numeric-linalg

Educational material on the SciPy implementation of numerical linear algebra algorithms

linear-solvers.ipynb (10886B)

  1 {
  2  "cells": [
  3   {
  4    "cell_type": "markdown",
  5    "id": "e1791697-e733-4721-9e10-fcdcfbda9064",
  6    "metadata": {},
  7    "source": [
  8     "# Solving Linear systems in SciPy\n",
  9     "\n",
 10     "A linear system is a system of equations of the form\n",
 11     "$$ \\left\\{ \\begin{aligned} a_{1 1} x_1 + a_{1 2} x_2 + \\cdots + a_{1 n} x_n &= b_1 \\\\ a_{2 1} x_1 + a_{2 2} x_2 + \\cdots + a_{2 n} x_n &= b_2 \\\\ & \\vdots \\\\ a_{n 1} x_1 + a_{n 2} x_2 + \\cdots + a_{n n} x_n &= b_n\\end{aligned} \\right.$$\n",
 12     "on the variables $x_1, \\ldots, x_n$.\n",
 13     "\n",
 14     "Solving this system is equivalent to solving the equation $A x = b$ on $x$ where $A = (a_{ij})_{ij}$ is a $n\\times n$ matrix and $x = (x_1, \\ldots, x_n) \\; \\& \\; b = (b_1, \\ldots, b_n)$ are vectors, which can always be done provided $A$ is invertible."
 15    ]
 16   },
 17   {
 18    "cell_type": "markdown",
 19    "id": "26bc6a30-7853-4c78-adfb-320f0a65dd10",
 20    "metadata": {},
 21    "source": [
 22     "In SciPy, we can solve linear systems using the `la.solve` function."
 23    ]
 24   },
 25   {
 26    "cell_type": "code",
 27    "execution_count": 1,
 28    "id": "b1ced47f-6783-48ed-a4e4-4f1f4e2b8835",
 29    "metadata": {},
 30    "outputs": [
 31     {
 32      "data": {
 33       "text/plain": [
 34        "array([[-0.12672176],\n",
 35        "       [ 0.1046832 ],\n",
 36        "       [ 1.19008264]])"
 37       ]
 38      },
 39      "execution_count": 1,
 40      "metadata": {},
 41      "output_type": "execute_result"
 42     }
 43    ],
 44    "source": [
 45     "import numpy as np\n",
 46     "import scipy.linalg as la\n",
 47     "\n",
 48     "A, b = np.array([[6,15,1],[8,7,12],[2,7,8]]), np.array([[2], [14], [10]])\n",
 49     "la.solve(A, b)"
 50    ]
 51   },
 52   {
 53    "cell_type": "markdown",
 54    "id": "bfdacdbf-150f-4b4d-8072-ee18758d3b60",
 55    "metadata": {},
 56    "source": [
 57     "## But how does `la.solve` work???\n",
 58     "\n",
 59     "Internally, the `la.solve` function uses the the [LAPACK library](https://netlib.org/lapack), a Fortran package for numerical linear algebra. The LAPACK generic linear solver algorithm goes something like the following:\n",
 60     "\n",
 61     "1. Decompose $A$ as $A = PL U$, where $P$ is permutation matrix, $L$ is a lower triangular matrix with $1$ in the diagonal and $U$ is an upper triangular matrix.\n",
 62     "3. Solve $P b' = b$ for $b'$, i.e. compute $b' = P^{-1} b$. Since $P$ is a permutation matrix, this operation is $O(n)$.\n",
 63     "4. Solve $L b'' = b'$ for $b''$, i.e. compute $b'' = L^{-1} P^{-1} b$. Since $L$ is known to be lower triangular, this operation is $O(n^2)$.\n",
 64     "3. Solve $U x = b''$ for $x$, i.e. compute $x = U^{-1} L^{-1} P^{-1} b = A^{-1} b$. Since $U$ is known to be upper triangular, this operation is $O(n^2)$.\n",
 65     "\n",
 66     "This is implemented in the `GETRS` family of subroutines."
 67    ]
 68   },
 69   {
 70    "cell_type": "markdown",
 71    "id": "5b7ad45f-ea81-4d46-9097-735cc159cf1a",
 72    "metadata": {},
 73    "source": [
 74     "As for the decomposition of $A$ in the first step, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting). A simple simple recurssive algorithm using such method might look something like the following:\n",
 75     "\n",
 76     "1. If $A = a_{11}$ is $1 \\times 1$ then take $P = L = 1$ and $U = a_{11}$.\n",
 77     "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors. Searching for $i_0$ is an $O(n)$ operation.\n",
 78     "3. Write\n",
 79     "   $$S_{i_0} A = \\left( \\begin{array}{c|c} a_{i_0} & A_{12}' \\\\ \\hline A_{21}' & A_{22}' \\end{array} \\right), $$\n",
 80     "   where $A_{22}'$ is $(n - 1) \\times (n - 1)$ and $a_{i_0} \\ne 0$ — given $A$ is invertible. Since $S_{i_0}$ acts on $A$ by swaping the first and $i_0$-th rows, computing $S_{i_0} A$ is an $O(n)$ operation.\n",
 81     "4. We want to solve the equation\n",
 82     "   $$S_{i_0} A = \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline 0 & P_{22} \\end{array} \\right) \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline L_{21} & L_{22} \\end{array} \\right) \\cdot \\left( \\begin{array}{c|c} u_{11} & U_{12} \\\\ \\hline 0 & U_{22} \\end{array} \\right),$$\n",
 83     "   where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonal entries and $U_{22}$ is upper triangular. In other words, we want to solve the equations\n",
 84     "   $$\n",
 85     "   \\begin{aligned}\n",
 86     "       a_{i_0} &= u_{11} & A_{12}' &= U_{12} \\\\\n",
 87     "       A_{21}' &= u_{11} P_{22} L_{21} & A_{22}' &= P_{22} L_{21} U_{12} + P_{22} L_{22} U_{22}.\n",
 88     "   \\end{aligned}\n",
 89     "   $$\n",
 90     "   We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}'$, so it remains to solve the bottom-right equation.\n",
 91     "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular. Computing $A_{21}' A_{12}'$ (and thus $A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}'$) is, of course, a $O(n^3)$ operation. Now since $P_{22}$ is a permutation matrix, computing $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$ is an $O(n^2)$ operation.\n",
 92     "7. Take\n",
 93     "   $$\n",
 94     "   \\begin{aligned}\n",
 95     "       L &= \\begin{pmatrix}      1 & 0      \\\\ L_{21} & L_{22} \\end{pmatrix} &\n",
 96     "       U &= \\begin{pmatrix} u_{11} & U_{12} \\\\      0 & U_{22} \\end{pmatrix}\n",
 97     "   \\end{aligned}\n",
 98     "   $$\n",
 99     "   for $L_{21}, L_{22}, u_{11}, U_{12}, U_{22}$ as above, so that\n",
100     "   $$\n",
101     "   S_{i_0} A = \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix} L U.\n",
102     "   $$\n",
103     "8. Hence by taking\n",
104     "   $$\n",
105     "   P = S_{i_0} \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix}\n",
106     "   $$\n",
107     "   we get $A = P L U$ as desired!\n",
108     "\n",
109     "In total, this algorithm takes $n$ recursive steps to solve $A = P L U$. Since each step involves $O(n^3)$ operations, the total complexity of our facotization algorithm is $O(n^4)$."
110    ]
111   },
112   {
113    "cell_type": "markdown",
114    "id": "7c6678f2-9ae6-4daa-922c-828771c1a796",
115    "metadata": {},
116    "source": [
117     "The LAPACK algorithm for factorizing a $m \\times n$ matrix $A$ improves on this simple concept by instead taking the decomposition\n",
118     "$$\n",
119     "A =\n",
120     "\\left(\n",
121     "\\begin{array}{c|c}\n",
122     "    A_{11} & A_{12} \\\\ \\hline\n",
123     "    A_{21} & A_{22}\n",
124     "\\end{array}\n",
125     "\\right),\n",
126     "$$\n",
127     "where $A_{11}$ is a $\\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor \\times \\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor$ matrix.\n",
128     "This is implemented in the `GETRF` family of subroutines. A SciPy version of their precise algorithm might look something like the following:"
129    ]
130   },
131   {
132    "cell_type": "code",
133    "execution_count": 3,
134    "id": "f942e1ba-0449-4f35-b894-4dc5cad385a6",
135    "metadata": {},
136    "outputs": [],
137    "source": [
138     "import numpy as np\n",
139     "import scipy.linalg as la\n",
140     "\n",
141     "def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):\n",
142     "    \"\"\"Returns the P, L, U\n",
143     "\n",
144     "    * A is m by n\n",
145     "    * P is m by m\n",
146     "    * L is m by n if m >= n and m by m if m <= n\n",
147     "    * U is n by n if m >= n and m by n if m <= n\n",
148     "    \"\"\"\n",
149     "    m, n = A.shape\n",
150     "    \n",
151     "    # A is a row\n",
152     "    if m == 1:\n",
153     "        return np.eye(1), np.eye(1), A\n",
154     "\n",
155     "    # A is a column\n",
156     "    elif n == 1:\n",
157     "        i0 = 0\n",
158     "\n",
159     "        for i in range(m):\n",
160     "            if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i\n",
161     "\n",
162     "        P = np.eye(m)\n",
163     "        P[0,0],  P[i0,i0] = 0, 0\n",
164     "        P[i0,0], P[0,i0]  = 1, 1\n",
165     "\n",
166     "        if A[i0, 0] != 0:\n",
167     "            L = P@A / A[i0, 0]\n",
168     "            U = A[i0, 0] * np.eye(1)\n",
169     "        else:\n",
170     "            L = A\n",
171     "            U = np.zeros((1, 1))\n",
172     "\n",
173     "        return P, L, U\n",
174     "    else:\n",
175     "        n1 = min(m, n)//2\n",
176     "        n2 = n - n1\n",
177     "\n",
178     "        # Write\n",
179     "        #\n",
180     "        #   A = [[A11, A12],\n",
181     "        #        [A21, A22]],\n",
182     "        #\n",
183     "        #   A1 = [[A11, \n",
184     "        #          A21]],\n",
185     "        #\n",
186     "        #   A2 = [[A12, \n",
187     "        #          A22]]\n",
188     "        #\n",
189     "        # where A11 is n1 by n1 and A22 is n2 by n2\n",
190     "        A11, A12 = A[:n1,:n1], A[:n1,n1:]\n",
191     "        A21, A22 = A[n1:,:n1], A[n1:,n1:]\n",
192     "        A1, A2   = A[:,:n1],   A[:,n1:]\n",
193     "\n",
194     "        # Solve the A1 block\n",
195     "        P1, L1, U11 = getrf(A1)\n",
196     "\n",
197     "        # Apply pivots\n",
198     "        A2 = la.inv(P1) @ A2\n",
199     "        A12, A22 = A2[:n1,:], A2[n1:,:]\n",
200     "        \n",
201     "        # Solve A12 \n",
202     "        L11, L21 = L1[:n1,:], L1[n1:,:]\n",
203     "        A12 = la.inv(L11) @ A12\n",
204     "\n",
205     "        # Update A22\n",
206     "        A22 = -L21@A12 + A22\n",
207     "        \n",
208     "        # Solve the A22 block\n",
209     "        P2, L22, U22 = getrf(A22)\n",
210     "\n",
211     "        # Take P = P1 @ P2_ for\n",
212     "        #\n",
213     "        # P2_ = [[1, 0,\n",
214     "        #         0, P2]]\n",
215     "        P2_ = np.eye(m)\n",
216     "        P2_[n1:,n1:] = P2\n",
217     "        P = P1 @ P2_\n",
218     "\n",
219     "        # Apply interchanges to L21\n",
220     "        L21 = la.inv(P2) @ L21\n",
221     "\n",
222     "        # Take\n",
223     "        # \n",
224     "        # L = [[L11, 0],\n",
225     "        #      [L21, L22]],\n",
226     "        #\n",
227     "        # U = [[U11, A12],\n",
228     "        #      [  0, U22]]\n",
229     "        if m >= n:\n",
230     "            L, U = np.zeros((m, n)), np.zeros((n, n))\n",
231     "        else:\n",
232     "            L, U = np.zeros((m, m)), np.zeros((m, n))\n",
233     "        L[:n1,:n1] = L11\n",
234     "        L[n1:,:n1] = L21\n",
235     "        L[n1:,n1:] = L22\n",
236     "        U[:n1,:n1] = U11\n",
237     "        U[:n1,n1:] = A12\n",
238     "        U[n1:,n1:] = U22\n",
239     "\n",
240     "        return P, L, U"
241    ]
242   },
243   {
244    "cell_type": "code",
245    "execution_count": null,
246    "id": "229ed6e3-493e-49c3-835a-0b8582aa6586",
247    "metadata": {},
248    "outputs": [],
249    "source": []
250   }
251  ],
252  "metadata": {
253   "kernelspec": {
254    "display_name": "Python 3 (ipykernel)",
255    "language": "python",
256    "name": "python3"
257   },
258   "language_info": {
259    "codemirror_mode": {
260     "name": "ipython",
261     "version": 3
262    },
263    "file_extension": ".py",
264    "mimetype": "text/x-python",
265    "name": "python",
266    "nbconvert_exporter": "python",
267    "pygments_lexer": "ipython3",
268    "version": "3.12.6"
269   }
270  },
271  "nbformat": 4,
272  "nbformat_minor": 5
273 }