numeric-linalg

Educational material on the SciPy implementation of numerical linear algebra algorithms

Commit
e981a0453b4e873149d9ae31dea1f8a91d8428c9
Parent
9c7392a31b59f45a69505c0415b9b7e6387205d5
Author
Pablo <pablo-pie@riseup.net>
Date

Finished transpiling the Fortran code to Python

Diffstats

5 files changed, 313 insertions, 236 deletions

Status Name Changes Insertions Deletions
Deleted getrf.py 1 file changed 0 106
Added getrf/getrf.py 1 file changed 106 0
Added getrf/test.py 1 file changed 78 0
Modified linear-solvers.ipynb 2 files changed 129 52
Deleted test_getrf.py 1 file changed 0 78
diff --git a/getrf.py /dev/null
@@ -1,106 +0,0 @@
-# "SciPy-transpiled" version of LAPACK's GETRF family of subroutines!
-import numpy as np
-import scipy.linalg as la
-
-def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):
-    """Returns the P, L, U
-
-    * A is m by n
-    * P is m by m
-    * L is m by n if m >= n and m by m if m <= n
-    * U is n by n if m >= n and m by n if m <= n
-    """
-    m, n = A.shape
-    
-    # A is a row
-    if m == 1:
-        return np.eye(1), np.eye(1), A
-
-    # A is a column
-    elif n == 1:
-        i0 = 0
-
-        for i in range(m):
-            if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i
-
-        # P permutes the 0-th and i0-th basis vectors
-        P = np.eye(m)
-        P[0,0],  P[i0,i0] = 0, 0
-        P[i0,0], P[0,i0]  = 1, 1
-
-        if A[i0, 0] != 0:
-            L = P@A / A[i0, 0]
-            U = A[i0, 0] * np.eye(1)
-        else:
-            L = A
-            U = np.zeros((1, 1))
-
-        return P, L, U
-    else:
-        n1 = min(m, n)//2
-        n2 = n - n1
-
-        # Write
-        #
-        #   A = [[A11, A12],
-        #        [A21, A22]],
-        #
-        #   A1 = [[A11, 
-        #          A21]],
-        #
-        #   A2 = [[A12, 
-        #          A22]]
-        #
-        # where A11 is n1 by n1 and A22 is n2 by n2
-        A11, A12 = A[:n1,:n1], A[:n1,n1:]
-        A21, A22 = A[n1:,:n1], A[n1:,n1:]
-        A1, A2   = A[:,:n1],   A[:,n1:]
-
-        # Solve the A1 block
-        P1, L1, U11 = getrf(A1)
-
-        # Apply pivots
-        # A2 is m by n2
-        A2 = la.inv(P1) @ A2
-        A12, A22 = A2[:n1,:], A2[n1:,:]
-        
-        # Solve A12 
-        L11, L21 = L1[:n1,:], L1[n1:,:]
-        A12 = la.inv(L11) @ A12
-
-        # Update A22
-        A22 = -L21@A12 + A22
-        
-        # Solve the A22 block
-        P2, L22, U22 = getrf(A22)
-
-        # Take P = P1 @ P2_ for
-        #
-        # P2_ = [[1, 0,
-        #         0, P2]]
-        P2_ = np.eye(m)
-        P2_[n1:,n1:] = P2
-        P = P1 @ P2_
-
-        # Apply interchanges to L21
-        L21 = la.inv(P2) @ L21
-
-        # Take
-        # 
-        # L = [[L11, 0],
-        #      [L21, L22]],
-        #
-        # U = [[U11, A12],
-        #      [0, U22]]
-        if m >= n:
-            L, U = np.zeros((m, n)), np.zeros((n, n))
-        else:
-            L, U = np.zeros((m, m)), np.zeros((m, n))
-        L[:n1,:n1] = L11
-        L[n1:,:n1] = L21
-        L[n1:,n1:] = L22
-        U[:n1,:n1] = U11
-        U[:n1,n1:] = A12
-        U[n1:,n1:] = U22
-
-        return P, L, U
diff --git /dev/null b/getrf/getrf.py
@@ -0,0 +1,106 @@
+# "SciPy-transpiled" version of LAPACK's GETRF family of subroutines!
+import numpy as np
+import scipy.linalg as la
+
+def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):
+    """Returns the P, L, U
+
+    * A is m by n
+    * P is m by m
+    * L is m by n if m >= n and m by m if m <= n
+    * U is n by n if m >= n and m by n if m <= n
+    """
+    m, n = A.shape
+    
+    # A is a row
+    if m == 1:
+        return np.eye(1), np.eye(1), A
+
+    # A is a column
+    elif n == 1:
+        i0 = 0
+
+        for i in range(m):
+            if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i
+
+        # P permutes the 0-th and i0-th basis vectors
+        P = np.eye(m)
+        P[0,0],  P[i0,i0] = 0, 0
+        P[i0,0], P[0,i0]  = 1, 1
+
+        if A[i0, 0] != 0:
+            L = P@A / A[i0, 0]
+            U = A[i0, 0] * np.eye(1)
+        else:
+            L = A
+            U = np.zeros((1, 1))
+
+        return P, L, U
+    else:
+        n1 = min(m, n)//2
+        n2 = n - n1
+
+        # Write
+        #
+        #   A = [[A11, A12],
+        #        [A21, A22]],
+        #
+        #   A1 = [[A11, 
+        #          A21]],
+        #
+        #   A2 = [[A12, 
+        #          A22]]
+        #
+        # where A11 is n1 by n1 and A22 is n2 by n2
+        A11, A12 = A[:n1,:n1], A[:n1,n1:]
+        A21, A22 = A[n1:,:n1], A[n1:,n1:]
+        A1, A2   = A[:,:n1],   A[:,n1:]
+
+        # Solve the A1 block
+        P1, L1, U11 = getrf(A1)
+
+        # Apply pivots
+        # A2 is m by n2
+        A2 = la.inv(P1) @ A2
+        A12, A22 = A2[:n1,:], A2[n1:,:]
+        
+        # Solve A12 
+        L11, L21 = L1[:n1,:], L1[n1:,:]
+        A12 = la.inv(L11) @ A12
+
+        # Update A22
+        A22 = -L21@A12 + A22
+        
+        # Solve the A22 block
+        P2, L22, U22 = getrf(A22)
+
+        # Take P = P1 @ P2_ for
+        #
+        # P2_ = [[1, 0,
+        #         0, P2]]
+        P2_ = np.eye(m)
+        P2_[n1:,n1:] = P2
+        P = P1 @ P2_
+
+        # Apply interchanges to L21
+        L21 = la.inv(P2) @ L21
+
+        # Take
+        # 
+        # L = [[L11, 0],
+        #      [L21, L22]],
+        #
+        # U = [[U11, A12],
+        #      [0, U22]]
+        if m >= n:
+            L, U = np.zeros((m, n)), np.zeros((n, n))
+        else:
+            L, U = np.zeros((m, m)), np.zeros((m, n))
+        L[:n1,:n1] = L11
+        L[n1:,:n1] = L21
+        L[n1:,n1:] = L22
+        U[:n1,:n1] = U11
+        U[:n1,n1:] = A12
+        U[n1:,n1:] = U22
+
+        return P, L, U
diff --git /dev/null b/getrf/test.py
@@ -0,0 +1,78 @@
+import numpy as np
+import scipy.linalg as la
+from getrf import getrf
+import json
+
+N_TESTS = 1000
+EPS = 1e-12
+
+def matrix_eq(A: np.array, B: np.array) -> bool:
+    return A.shape == B.shape and np.all(np.abs(A - B) < EPS)
+
+def test(M: int, N: int, n_tests = N_TESTS):
+    print(f"Comparing getrf and la.lu (running {n_tests} tests)")
+    print("=" * 20)
+
+    fails = []
+
+    for i in range(n_tests):
+        m = np.random.randint(1, M+1)
+        n = np.random.randint(1, N+1)
+
+        A = np.random.rand(m, n)
+        P, L, U = la.lu(A)
+
+        print(f"{i:05d}: Testing a {m:03d}x{n:03d} matrix: ", end="")
+
+        try:
+            my_P, my_L, my_U = getrf(A)
+        except Exception as e:
+            d = {"A": A.tolist(), "runtime error": str(e)}
+            fails.append(d)
+
+            print(f"\033[91mfailed!\033[0m")
+            print("  > getrf raised an error!")
+            continue
+
+        P_test = matrix_eq(P, my_P)
+        L_test = matrix_eq(L, my_L)
+        U_test = matrix_eq(U, my_U)
+
+        if (not P_test) or (not L_test) or (not U_test):
+            d = {"A": A.tolist(),}
+
+            if not P_test:
+                d["expected P"] = P.tolist()
+                d["actual P"]   = my_P.tolist()
+
+            if not L_test:
+                d["expected L"] = L.tolist()
+                d["actual L"]   = my_L.tolist()
+
+            if not U_test:
+                d["expected U"] = U.tolist()
+                d["actual U"]   = my_U.tolist()
+
+            fails.append(d)
+
+            print(f"\033[91mfailed!\033[0m")
+            print("  > getrf returned wrong values")
+            continue
+
+        print(f"\033[32mpassed!\033[0m")
+
+    print("\n" + "="  * 20)
+    if len(fails) == 0:
+        print(f"All {n_tests} tests passed!")
+    else:
+        testlogs = "./testlogs.json"
+        with open(testlogs, "w") as f:
+            json.dump(fails, f)
+
+        print(f"{n_tests-len(fails)} tests passed ({len(fails)} fails)")
+        print(f"Check {testlogs} for the failing inputs")
+
+m, n = 200, 200
+
+test(m, n, 10000)
+
diff --git a/linear-solvers.ipynb b/linear-solvers.ipynb
@@ -74,10 +74,10 @@
     "As for the decomposition of $A$ in the first step, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting). A simple simple recurssive algorithm using such method might look something like the following:\n",
     "\n",
     "1. If $A = a_{11}$ is $1 \\times 1$ then take $P = L = 1$ and $U = a_{11}$.\n",
-    "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors.\n",
+    "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors. Searching for $i_0$ is an $O(n)$ operation.\n",
     "3. Write\n",
     "   $$S_{i_0} A = \\left( \\begin{array}{c|c} a_{i_0} & A_{12}' \\\\ \\hline A_{21}' & A_{22}' \\end{array} \\right), $$\n",
-    "   where $A_{22}'$ is $(n - 1) \\times (n - 1)$. Since $A$ is invertible, $a_{i_0} \\ne 0$.\n",
+    "   where $A_{22}'$ is $(n - 1) \\times (n - 1)$ and $a_{i_0} \\ne 0$ — given $A$ is invertible. Since $S_{i_0}$ acts on $A$ by swaping the first and $i_0$-th rows, computing $S_{i_0} A$ is an $O(n)$ operation.\n",
     "4. We want to solve the equation\n",
     "   $$S_{i_0} A = \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline 0 & P_{22} \\end{array} \\right) \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline L_{21} & L_{22} \\end{array} \\right) \\cdot \\left( \\begin{array}{c|c} u_{11} & U_{12} \\\\ \\hline 0 & U_{22} \\end{array} \\right),$$\n",
     "   where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonal entries and $U_{22}$ is upper triangular. In other words, we want to solve the equations\n",
@@ -87,9 +87,9 @@
     "       A_{21}' &= u_{11} P_{22} L_{21} & A_{22}' &= P_{22} L_{21} U_{12} + P_{22} L_{22} U_{22}.\n",
     "   \\end{aligned}\n",
     "   $$\n",
-    "   We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$, so it remains to solve the bottom-right equation.\n",
-    "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular.\n",
-    "6. Take\n",
+    "   We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}'$, so it remains to solve the bottom-right equation.\n",
+    "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular. Computing $A_{21}' A_{12}'$ (and thus $A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}'$) is, of course, a $O(n^3)$ operation. Now since $P_{22}$ is a permutation matrix, computing $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$ is an $O(n^2)$ operation.\n",
+    "7. Take\n",
     "   $$\n",
     "   \\begin{aligned}\n",
     "       L &= \\begin{pmatrix}      1 & 0      \\\\ L_{21} & L_{22} \\end{pmatrix} &\n",
@@ -100,73 +100,150 @@
     "   $$\n",
     "   S_{i_0} A = \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix} L U.\n",
     "   $$\n",
-    "7. Hence by taking\n",
+    "8. Hence by taking\n",
     "   $$\n",
     "   P = S_{i_0} \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix}\n",
     "   $$\n",
-    "   we get $A = P L U$ as desired!"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "0e1e9c6b-5e65-4cf1-aa53-a7202abff788",
-   "metadata": {
-    "jp-MarkdownHeadingCollapsed": true
-   },
-   "source": [
-    "## Scrap"
+    "   we get $A = P L U$ as desired!\n",
+    "\n",
+    "In total, this algorithm takes $n$ recursive steps to solve $A = P L U$. Since each step involves $O(n^3)$ operations, the total complexity of our facotization algorithm is $O(n^4)$."
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "4dde0b0b-08cc-4bca-9dd5-c2bced871cef",
+   "id": "7c6678f2-9ae6-4daa-922c-828771c1a796",
    "metadata": {},
    "source": [
-    "As for the decomposition of $A$ in the first step, given some $m \\times n$ matrix $A$, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting) to write $A = P L U$, where $P$ is a $m \\times m$ permutation matrix, $L$ is a $m \\times n$ lower trapezoidal matrix with $1$ in the diagona entries and $U$ is a $n \\times n$ upper triangular matrix. Their algorithm goes something like the following:\n",
-    "\n",
-    "1. If $A = 0$ then take $P = 1$, $L$ the $m \\times n$ matrix with $1$ in the diagonals and zero elsewhere and $U = 0$.\n",
-    "2. If $m = 1$, so that $A = \\begin{pmatrix} a_1 & \\cdots & a_n \\end{pmatrix}$, *DIE*\n",
-    "3. If $n = 1$, so that\n",
-    "   $$ A = \\begin{pmatrix} a_1 \\\\ \\vdots \\\\ a_m \\end{pmatrix} $$\n",
-    "   choose $i_0$ that maximizes $|a_{i_0}|$. Since $A \\ne 0$, $a_{i_0} \\ne 0$. Return $P = \\sigma_{1, i_0}$, $L = a_{i_0}^{-1} P^{-1} A$ and $U = 1$.\n",
-    "4. Otherwise do the following series of steps.\n",
-    "   1. Take $n_1 = \\left\\lfloor \\frac{\\min \\{n, m\\}}{2} \\right \\rfloor$ and $n_2 = n - n1$.\n",
-    "   2. Write\n",
-    "      $$A = \\left( \\begin{array}{c|c} A_{11} & A_{12} \\\\ \\hline A_{21} & A_{22} \\end{array} \\right), $$\n",
-    "      where $A_{11}$ is $n_1 \\times n_2$ and $A_{22}$ is $n_2 \\times n_2$.\n",
-    "   3. Write\n",
-    "      $$\\begin{pmatrix} A_{11} \\\\ A_{21} \\end{pmatrix} = P_1 L_1 U_1 \\quad L_1 = \\begin{pmatrix} L_1' \\\\ L_1'' \\end{pmatrix} ,$$\n",
-    "      where $P_1$ is a $m \\times m$ permutation matrix, $L_1$ is a $m \\times n_1$ lower trapezoidal matrix, $L_1'$ is $n_1 \\times n_1$ and $U_1$ is a $n_1 \\times n_1$ upper triangular matrix.\n",
-    "   4. Write\n",
-    "      $$\\begin{pmatrix} A_{12}' \\\\ A_{22}' \\end{pmatrix} = P_1^{-1} \\begin{pmatrix} A_{12} \\\\ A_{22} \\end{pmatrix}$$\n",
-    "   5. Take\n",
-    "      $$A_{12}'' = L_1'^{-1} A_{12}' \\quad A_{22}'' = -L_1'1A_{12}'' + A_{22}'$$\n",
-    "   6. Write $A_{12}'' = P_2 L_2 U_2$, where $P_2$ is a $n_1 \\times n_1$ permutation matrix, $L_2$ is a $n_1 \\times n_2$ lower trapezoidal matrix and $U_2$ is a $n_2 \\times n_2$ upper triangular matrix.\n",
-    "   7. Take\n",
-    "      $$\n",
-    "      \\begin{aligned}\n",
-    "          P &= \\begin{pmatrix} P_1 & 0 \\\\ 0 & P_2 \\end{pmatrix} &\n",
-    "          L &= \\begin{pmatrix} L_1 & 0 \\\\ A_{21}'' & L_2 \\end{pmatrix} &\n",
-    "          U &= \\begin{pmatrix} U_1 & A_{12}'' \\\\ 0 & U_2 \\end{pmatrix},\n",
-    "      \\end{aligned}\n",
-    "      $$\n",
-    "      where $A_{21}'' = P_2^{-1} L_1'1$."
+    "The LAPACK algorithm for factorizing a $m \\times n$ matrix $A$ improves on this simple concept by instead taking the decomposition\n",
+    "$$\n",
+    "A =\n",
+    "\\left(\n",
+    "\\begin{array}{c|c}\n",
+    "    A_{11} & A_{12} \\\\ \\hline\n",
+    "    A_{21} & A_{22}\n",
+    "\\end{array}\n",
+    "\\right),\n",
+    "$$\n",
+    "where $A_{11}$ is a $\\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor \\times \\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor$ matrix.\n",
+    "This is implemented in the `GETRF` family of subroutines. A SciPy version of their precise algorithm might look something like the following:"
    ]
   },
   {
-   "cell_type": "markdown",
-   "id": "93106465-de78-418d-8c7b-9774b97f6b07",
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "f942e1ba-0449-4f35-b894-4dc5cad385a6",
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Althought similar to the naive method of computing $A^{-1}$ and then applying it to $b$, this algorithm has better performance characteristics since it only computes $A^{-1} b$ (without calculating the product $A^{-1} = U^{-1} L^{-1} P^{-1}$). The `la.solve` function also uses optimized linear solvers when $A$ is known to be symmetric or Hermitian.\n",
+    "import numpy as np\n",
+    "import scipy.linalg as la\n",
+    "\n",
+    "def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):\n",
+    "    \"\"\"Returns the P, L, U\n",
+    "\n",
+    "    * A is m by n\n",
+    "    * P is m by m\n",
+    "    * L is m by n if m >= n and m by m if m <= n\n",
+    "    * U is n by n if m >= n and m by n if m <= n\n",
+    "    \"\"\"\n",
+    "    m, n = A.shape\n",
+    "    \n",
+    "    # A is a row\n",
+    "    if m == 1:\n",
+    "        return np.eye(1), np.eye(1), A\n",
+    "\n",
+    "    # A is a column\n",
+    "    elif n == 1:\n",
+    "        i0 = 0\n",
+    "\n",
+    "        for i in range(m):\n",
+    "            if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i\n",
+    "\n",
+    "        P = np.eye(m)\n",
+    "        P[0,0],  P[i0,i0] = 0, 0\n",
+    "        P[i0,0], P[0,i0]  = 1, 1\n",
+    "\n",
+    "        if A[i0, 0] != 0:\n",
+    "            L = P@A / A[i0, 0]\n",
+    "            U = A[i0, 0] * np.eye(1)\n",
+    "        else:\n",
+    "            L = A\n",
+    "            U = np.zeros((1, 1))\n",
+    "\n",
+    "        return P, L, U\n",
+    "    else:\n",
+    "        n1 = min(m, n)//2\n",
+    "        n2 = n - n1\n",
+    "\n",
+    "        # Write\n",
+    "        #\n",
+    "        #   A = [[A11, A12],\n",
+    "        #        [A21, A22]],\n",
+    "        #\n",
+    "        #   A1 = [[A11, \n",
+    "        #          A21]],\n",
+    "        #\n",
+    "        #   A2 = [[A12, \n",
+    "        #          A22]]\n",
+    "        #\n",
+    "        # where A11 is n1 by n1 and A22 is n2 by n2\n",
+    "        A11, A12 = A[:n1,:n1], A[:n1,n1:]\n",
+    "        A21, A22 = A[n1:,:n1], A[n1:,n1:]\n",
+    "        A1, A2   = A[:,:n1],   A[:,n1:]\n",
+    "\n",
+    "        # Solve the A1 block\n",
+    "        P1, L1, U11 = getrf(A1)\n",
+    "\n",
+    "        # Apply pivots\n",
+    "        A2 = la.inv(P1) @ A2\n",
+    "        A12, A22 = A2[:n1,:], A2[n1:,:]\n",
+    "        \n",
+    "        # Solve A12 \n",
+    "        L11, L21 = L1[:n1,:], L1[n1:,:]\n",
+    "        A12 = la.inv(L11) @ A12\n",
+    "\n",
+    "        # Update A22\n",
+    "        A22 = -L21@A12 + A22\n",
+    "        \n",
+    "        # Solve the A22 block\n",
+    "        P2, L22, U22 = getrf(A22)\n",
+    "\n",
+    "        # Take P = P1 @ P2_ for\n",
+    "        #\n",
+    "        # P2_ = [[1, 0,\n",
+    "        #         0, P2]]\n",
+    "        P2_ = np.eye(m)\n",
+    "        P2_[n1:,n1:] = P2\n",
+    "        P = P1 @ P2_\n",
+    "\n",
+    "        # Apply interchanges to L21\n",
+    "        L21 = la.inv(P2) @ L21\n",
+    "\n",
+    "        # Take\n",
+    "        # \n",
+    "        # L = [[L11, 0],\n",
+    "        #      [L21, L22]],\n",
+    "        #\n",
+    "        # U = [[U11, A12],\n",
+    "        #      [  0, U22]]\n",
+    "        if m >= n:\n",
+    "            L, U = np.zeros((m, n)), np.zeros((n, n))\n",
+    "        else:\n",
+    "            L, U = np.zeros((m, m)), np.zeros((m, n))\n",
+    "        L[:n1,:n1] = L11\n",
+    "        L[n1:,:n1] = L21\n",
+    "        L[n1:,n1:] = L22\n",
+    "        U[:n1,:n1] = U11\n",
+    "        U[:n1,n1:] = A12\n",
+    "        U[n1:,n1:] = U22\n",
     "\n",
-    "For more details please refer to the implementations of the `GETRS`, `SYSV` and `HESV` families of subroutines in LAPACK."
+    "        return P, L, U"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "6d77288f-56a8-4037-965d-56f2c443a73c",
+   "id": "229ed6e3-493e-49c3-835a-0b8582aa6586",
    "metadata": {},
    "outputs": [],
    "source": []
diff --git a/test_getrf.py /dev/null
@@ -1,78 +0,0 @@
-import numpy as np
-import scipy.linalg as la
-from getrf import getrf
-import json
-
-N_TESTS = 1000
-EPS = 1e-12
-
-def matrix_eq(A: np.array, B: np.array) -> bool:
-    return A.shape == B.shape and np.all(np.abs(A - B) < EPS)
-
-def test(M: int, N: int, n_tests = N_TESTS):
-    print(f"Comparing getrf and la.lu (running {n_tests} tests)")
-    print("=" * 20)
-
-    fails = []
-
-    for i in range(n_tests):
-        m = np.random.randint(1, M+1)
-        n = np.random.randint(1, N+1)
-
-        A = np.random.rand(m, n)
-        P, L, U = la.lu(A)
-
-        print(f"{i:05d}: Testing a {m:03d}x{n:03d} matrix: ", end="")
-
-        try:
-            my_P, my_L, my_U = getrf(A)
-        except Exception as e:
-            d = {"A": A.tolist(), "runtime error": str(e)}
-            fails.append(d)
-
-            print(f"\033[91mfailed!\033[0m")
-            print("  > getrf raised an error!")
-            continue
-
-        P_test = matrix_eq(P, my_P)
-        L_test = matrix_eq(L, my_L)
-        U_test = matrix_eq(U, my_U)
-
-        if (not P_test) or (not L_test) or (not U_test):
-            d = {"A": A.tolist(),}
-
-            if not P_test:
-                d["expected P"] = P.tolist()
-                d["actual P"]   = my_P.tolist()
-
-            if not L_test:
-                d["expected L"] = L.tolist()
-                d["actual L"]   = my_L.tolist()
-
-            if not U_test:
-                d["expected U"] = U.tolist()
-                d["actual U"]   = my_U.tolist()
-
-            fails.append(d)
-
-            print(f"\033[91mfailed!\033[0m")
-            print("  > getrf returned wrong values")
-            continue
-
-        print(f"\033[32mpassed!\033[0m")
-
-    print("\n" + "="  * 20)
-    if len(fails) == 0:
-        print(f"All {n_tests} tests passed!")
-    else:
-        testlogs = "./testlogs.json"
-        with open(testlogs, "w") as f:
-            json.dump(fails, f)
-
-        print(f"{n_tests-len(fails)} tests passed ({len(fails)} fails)")
-        print(f"Check {testlogs} for the failing inputs")
-
-m, n = 200, 200
-
-test(m, n, 10000)
-