numeric-linalg

Educational material on the SciPy implementation of numerical linear algebra algorithms

test.py (2036B)

 1 import numpy as np
 2 import scipy.linalg as la
 3 from getrf import getrf
 4 import json
 5 
 6 N_TESTS = 1000
 7 EPS = 1e-12
 8 
 9 def matrix_eq(A: np.array, B: np.array) -> bool:
10     return A.shape == B.shape and np.all(np.abs(A - B) < EPS)
11 
12 def test(M: int, N: int, n_tests = N_TESTS):
13     print(f"Comparing getrf and la.lu (running {n_tests} tests)")
14     print("=" * 20)
15 
16     fails = []
17 
18     for i in range(n_tests):
19         m = np.random.randint(1, M+1)
20         n = np.random.randint(1, N+1)
21 
22         A = np.random.rand(m, n)
23         P, L, U = la.lu(A)
24 
25         print(f"{i:05d}: Testing a {m:03d}x{n:03d} matrix: ", end="")
26 
27         try:
28             my_P, my_L, my_U = getrf(A)
29         except Exception as e:
30             d = {"A": A.tolist(), "runtime error": str(e)}
31             fails.append(d)
32 
33             print(f"\033[91mfailed!\033[0m")
34             print("  > getrf raised an error!")
35             continue
36 
37         P_test = matrix_eq(P, my_P)
38         L_test = matrix_eq(L, my_L)
39         U_test = matrix_eq(U, my_U)
40 
41         if (not P_test) or (not L_test) or (not U_test):
42             d = {"A": A.tolist(),}
43 
44             if not P_test:
45                 d["expected P"] = P.tolist()
46                 d["actual P"]   = my_P.tolist()
47 
48             if not L_test:
49                 d["expected L"] = L.tolist()
50                 d["actual L"]   = my_L.tolist()
51 
52             if not U_test:
53                 d["expected U"] = U.tolist()
54                 d["actual U"]   = my_U.tolist()
55 
56             fails.append(d)
57 
58             print(f"\033[91mfailed!\033[0m")
59             print("  > getrf returned wrong values")
60             continue
61 
62         print(f"\033[32mpassed!\033[0m")
63 
64     print("\n" + "="  * 20)
65     if len(fails) == 0:
66         print(f"All {n_tests} tests passed!")
67     else:
68         testlogs = "./testlogs.json"
69         with open(testlogs, "w") as f:
70             json.dump(fails, f)
71 
72         print(f"{n_tests-len(fails)} tests passed ({len(fails)} fails)")
73         print(f"Check {testlogs} for the failing inputs")
74 
75 m, n = 200, 200
76 
77 test(m, n, 10000)
78