numeric-linalg

Educational material on the SciPy implementation of numerical linear algebra algorithms

Commit
e981a0453b4e873149d9ae31dea1f8a91d8428c9
Parent
9c7392a31b59f45a69505c0415b9b7e6387205d5
Author
Pablo <pablo-pie@riseup.net>
Date

Finished transpiling the Fortran code to Python

Diffstat

3 files changed, 129 insertions, 52 deletions

Status File Name N° Changes Insertions Deletions
Renamed getrf.py -> getrf/getrf.py 0 0 0
Renamed test_getrf.py -> getrf/test.py 0 0 0
Modified linear-solvers.ipynb 181 129 52
diff --git a/getrf.py b/getrf/getrf.py
diff --git a/test_getrf.py b/getrf/test.py
diff --git a/linear-solvers.ipynb b/linear-solvers.ipynb
@@ -74,10 +74,10 @@
     "As for the decomposition of $A$ in the first step, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting). A simple simple recurssive algorithm using such method might look something like the following:\n",
     "\n",
     "1. If $A = a_{11}$ is $1 \\times 1$ then take $P = L = 1$ and $U = a_{11}$.\n",
-    "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors.\n",
+    "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors. Searching for $i_0$ is an $O(n)$ operation.\n",
     "3. Write\n",
     "   $$S_{i_0} A = \\left( \\begin{array}{c|c} a_{i_0} & A_{12}' \\\\ \\hline A_{21}' & A_{22}' \\end{array} \\right), $$\n",
-    "   where $A_{22}'$ is $(n - 1) \\times (n - 1)$. Since $A$ is invertible, $a_{i_0} \\ne 0$.\n",
+    "   where $A_{22}'$ is $(n - 1) \\times (n - 1)$ and $a_{i_0} \\ne 0$ — given $A$ is invertible. Since $S_{i_0}$ acts on $A$ by swaping the first and $i_0$-th rows, computing $S_{i_0} A$ is an $O(n)$ operation.\n",
     "4. We want to solve the equation\n",
     "   $$S_{i_0} A = \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline 0 & P_{22} \\end{array} \\right) \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline L_{21} & L_{22} \\end{array} \\right) \\cdot \\left( \\begin{array}{c|c} u_{11} & U_{12} \\\\ \\hline 0 & U_{22} \\end{array} \\right),$$\n",
     "   where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonal entries and $U_{22}$ is upper triangular. In other words, we want to solve the equations\n",
@@ -87,9 +87,9 @@
     "       A_{21}' &= u_{11} P_{22} L_{21} & A_{22}' &= P_{22} L_{21} U_{12} + P_{22} L_{22} U_{22}.\n",
     "   \\end{aligned}\n",
     "   $$\n",
-    "   We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$, so it remains to solve the bottom-right equation.\n",
-    "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular.\n",
-    "6. Take\n",
+    "   We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}'$, so it remains to solve the bottom-right equation.\n",
+    "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular. Computing $A_{21}' A_{12}'$ (and thus $A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}'$) is, of course, a $O(n^3)$ operation. Now since $P_{22}$ is a permutation matrix, computing $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$ is an $O(n^2)$ operation.\n",
+    "7. Take\n",
     "   $$\n",
     "   \\begin{aligned}\n",
     "       L &= \\begin{pmatrix}      1 & 0      \\\\ L_{21} & L_{22} \\end{pmatrix} &\n",
@@ -100,73 +100,150 @@
     "   $$\n",
     "   S_{i_0} A = \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix} L U.\n",
     "   $$\n",
-    "7. Hence by taking\n",
+    "8. Hence by taking\n",
     "   $$\n",
     "   P = S_{i_0} \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix}\n",
     "   $$\n",
-    "   we get $A = P L U$ as desired!"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "0e1e9c6b-5e65-4cf1-aa53-a7202abff788",
-   "metadata": {
-    "jp-MarkdownHeadingCollapsed": true
-   },
-   "source": [
-    "## Scrap"
+    "   we get $A = P L U$ as desired!\n",
+    "\n",
+    "In total, this algorithm takes $n$ recursive steps to solve $A = P L U$. Since each step involves $O(n^3)$ operations, the total complexity of our facotization algorithm is $O(n^4)$."
    ]
   },
   {
    "cell_type": "markdown",
-   "id": "4dde0b0b-08cc-4bca-9dd5-c2bced871cef",
+   "id": "7c6678f2-9ae6-4daa-922c-828771c1a796",
    "metadata": {},
    "source": [
-    "As for the decomposition of $A$ in the first step, given some $m \\times n$ matrix $A$, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting) to write $A = P L U$, where $P$ is a $m \\times m$ permutation matrix, $L$ is a $m \\times n$ lower trapezoidal matrix with $1$ in the diagona entries and $U$ is a $n \\times n$ upper triangular matrix. Their algorithm goes something like the following:\n",
-    "\n",
-    "1. If $A = 0$ then take $P = 1$, $L$ the $m \\times n$ matrix with $1$ in the diagonals and zero elsewhere and $U = 0$.\n",
-    "2. If $m = 1$, so that $A = \\begin{pmatrix} a_1 & \\cdots & a_n \\end{pmatrix}$, *DIE*\n",
-    "3. If $n = 1$, so that\n",
-    "   $$ A = \\begin{pmatrix} a_1 \\\\ \\vdots \\\\ a_m \\end{pmatrix} $$\n",
-    "   choose $i_0$ that maximizes $|a_{i_0}|$. Since $A \\ne 0$, $a_{i_0} \\ne 0$. Return $P = \\sigma_{1, i_0}$, $L = a_{i_0}^{-1} P^{-1} A$ and $U = 1$.\n",
-    "4. Otherwise do the following series of steps.\n",
-    "   1. Take $n_1 = \\left\\lfloor \\frac{\\min \\{n, m\\}}{2} \\right \\rfloor$ and $n_2 = n - n1$.\n",
-    "   2. Write\n",
-    "      $$A = \\left( \\begin{array}{c|c} A_{11} & A_{12} \\\\ \\hline A_{21} & A_{22} \\end{array} \\right), $$\n",
-    "      where $A_{11}$ is $n_1 \\times n_2$ and $A_{22}$ is $n_2 \\times n_2$.\n",
-    "   3. Write\n",
-    "      $$\\begin{pmatrix} A_{11} \\\\ A_{21} \\end{pmatrix} = P_1 L_1 U_1 \\quad L_1 = \\begin{pmatrix} L_1' \\\\ L_1'' \\end{pmatrix} ,$$\n",
-    "      where $P_1$ is a $m \\times m$ permutation matrix, $L_1$ is a $m \\times n_1$ lower trapezoidal matrix, $L_1'$ is $n_1 \\times n_1$ and $U_1$ is a $n_1 \\times n_1$ upper triangular matrix.\n",
-    "   4. Write\n",
-    "      $$\\begin{pmatrix} A_{12}' \\\\ A_{22}' \\end{pmatrix} = P_1^{-1} \\begin{pmatrix} A_{12} \\\\ A_{22} \\end{pmatrix}$$\n",
-    "   5. Take\n",
-    "      $$A_{12}'' = L_1'^{-1} A_{12}' \\quad A_{22}'' = -L_1'1A_{12}'' + A_{22}'$$\n",
-    "   6. Write $A_{12}'' = P_2 L_2 U_2$, where $P_2$ is a $n_1 \\times n_1$ permutation matrix, $L_2$ is a $n_1 \\times n_2$ lower trapezoidal matrix and $U_2$ is a $n_2 \\times n_2$ upper triangular matrix.\n",
-    "   7. Take\n",
-    "      $$\n",
-    "      \\begin{aligned}\n",
-    "          P &= \\begin{pmatrix} P_1 & 0 \\\\ 0 & P_2 \\end{pmatrix} &\n",
-    "          L &= \\begin{pmatrix} L_1 & 0 \\\\ A_{21}'' & L_2 \\end{pmatrix} &\n",
-    "          U &= \\begin{pmatrix} U_1 & A_{12}'' \\\\ 0 & U_2 \\end{pmatrix},\n",
-    "      \\end{aligned}\n",
-    "      $$\n",
-    "      where $A_{21}'' = P_2^{-1} L_1'1$."
+    "The LAPACK algorithm for factorizing a $m \\times n$ matrix $A$ improves on this simple concept by instead taking the decomposition\n",
+    "$$\n",
+    "A =\n",
+    "\\left(\n",
+    "\\begin{array}{c|c}\n",
+    "    A_{11} & A_{12} \\\\ \\hline\n",
+    "    A_{21} & A_{22}\n",
+    "\\end{array}\n",
+    "\\right),\n",
+    "$$\n",
+    "where $A_{11}$ is a $\\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor \\times \\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor$ matrix.\n",
+    "This is implemented in the `GETRF` family of subroutines. A SciPy version of their precise algorithm might look something like the following:"
    ]
   },
   {
-   "cell_type": "markdown",
-   "id": "93106465-de78-418d-8c7b-9774b97f6b07",
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "f942e1ba-0449-4f35-b894-4dc5cad385a6",
    "metadata": {},
+   "outputs": [],
    "source": [
-    "Althought similar to the naive method of computing $A^{-1}$ and then applying it to $b$, this algorithm has better performance characteristics since it only computes $A^{-1} b$ (without calculating the product $A^{-1} = U^{-1} L^{-1} P^{-1}$). The `la.solve` function also uses optimized linear solvers when $A$ is known to be symmetric or Hermitian.\n",
+    "import numpy as np\n",
+    "import scipy.linalg as la\n",
+    "\n",
+    "def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):\n",
+    "    \"\"\"Returns the P, L, U\n",
+    "\n",
+    "    * A is m by n\n",
+    "    * P is m by m\n",
+    "    * L is m by n if m >= n and m by m if m <= n\n",
+    "    * U is n by n if m >= n and m by n if m <= n\n",
+    "    \"\"\"\n",
+    "    m, n = A.shape\n",
+    "    \n",
+    "    # A is a row\n",
+    "    if m == 1:\n",
+    "        return np.eye(1), np.eye(1), A\n",
+    "\n",
+    "    # A is a column\n",
+    "    elif n == 1:\n",
+    "        i0 = 0\n",
+    "\n",
+    "        for i in range(m):\n",
+    "            if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i\n",
+    "\n",
+    "        P = np.eye(m)\n",
+    "        P[0,0],  P[i0,i0] = 0, 0\n",
+    "        P[i0,0], P[0,i0]  = 1, 1\n",
+    "\n",
+    "        if A[i0, 0] != 0:\n",
+    "            L = P@A / A[i0, 0]\n",
+    "            U = A[i0, 0] * np.eye(1)\n",
+    "        else:\n",
+    "            L = A\n",
+    "            U = np.zeros((1, 1))\n",
+    "\n",
+    "        return P, L, U\n",
+    "    else:\n",
+    "        n1 = min(m, n)//2\n",
+    "        n2 = n - n1\n",
+    "\n",
+    "        # Write\n",
+    "        #\n",
+    "        #   A = [[A11, A12],\n",
+    "        #        [A21, A22]],\n",
+    "        #\n",
+    "        #   A1 = [[A11, \n",
+    "        #          A21]],\n",
+    "        #\n",
+    "        #   A2 = [[A12, \n",
+    "        #          A22]]\n",
+    "        #\n",
+    "        # where A11 is n1 by n1 and A22 is n2 by n2\n",
+    "        A11, A12 = A[:n1,:n1], A[:n1,n1:]\n",
+    "        A21, A22 = A[n1:,:n1], A[n1:,n1:]\n",
+    "        A1, A2   = A[:,:n1],   A[:,n1:]\n",
+    "\n",
+    "        # Solve the A1 block\n",
+    "        P1, L1, U11 = getrf(A1)\n",
+    "\n",
+    "        # Apply pivots\n",
+    "        A2 = la.inv(P1) @ A2\n",
+    "        A12, A22 = A2[:n1,:], A2[n1:,:]\n",
+    "        \n",
+    "        # Solve A12 \n",
+    "        L11, L21 = L1[:n1,:], L1[n1:,:]\n",
+    "        A12 = la.inv(L11) @ A12\n",
+    "\n",
+    "        # Update A22\n",
+    "        A22 = -L21@A12 + A22\n",
+    "        \n",
+    "        # Solve the A22 block\n",
+    "        P2, L22, U22 = getrf(A22)\n",
+    "\n",
+    "        # Take P = P1 @ P2_ for\n",
+    "        #\n",
+    "        # P2_ = [[1, 0,\n",
+    "        #         0, P2]]\n",
+    "        P2_ = np.eye(m)\n",
+    "        P2_[n1:,n1:] = P2\n",
+    "        P = P1 @ P2_\n",
+    "\n",
+    "        # Apply interchanges to L21\n",
+    "        L21 = la.inv(P2) @ L21\n",
+    "\n",
+    "        # Take\n",
+    "        # \n",
+    "        # L = [[L11, 0],\n",
+    "        #      [L21, L22]],\n",
+    "        #\n",
+    "        # U = [[U11, A12],\n",
+    "        #      [  0, U22]]\n",
+    "        if m >= n:\n",
+    "            L, U = np.zeros((m, n)), np.zeros((n, n))\n",
+    "        else:\n",
+    "            L, U = np.zeros((m, m)), np.zeros((m, n))\n",
+    "        L[:n1,:n1] = L11\n",
+    "        L[n1:,:n1] = L21\n",
+    "        L[n1:,n1:] = L22\n",
+    "        U[:n1,:n1] = U11\n",
+    "        U[:n1,n1:] = A12\n",
+    "        U[n1:,n1:] = U22\n",
     "\n",
-    "For more details please refer to the implementations of the `GETRS`, `SYSV` and `HESV` families of subroutines in LAPACK."
+    "        return P, L, U"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "6d77288f-56a8-4037-965d-56f2c443a73c",
+   "id": "229ed6e3-493e-49c3-835a-0b8582aa6586",
    "metadata": {},
    "outputs": [],
    "source": []