numeric-linalg

Educational material on the SciPy implementation of numerical linear algebra algorithms

NameSizeMode
..
getrf/test.py 2036B -rw-r--r--
01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import scipy.linalg as la
from getrf import getrf
import json

N_TESTS = 1000
EPS = 1e-12

def matrix_eq(A: np.array, B: np.array) -> bool:
    return A.shape == B.shape and np.all(np.abs(A - B) < EPS)

def test(M: int, N: int, n_tests = N_TESTS):
    print(f"Comparing getrf and la.lu (running {n_tests} tests)")
    print("=" * 20)

    fails = []

    for i in range(n_tests):
        m = np.random.randint(1, M+1)
        n = np.random.randint(1, N+1)

        A = np.random.rand(m, n)
        P, L, U = la.lu(A)

        print(f"{i:05d}: Testing a {m:03d}x{n:03d} matrix: ", end="")

        try:
            my_P, my_L, my_U = getrf(A)
        except Exception as e:
            d = {"A": A.tolist(), "runtime error": str(e)}
            fails.append(d)

            print(f"\033[91mfailed!\033[0m")
            print("  > getrf raised an error!")
            continue

        P_test = matrix_eq(P, my_P)
        L_test = matrix_eq(L, my_L)
        U_test = matrix_eq(U, my_U)

        if (not P_test) or (not L_test) or (not U_test):
            d = {"A": A.tolist(),}

            if not P_test:
                d["expected P"] = P.tolist()
                d["actual P"]   = my_P.tolist()

            if not L_test:
                d["expected L"] = L.tolist()
                d["actual L"]   = my_L.tolist()

            if not U_test:
                d["expected U"] = U.tolist()
                d["actual U"]   = my_U.tolist()

            fails.append(d)

            print(f"\033[91mfailed!\033[0m")
            print("  > getrf returned wrong values")
            continue

        print(f"\033[32mpassed!\033[0m")

    print("\n" + "="  * 20)
    if len(fails) == 0:
        print(f"All {n_tests} tests passed!")
    else:
        testlogs = "./testlogs.json"
        with open(testlogs, "w") as f:
            json.dump(fails, f)

        print(f"{n_tests-len(fails)} tests passed ({len(fails)} fails)")
        print(f"Check {testlogs} for the failing inputs")

m, n = 200, 200

test(m, n, 10000)