diff --git a/linear-solvers.ipynb b/linear-solvers.ipynb
@@ -74,10 +74,10 @@
"As for the decomposition of $A$ in the first step, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting). A simple simple recurssive algorithm using such method might look something like the following:\n",
"\n",
"1. If $A = a_{11}$ is $1 \\times 1$ then take $P = L = 1$ and $U = a_{11}$.\n",
- "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors.\n",
+ "2. If $A$ is $n \\times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \\times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors. Searching for $i_0$ is an $O(n)$ operation.\n",
"3. Write\n",
" $$S_{i_0} A = \\left( \\begin{array}{c|c} a_{i_0} & A_{12}' \\\\ \\hline A_{21}' & A_{22}' \\end{array} \\right), $$\n",
- " where $A_{22}'$ is $(n - 1) \\times (n - 1)$. Since $A$ is invertible, $a_{i_0} \\ne 0$.\n",
+ " where $A_{22}'$ is $(n - 1) \\times (n - 1)$ and $a_{i_0} \\ne 0$ — given $A$ is invertible. Since $S_{i_0}$ acts on $A$ by swaping the first and $i_0$-th rows, computing $S_{i_0} A$ is an $O(n)$ operation.\n",
"4. We want to solve the equation\n",
" $$S_{i_0} A = \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline 0 & P_{22} \\end{array} \\right) \\left( \\begin{array}{c|c} 1 & 0 \\\\ \\hline L_{21} & L_{22} \\end{array} \\right) \\cdot \\left( \\begin{array}{c|c} u_{11} & U_{12} \\\\ \\hline 0 & U_{22} \\end{array} \\right),$$\n",
" where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonal entries and $U_{22}$ is upper triangular. In other words, we want to solve the equations\n",
@@ -87,9 +87,9 @@
" A_{21}' &= u_{11} P_{22} L_{21} & A_{22}' &= P_{22} L_{21} U_{12} + P_{22} L_{22} U_{22}.\n",
" \\end{aligned}\n",
" $$\n",
- " We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$, so it remains to solve the bottom-right equation.\n",
- "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular.\n",
- "6. Take\n",
+ " We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}'$, so it remains to solve the bottom-right equation.\n",
+ "5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular. Computing $A_{21}' A_{12}'$ (and thus $A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}'$) is, of course, a $O(n^3)$ operation. Now since $P_{22}$ is a permutation matrix, computing $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$ is an $O(n^2)$ operation.\n",
+ "7. Take\n",
" $$\n",
" \\begin{aligned}\n",
" L &= \\begin{pmatrix} 1 & 0 \\\\ L_{21} & L_{22} \\end{pmatrix} &\n",
@@ -100,73 +100,150 @@
" $$\n",
" S_{i_0} A = \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix} L U.\n",
" $$\n",
- "7. Hence by taking\n",
+ "8. Hence by taking\n",
" $$\n",
" P = S_{i_0} \\begin{pmatrix} 1 & 0 \\\\ 0 & P_{22} \\end{pmatrix}\n",
" $$\n",
- " we get $A = P L U$ as desired!"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0e1e9c6b-5e65-4cf1-aa53-a7202abff788",
- "metadata": {
- "jp-MarkdownHeadingCollapsed": true
- },
- "source": [
- "## Scrap"
+ " we get $A = P L U$ as desired!\n",
+ "\n",
+ "In total, this algorithm takes $n$ recursive steps to solve $A = P L U$. Since each step involves $O(n^3)$ operations, the total complexity of our facotization algorithm is $O(n^4)$."
]
},
{
"cell_type": "markdown",
- "id": "4dde0b0b-08cc-4bca-9dd5-c2bced871cef",
+ "id": "7c6678f2-9ae6-4daa-922c-828771c1a796",
"metadata": {},
"source": [
- "As for the decomposition of $A$ in the first step, given some $m \\times n$ matrix $A$, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting) to write $A = P L U$, where $P$ is a $m \\times m$ permutation matrix, $L$ is a $m \\times n$ lower trapezoidal matrix with $1$ in the diagona entries and $U$ is a $n \\times n$ upper triangular matrix. Their algorithm goes something like the following:\n",
- "\n",
- "1. If $A = 0$ then take $P = 1$, $L$ the $m \\times n$ matrix with $1$ in the diagonals and zero elsewhere and $U = 0$.\n",
- "2. If $m = 1$, so that $A = \\begin{pmatrix} a_1 & \\cdots & a_n \\end{pmatrix}$, *DIE*\n",
- "3. If $n = 1$, so that\n",
- " $$ A = \\begin{pmatrix} a_1 \\\\ \\vdots \\\\ a_m \\end{pmatrix} $$\n",
- " choose $i_0$ that maximizes $|a_{i_0}|$. Since $A \\ne 0$, $a_{i_0} \\ne 0$. Return $P = \\sigma_{1, i_0}$, $L = a_{i_0}^{-1} P^{-1} A$ and $U = 1$.\n",
- "4. Otherwise do the following series of steps.\n",
- " 1. Take $n_1 = \\left\\lfloor \\frac{\\min \\{n, m\\}}{2} \\right \\rfloor$ and $n_2 = n - n1$.\n",
- " 2. Write\n",
- " $$A = \\left( \\begin{array}{c|c} A_{11} & A_{12} \\\\ \\hline A_{21} & A_{22} \\end{array} \\right), $$\n",
- " where $A_{11}$ is $n_1 \\times n_2$ and $A_{22}$ is $n_2 \\times n_2$.\n",
- " 3. Write\n",
- " $$\\begin{pmatrix} A_{11} \\\\ A_{21} \\end{pmatrix} = P_1 L_1 U_1 \\quad L_1 = \\begin{pmatrix} L_1' \\\\ L_1'' \\end{pmatrix} ,$$\n",
- " where $P_1$ is a $m \\times m$ permutation matrix, $L_1$ is a $m \\times n_1$ lower trapezoidal matrix, $L_1'$ is $n_1 \\times n_1$ and $U_1$ is a $n_1 \\times n_1$ upper triangular matrix.\n",
- " 4. Write\n",
- " $$\\begin{pmatrix} A_{12}' \\\\ A_{22}' \\end{pmatrix} = P_1^{-1} \\begin{pmatrix} A_{12} \\\\ A_{22} \\end{pmatrix}$$\n",
- " 5. Take\n",
- " $$A_{12}'' = L_1'^{-1} A_{12}' \\quad A_{22}'' = -L_1'1A_{12}'' + A_{22}'$$\n",
- " 6. Write $A_{12}'' = P_2 L_2 U_2$, where $P_2$ is a $n_1 \\times n_1$ permutation matrix, $L_2$ is a $n_1 \\times n_2$ lower trapezoidal matrix and $U_2$ is a $n_2 \\times n_2$ upper triangular matrix.\n",
- " 7. Take\n",
- " $$\n",
- " \\begin{aligned}\n",
- " P &= \\begin{pmatrix} P_1 & 0 \\\\ 0 & P_2 \\end{pmatrix} &\n",
- " L &= \\begin{pmatrix} L_1 & 0 \\\\ A_{21}'' & L_2 \\end{pmatrix} &\n",
- " U &= \\begin{pmatrix} U_1 & A_{12}'' \\\\ 0 & U_2 \\end{pmatrix},\n",
- " \\end{aligned}\n",
- " $$\n",
- " where $A_{21}'' = P_2^{-1} L_1'1$."
+ "The LAPACK algorithm for factorizing a $m \\times n$ matrix $A$ improves on this simple concept by instead taking the decomposition\n",
+ "$$\n",
+ "A =\n",
+ "\\left(\n",
+ "\\begin{array}{c|c}\n",
+ " A_{11} & A_{12} \\\\ \\hline\n",
+ " A_{21} & A_{22}\n",
+ "\\end{array}\n",
+ "\\right),\n",
+ "$$\n",
+ "where $A_{11}$ is a $\\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor \\times \\left\\lfloor \\frac{\\min \\{m, n\\}}{2} \\right\\rfloor$ matrix.\n",
+ "This is implemented in the `GETRF` family of subroutines. A SciPy version of their precise algorithm might look something like the following:"
]
},
{
- "cell_type": "markdown",
- "id": "93106465-de78-418d-8c7b-9774b97f6b07",
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "f942e1ba-0449-4f35-b894-4dc5cad385a6",
"metadata": {},
+ "outputs": [],
"source": [
- "Althought similar to the naive method of computing $A^{-1}$ and then applying it to $b$, this algorithm has better performance characteristics since it only computes $A^{-1} b$ (without calculating the product $A^{-1} = U^{-1} L^{-1} P^{-1}$). The `la.solve` function also uses optimized linear solvers when $A$ is known to be symmetric or Hermitian.\n",
+ "import numpy as np\n",
+ "import scipy.linalg as la\n",
+ "\n",
+ "def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):\n",
+ " \"\"\"Returns the P, L, U\n",
+ "\n",
+ " * A is m by n\n",
+ " * P is m by m\n",
+ " * L is m by n if m >= n and m by m if m <= n\n",
+ " * U is n by n if m >= n and m by n if m <= n\n",
+ " \"\"\"\n",
+ " m, n = A.shape\n",
+ " \n",
+ " # A is a row\n",
+ " if m == 1:\n",
+ " return np.eye(1), np.eye(1), A\n",
+ "\n",
+ " # A is a column\n",
+ " elif n == 1:\n",
+ " i0 = 0\n",
+ "\n",
+ " for i in range(m):\n",
+ " if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i\n",
+ "\n",
+ " P = np.eye(m)\n",
+ " P[0,0], P[i0,i0] = 0, 0\n",
+ " P[i0,0], P[0,i0] = 1, 1\n",
+ "\n",
+ " if A[i0, 0] != 0:\n",
+ " L = P@A / A[i0, 0]\n",
+ " U = A[i0, 0] * np.eye(1)\n",
+ " else:\n",
+ " L = A\n",
+ " U = np.zeros((1, 1))\n",
+ "\n",
+ " return P, L, U\n",
+ " else:\n",
+ " n1 = min(m, n)//2\n",
+ " n2 = n - n1\n",
+ "\n",
+ " # Write\n",
+ " #\n",
+ " # A = [[A11, A12],\n",
+ " # [A21, A22]],\n",
+ " #\n",
+ " # A1 = [[A11, \n",
+ " # A21]],\n",
+ " #\n",
+ " # A2 = [[A12, \n",
+ " # A22]]\n",
+ " #\n",
+ " # where A11 is n1 by n1 and A22 is n2 by n2\n",
+ " A11, A12 = A[:n1,:n1], A[:n1,n1:]\n",
+ " A21, A22 = A[n1:,:n1], A[n1:,n1:]\n",
+ " A1, A2 = A[:,:n1], A[:,n1:]\n",
+ "\n",
+ " # Solve the A1 block\n",
+ " P1, L1, U11 = getrf(A1)\n",
+ "\n",
+ " # Apply pivots\n",
+ " A2 = la.inv(P1) @ A2\n",
+ " A12, A22 = A2[:n1,:], A2[n1:,:]\n",
+ " \n",
+ " # Solve A12 \n",
+ " L11, L21 = L1[:n1,:], L1[n1:,:]\n",
+ " A12 = la.inv(L11) @ A12\n",
+ "\n",
+ " # Update A22\n",
+ " A22 = -L21@A12 + A22\n",
+ " \n",
+ " # Solve the A22 block\n",
+ " P2, L22, U22 = getrf(A22)\n",
+ "\n",
+ " # Take P = P1 @ P2_ for\n",
+ " #\n",
+ " # P2_ = [[1, 0,\n",
+ " # 0, P2]]\n",
+ " P2_ = np.eye(m)\n",
+ " P2_[n1:,n1:] = P2\n",
+ " P = P1 @ P2_\n",
+ "\n",
+ " # Apply interchanges to L21\n",
+ " L21 = la.inv(P2) @ L21\n",
+ "\n",
+ " # Take\n",
+ " # \n",
+ " # L = [[L11, 0],\n",
+ " # [L21, L22]],\n",
+ " #\n",
+ " # U = [[U11, A12],\n",
+ " # [ 0, U22]]\n",
+ " if m >= n:\n",
+ " L, U = np.zeros((m, n)), np.zeros((n, n))\n",
+ " else:\n",
+ " L, U = np.zeros((m, m)), np.zeros((m, n))\n",
+ " L[:n1,:n1] = L11\n",
+ " L[n1:,:n1] = L21\n",
+ " L[n1:,n1:] = L22\n",
+ " U[:n1,:n1] = U11\n",
+ " U[:n1,n1:] = A12\n",
+ " U[n1:,n1:] = U22\n",
"\n",
- "For more details please refer to the implementations of the `GETRS`, `SYSV` and `HESV` families of subroutines in LAPACK."
+ " return P, L, U"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "6d77288f-56a8-4037-965d-56f2c443a73c",
+ "id": "229ed6e3-493e-49c3-835a-0b8582aa6586",
"metadata": {},
"outputs": [],
"source": []