numeric-linalg
Educational material on the SciPy implementation of numerical linear algebra algorithms
linear-solvers.ipynb (10886B)
1 { 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "id": "e1791697-e733-4721-9e10-fcdcfbda9064", 6 "metadata": {}, 7 "source": [ 8 "# Solving Linear systems in SciPy\n", 9 "\n", 10 "A linear system is a system of equations of the form\n", 11 "$$ \\left\\{ \\begin{aligned} a_{1 1} x_1 + a_{1 2} x_2 + \\cdots + a_{1 n} x_n &= b_1 \\\\ a_{2 1} x_1 + a_{2 2} x_2 + \\cdots + a_{2 n} x_n &= b_2 \\\\ & \\vdots \\\\ a_{n 1} x_1 + a_{n 2} x_2 + \\cdots + a_{n n} x_n &= b_n\\end{aligned} \\right.$$\n", 12 "on the variables $x_1, \\ldots, x_n$.\n", 13 "\n", 14 "Solving this system is equivalent to solving the equation $A x = b$ on $x$ where $A = (a_{ij})_{ij}$ is a $n\\times n$ matrix and $x = (x_1, \\ldots, x_n) \\; \\& \\; b = (b_1, \\ldots, b_n)$ are vectors, which can always be done provided $A$ is invertible." 15 ] 16 }, 17 { 18 "cell_type": "markdown", 19 "id": "26bc6a30-7853-4c78-adfb-320f0a65dd10", 20 "metadata": {}, 21 "source": [ 22 "In SciPy, we can solve linear systems using the `la.solve` function." 23 ] 24 }, 25 { 26 "cell_type": "code", 27 "execution_count": 1, 28 "id": "b1ced47f-6783-48ed-a4e4-4f1f4e2b8835", 29 "metadata": {}, 30 "outputs": [ 31 { 32 "data": { 33 "text/plain": [ 34 "array([[-0.12672176],\n", 35 " [ 0.1046832 ],\n", 36 " [ 1.19008264]])" 37 ] 38 }, 39 "execution_count": 1, 40 "metadata": {}, 41 "output_type": "execute_result" 42 } 43 ], 44 "source": [ 45 "import numpy as np\n", 46 "import scipy.linalg as la\n", 47 "\n", 48 "A, b = np.array([[6,15,1],[8,7,12],[2,7,8]]), np.array([[2], [14], [10]])\n", 49 "la.solve(A, b)" 50 ] 51 }, 52 { 53 "cell_type": "markdown", 54 "id": "bfdacdbf-150f-4b4d-8072-ee18758d3b60", 55 "metadata": {}, 56 "source": [ 57 "## But how does `la.solve` work???\n", 58 "\n", 59 "Internally, the `la.solve` function uses the the [LAPACK library](https://netlib.org/lapack), a Fortran package for numerical linear algebra. The LAPACK generic linear solver algorithm goes something like the following:\n", 60 "\n", 61 "1. Decompose $A$ as $A = PL U$, where $P$ is permutation matrix, $L$ is a lower triangular matrix with $1$ in the diagonal and $U$ is an upper triangular matrix.\n", 62 "3. Solve $P b' = b$ for $b'$, i.e. compute $b' = P^{-1} b$. Since $P$ is a permutation matrix, this operation is $O(n)$.\n", 63 "4. Solve $L b'' = b'$ for $b''$, i.e. compute $b'' = L^{-1} P^{-1} b$. Since $L$ is known to be lower triangular, this operation is $O(n^2)$.\n", 64 "3. Solve $U x = b''$ for $x$, i.e. compute $x = U^{-1} L^{-1} P^{-1} b = A^{-1} b$. Since $U$ is known to be upper triangular, this operation is $O(n^2)$.\n", 65 "\n", 66 "This is implemented in the `GETRS` family of subroutines." 67 ] 68 }, 69 { 70 "cell_type": "markdown", 71 "id": "5b7ad45f-ea81-4d46-9097-735cc159cf1a", 72 "metadata": {}, 73 "source": [ 74 "As for the decomposition of $A$ in the first step, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting). A simple simple recurssive algorithm using such method might look something like the following:\n", 75 "\n", 76 "1. If $A = a_{11}$ is $1 \\times 1$ then take $P = L = 1$ and $U = a_{11}$.\n", 77 "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors. Searching for $i_0$ is an $O(n)$ operation.\n", 78 "3. Write\n", 79 " $$S_{i_0} A = \\left( \\begin{array}{c|c} a_{i_0} & A_{12}' \\\\ \\hline A_{21}' & A_{22}' \\end{array} \\right), $$\n", 80 " where $A_{22}'$ is $(n - 1) \\times (n - 1)$ and $a_{i_0} \\ne 0$ — given $A$ is invertible. Since $S_{i_0}$ acts on $A$ by swaping the first and $i_0$-th rows, computing $S_{i_0} A$ is an $O(n)$ operation.\n", 81 "4. We want to solve the equation\n", 82 " $$S_{i_0} A = \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline 0 & P_{22} \\end{array} \\right) \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline L_{21} & L_{22} \\end{array} \\right) \\cdot \\left( \\begin{array}{c|c} u_{11} & U_{12} \\\\ \\hline 0 & U_{22} \\end{array} \\right),$$\n", 83 " where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonal entries and $U_{22}$ is upper triangular. In other words, we want to solve the equations\n", 84 " $$\n", 85 " \\begin{aligned}\n", 86 " a_{i_0} &= u_{11} & A_{12}' &= U_{12} \\\\\n", 87 " A_{21}' &= u_{11} P_{22} L_{21} & A_{22}' &= P_{22} L_{21} U_{12} + P_{22} L_{22} U_{22}.\n", 88 " \\end{aligned}\n", 89 " $$\n", 90 " We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}'$, so it remains to solve the bottom-right equation.\n", 91 "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular. Computing $A_{21}' A_{12}'$ (and thus $A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}'$) is, of course, a $O(n^3)$ operation. Now since $P_{22}$ is a permutation matrix, computing $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$ is an $O(n^2)$ operation.\n", 92 "7. Take\n", 93 " $$\n", 94 " \\begin{aligned}\n", 95 " L &= \\begin{pmatrix} 1 & 0 \\\\ L_{21} & L_{22} \\end{pmatrix} &\n", 96 " U &= \\begin{pmatrix} u_{11} & U_{12} \\\\ 0 & U_{22} \\end{pmatrix}\n", 97 " \\end{aligned}\n", 98 " $$\n", 99 " for $L_{21}, L_{22}, u_{11}, U_{12}, U_{22}$ as above, so that\n", 100 " $$\n", 101 " S_{i_0} A = \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix} L U.\n", 102 " $$\n", 103 "8. Hence by taking\n", 104 " $$\n", 105 " P = S_{i_0} \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix}\n", 106 " $$\n", 107 " we get $A = P L U$ as desired!\n", 108 "\n", 109 "In total, this algorithm takes $n$ recursive steps to solve $A = P L U$. Since each step involves $O(n^3)$ operations, the total complexity of our facotization algorithm is $O(n^4)$." 110 ] 111 }, 112 { 113 "cell_type": "markdown", 114 "id": "7c6678f2-9ae6-4daa-922c-828771c1a796", 115 "metadata": {}, 116 "source": [ 117 "The LAPACK algorithm for factorizing a $m \\times n$ matrix $A$ improves on this simple concept by instead taking the decomposition\n", 118 "$$\n", 119 "A =\n", 120 "\\left(\n", 121 "\\begin{array}{c|c}\n", 122 " A_{11} & A_{12} \\\\ \\hline\n", 123 " A_{21} & A_{22}\n", 124 "\\end{array}\n", 125 "\\right),\n", 126 "$$\n", 127 "where $A_{11}$ is a $\\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor \\times \\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor$ matrix.\n", 128 "This is implemented in the `GETRF` family of subroutines. A SciPy version of their precise algorithm might look something like the following:" 129 ] 130 }, 131 { 132 "cell_type": "code", 133 "execution_count": 3, 134 "id": "f942e1ba-0449-4f35-b894-4dc5cad385a6", 135 "metadata": {}, 136 "outputs": [], 137 "source": [ 138 "import numpy as np\n", 139 "import scipy.linalg as la\n", 140 "\n", 141 "def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):\n", 142 " \"\"\"Returns the P, L, U\n", 143 "\n", 144 " * A is m by n\n", 145 " * P is m by m\n", 146 " * L is m by n if m >= n and m by m if m <= n\n", 147 " * U is n by n if m >= n and m by n if m <= n\n", 148 " \"\"\"\n", 149 " m, n = A.shape\n", 150 " \n", 151 " # A is a row\n", 152 " if m == 1:\n", 153 " return np.eye(1), np.eye(1), A\n", 154 "\n", 155 " # A is a column\n", 156 " elif n == 1:\n", 157 " i0 = 0\n", 158 "\n", 159 " for i in range(m):\n", 160 " if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i\n", 161 "\n", 162 " P = np.eye(m)\n", 163 " P[0,0], P[i0,i0] = 0, 0\n", 164 " P[i0,0], P[0,i0] = 1, 1\n", 165 "\n", 166 " if A[i0, 0] != 0:\n", 167 " L = P@A / A[i0, 0]\n", 168 " U = A[i0, 0] * np.eye(1)\n", 169 " else:\n", 170 " L = A\n", 171 " U = np.zeros((1, 1))\n", 172 "\n", 173 " return P, L, U\n", 174 " else:\n", 175 " n1 = min(m, n)//2\n", 176 " n2 = n - n1\n", 177 "\n", 178 " # Write\n", 179 " #\n", 180 " # A = [[A11, A12],\n", 181 " # [A21, A22]],\n", 182 " #\n", 183 " # A1 = [[A11, \n", 184 " # A21]],\n", 185 " #\n", 186 " # A2 = [[A12, \n", 187 " # A22]]\n", 188 " #\n", 189 " # where A11 is n1 by n1 and A22 is n2 by n2\n", 190 " A11, A12 = A[:n1,:n1], A[:n1,n1:]\n", 191 " A21, A22 = A[n1:,:n1], A[n1:,n1:]\n", 192 " A1, A2 = A[:,:n1], A[:,n1:]\n", 193 "\n", 194 " # Solve the A1 block\n", 195 " P1, L1, U11 = getrf(A1)\n", 196 "\n", 197 " # Apply pivots\n", 198 " A2 = la.inv(P1) @ A2\n", 199 " A12, A22 = A2[:n1,:], A2[n1:,:]\n", 200 " \n", 201 " # Solve A12 \n", 202 " L11, L21 = L1[:n1,:], L1[n1:,:]\n", 203 " A12 = la.inv(L11) @ A12\n", 204 "\n", 205 " # Update A22\n", 206 " A22 = -L21@A12 + A22\n", 207 " \n", 208 " # Solve the A22 block\n", 209 " P2, L22, U22 = getrf(A22)\n", 210 "\n", 211 " # Take P = P1 @ P2_ for\n", 212 " #\n", 213 " # P2_ = [[1, 0,\n", 214 " # 0, P2]]\n", 215 " P2_ = np.eye(m)\n", 216 " P2_[n1:,n1:] = P2\n", 217 " P = P1 @ P2_\n", 218 "\n", 219 " # Apply interchanges to L21\n", 220 " L21 = la.inv(P2) @ L21\n", 221 "\n", 222 " # Take\n", 223 " # \n", 224 " # L = [[L11, 0],\n", 225 " # [L21, L22]],\n", 226 " #\n", 227 " # U = [[U11, A12],\n", 228 " # [ 0, U22]]\n", 229 " if m >= n:\n", 230 " L, U = np.zeros((m, n)), np.zeros((n, n))\n", 231 " else:\n", 232 " L, U = np.zeros((m, m)), np.zeros((m, n))\n", 233 " L[:n1,:n1] = L11\n", 234 " L[n1:,:n1] = L21\n", 235 " L[n1:,n1:] = L22\n", 236 " U[:n1,:n1] = U11\n", 237 " U[:n1,n1:] = A12\n", 238 " U[n1:,n1:] = U22\n", 239 "\n", 240 " return P, L, U" 241 ] 242 }, 243 { 244 "cell_type": "code", 245 "execution_count": null, 246 "id": "229ed6e3-493e-49c3-835a-0b8582aa6586", 247 "metadata": {}, 248 "outputs": [], 249 "source": [] 250 } 251 ], 252 "metadata": { 253 "kernelspec": { 254 "display_name": "Python 3 (ipykernel)", 255 "language": "python", 256 "name": "python3" 257 }, 258 "language_info": { 259 "codemirror_mode": { 260 "name": "ipython", 261 "version": 3 262 }, 263 "file_extension": ".py", 264 "mimetype": "text/x-python", 265 "name": "python", 266 "nbconvert_exporter": "python", 267 "pygments_lexer": "ipython3", 268 "version": "3.12.6" 269 } 270 }, 271 "nbformat": 4, 272 "nbformat_minor": 5 273 }