- Commit
- e981a0453b4e873149d9ae31dea1f8a91d8428c9
- Parent
- 9c7392a31b59f45a69505c0415b9b7e6387205d5
- Author
- Pablo <pablo-pie@riseup.net>
- Date
Finished transpiling the Fortran code to Python
Educational material on the SciPy implementation of numerical linear algebra algorithms
Finished transpiling the Fortran code to Python
3 files changed, 129 insertions, 52 deletions
Status | File Name | N° Changes | Insertions | Deletions |
Renamed | getrf.py -> getrf/getrf.py | 0 | 0 | 0 |
Renamed | test_getrf.py -> getrf/test.py | 0 | 0 | 0 |
Modified | linear-solvers.ipynb | 181 | 129 | 52 |
diff --git a/getrf.py b/getrf/getrf.py
diff --git a/test_getrf.py b/getrf/test.py
diff --git a/linear-solvers.ipynb b/linear-solvers.ipynb @@ -74,10 +74,10 @@ "As for the decomposition of $A$ in the first step, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting). A simple simple recurssive algorithm using such method might look something like the following:\n", "\n", "1. If $A = a_{11}$ is $1 \\times 1$ then take $P = L = 1$ and $U = a_{11}$.\n", - "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors.\n", + "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors. Searching for $i_0$ is an $O(n)$ operation.\n", "3. Write\n", " $$S_{i_0} A = \\left( \\begin{array}{c|c} a_{i_0} & A_{12}' \\\\ \\hline A_{21}' & A_{22}' \\end{array} \\right), $$\n", - " where $A_{22}'$ is $(n - 1) \\times (n - 1)$. Since $A$ is invertible, $a_{i_0} \\ne 0$.\n", + " where $A_{22}'$ is $(n - 1) \\times (n - 1)$ and $a_{i_0} \\ne 0$ — given $A$ is invertible. Since $S_{i_0}$ acts on $A$ by swaping the first and $i_0$-th rows, computing $S_{i_0} A$ is an $O(n)$ operation.\n", "4. We want to solve the equation\n", " $$S_{i_0} A = \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline 0 & P_{22} \\end{array} \\right) \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline L_{21} & L_{22} \\end{array} \\right) \\cdot \\left( \\begin{array}{c|c} u_{11} & U_{12} \\\\ \\hline 0 & U_{22} \\end{array} \\right),$$\n", " where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonal entries and $U_{22}$ is upper triangular. In other words, we want to solve the equations\n", @@ -87,9 +87,9 @@ " A_{21}' &= u_{11} P_{22} L_{21} & A_{22}' &= P_{22} L_{21} U_{12} + P_{22} L_{22} U_{22}.\n", " \\end{aligned}\n", " $$\n", - " We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$, so it remains to solve the bottom-right equation.\n", - "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular.\n", - "6. Take\n", + " We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}'$, so it remains to solve the bottom-right equation.\n", + "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular. Computing $A_{21}' A_{12}'$ (and thus $A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}'$) is, of course, a $O(n^3)$ operation. Now since $P_{22}$ is a permutation matrix, computing $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$ is an $O(n^2)$ operation.\n", + "7. Take\n", " $$\n", " \\begin{aligned}\n", " L &= \\begin{pmatrix} 1 & 0 \\\\ L_{21} & L_{22} \\end{pmatrix} &\n", @@ -100,73 +100,150 @@ " $$\n", " S_{i_0} A = \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix} L U.\n", " $$\n", - "7. Hence by taking\n", + "8. Hence by taking\n", " $$\n", " P = S_{i_0} \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix}\n", " $$\n", - " we get $A = P L U$ as desired!" - ] - }, - { - "cell_type": "markdown", - "id": "0e1e9c6b-5e65-4cf1-aa53-a7202abff788", - "metadata": { - "jp-MarkdownHeadingCollapsed": true - }, - "source": [ - "## Scrap" + " we get $A = P L U$ as desired!\n", + "\n", + "In total, this algorithm takes $n$ recursive steps to solve $A = P L U$. Since each step involves $O(n^3)$ operations, the total complexity of our facotization algorithm is $O(n^4)$." ] }, { "cell_type": "markdown", - "id": "4dde0b0b-08cc-4bca-9dd5-c2bced871cef", + "id": "7c6678f2-9ae6-4daa-922c-828771c1a796", "metadata": {}, "source": [ - "As for the decomposition of $A$ in the first step, given some $m \\times n$ matrix $A$, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting) to write $A = P L U$, where $P$ is a $m \\times m$ permutation matrix, $L$ is a $m \\times n$ lower trapezoidal matrix with $1$ in the diagona entries and $U$ is a $n \\times n$ upper triangular matrix. Their algorithm goes something like the following:\n", - "\n", - "1. If $A = 0$ then take $P = 1$, $L$ the $m \\times n$ matrix with $1$ in the diagonals and zero elsewhere and $U = 0$.\n", - "2. If $m = 1$, so that $A = \\begin{pmatrix} a_1 & \\cdots & a_n \\end{pmatrix}$, *DIE*\n", - "3. If $n = 1$, so that\n", - " $$ A = \\begin{pmatrix} a_1 \\\\ \\vdots \\\\ a_m \\end{pmatrix} $$\n", - " choose $i_0$ that maximizes $|a_{i_0}|$. Since $A \\ne 0$, $a_{i_0} \\ne 0$. Return $P = \\sigma_{1, i_0}$, $L = a_{i_0}^{-1} P^{-1} A$ and $U = 1$.\n", - "4. Otherwise do the following series of steps.\n", - " 1. Take $n_1 = \\left\\lfloor \\frac{\\min \\{n, m\\}}{2} \\right \\rfloor$ and $n_2 = n - n1$.\n", - " 2. Write\n", - " $$A = \\left( \\begin{array}{c|c} A_{11} & A_{12} \\\\ \\hline A_{21} & A_{22} \\end{array} \\right), $$\n", - " where $A_{11}$ is $n_1 \\times n_2$ and $A_{22}$ is $n_2 \\times n_2$.\n", - " 3. Write\n", - " $$\\begin{pmatrix} A_{11} \\\\ A_{21} \\end{pmatrix} = P_1 L_1 U_1 \\quad L_1 = \\begin{pmatrix} L_1' \\\\ L_1'' \\end{pmatrix} ,$$\n", - " where $P_1$ is a $m \\times m$ permutation matrix, $L_1$ is a $m \\times n_1$ lower trapezoidal matrix, $L_1'$ is $n_1 \\times n_1$ and $U_1$ is a $n_1 \\times n_1$ upper triangular matrix.\n", - " 4. Write\n", - " $$\\begin{pmatrix} A_{12}' \\\\ A_{22}' \\end{pmatrix} = P_1^{-1} \\begin{pmatrix} A_{12} \\\\ A_{22} \\end{pmatrix}$$\n", - " 5. Take\n", - " $$A_{12}'' = L_1'^{-1} A_{12}' \\quad A_{22}'' = -L_1'1A_{12}'' + A_{22}'$$\n", - " 6. Write $A_{12}'' = P_2 L_2 U_2$, where $P_2$ is a $n_1 \\times n_1$ permutation matrix, $L_2$ is a $n_1 \\times n_2$ lower trapezoidal matrix and $U_2$ is a $n_2 \\times n_2$ upper triangular matrix.\n", - " 7. Take\n", - " $$\n", - " \\begin{aligned}\n", - " P &= \\begin{pmatrix} P_1 & 0 \\\\ 0 & P_2 \\end{pmatrix} &\n", - " L &= \\begin{pmatrix} L_1 & 0 \\\\ A_{21}'' & L_2 \\end{pmatrix} &\n", - " U &= \\begin{pmatrix} U_1 & A_{12}'' \\\\ 0 & U_2 \\end{pmatrix},\n", - " \\end{aligned}\n", - " $$\n", - " where $A_{21}'' = P_2^{-1} L_1'1$." + "The LAPACK algorithm for factorizing a $m \\times n$ matrix $A$ improves on this simple concept by instead taking the decomposition\n", + "$$\n", + "A =\n", + "\\left(\n", + "\\begin{array}{c|c}\n", + " A_{11} & A_{12} \\\\ \\hline\n", + " A_{21} & A_{22}\n", + "\\end{array}\n", + "\\right),\n", + "$$\n", + "where $A_{11}$ is a $\\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor \\times \\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor$ matrix.\n", + "This is implemented in the `GETRF` family of subroutines. A SciPy version of their precise algorithm might look something like the following:" ] }, { - "cell_type": "markdown", - "id": "93106465-de78-418d-8c7b-9774b97f6b07", + "cell_type": "code", + "execution_count": 3, + "id": "f942e1ba-0449-4f35-b894-4dc5cad385a6", "metadata": {}, + "outputs": [], "source": [ - "Althought similar to the naive method of computing $A^{-1}$ and then applying it to $b$, this algorithm has better performance characteristics since it only computes $A^{-1} b$ (without calculating the product $A^{-1} = U^{-1} L^{-1} P^{-1}$). The `la.solve` function also uses optimized linear solvers when $A$ is known to be symmetric or Hermitian.\n", + "import numpy as np\n", + "import scipy.linalg as la\n", + "\n", + "def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):\n", + " \"\"\"Returns the P, L, U\n", + "\n", + " * A is m by n\n", + " * P is m by m\n", + " * L is m by n if m >= n and m by m if m <= n\n", + " * U is n by n if m >= n and m by n if m <= n\n", + " \"\"\"\n", + " m, n = A.shape\n", + " \n", + " # A is a row\n", + " if m == 1:\n", + " return np.eye(1), np.eye(1), A\n", + "\n", + " # A is a column\n", + " elif n == 1:\n", + " i0 = 0\n", + "\n", + " for i in range(m):\n", + " if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i\n", + "\n", + " P = np.eye(m)\n", + " P[0,0], P[i0,i0] = 0, 0\n", + " P[i0,0], P[0,i0] = 1, 1\n", + "\n", + " if A[i0, 0] != 0:\n", + " L = P@A / A[i0, 0]\n", + " U = A[i0, 0] * np.eye(1)\n", + " else:\n", + " L = A\n", + " U = np.zeros((1, 1))\n", + "\n", + " return P, L, U\n", + " else:\n", + " n1 = min(m, n)//2\n", + " n2 = n - n1\n", + "\n", + " # Write\n", + " #\n", + " # A = [[A11, A12],\n", + " # [A21, A22]],\n", + " #\n", + " # A1 = [[A11, \n", + " # A21]],\n", + " #\n", + " # A2 = [[A12, \n", + " # A22]]\n", + " #\n", + " # where A11 is n1 by n1 and A22 is n2 by n2\n", + " A11, A12 = A[:n1,:n1], A[:n1,n1:]\n", + " A21, A22 = A[n1:,:n1], A[n1:,n1:]\n", + " A1, A2 = A[:,:n1], A[:,n1:]\n", + "\n", + " # Solve the A1 block\n", + " P1, L1, U11 = getrf(A1)\n", + "\n", + " # Apply pivots\n", + " A2 = la.inv(P1) @ A2\n", + " A12, A22 = A2[:n1,:], A2[n1:,:]\n", + " \n", + " # Solve A12 \n", + " L11, L21 = L1[:n1,:], L1[n1:,:]\n", + " A12 = la.inv(L11) @ A12\n", + "\n", + " # Update A22\n", + " A22 = -L21@A12 + A22\n", + " \n", + " # Solve the A22 block\n", + " P2, L22, U22 = getrf(A22)\n", + "\n", + " # Take P = P1 @ P2_ for\n", + " #\n", + " # P2_ = [[1, 0,\n", + " # 0, P2]]\n", + " P2_ = np.eye(m)\n", + " P2_[n1:,n1:] = P2\n", + " P = P1 @ P2_\n", + "\n", + " # Apply interchanges to L21\n", + " L21 = la.inv(P2) @ L21\n", + "\n", + " # Take\n", + " # \n", + " # L = [[L11, 0],\n", + " # [L21, L22]],\n", + " #\n", + " # U = [[U11, A12],\n", + " # [ 0, U22]]\n", + " if m >= n:\n", + " L, U = np.zeros((m, n)), np.zeros((n, n))\n", + " else:\n", + " L, U = np.zeros((m, m)), np.zeros((m, n))\n", + " L[:n1,:n1] = L11\n", + " L[n1:,:n1] = L21\n", + " L[n1:,n1:] = L22\n", + " U[:n1,:n1] = U11\n", + " U[:n1,n1:] = A12\n", + " U[n1:,n1:] = U22\n", "\n", - "For more details please refer to the implementations of the `GETRS`, `SYSV` and `HESV` families of subroutines in LAPACK." + " return P, L, U" ] }, { "cell_type": "code", "execution_count": null, - "id": "6d77288f-56a8-4037-965d-56f2c443a73c", + "id": "229ed6e3-493e-49c3-835a-0b8582aa6586", "metadata": {}, "outputs": [], "source": []