# Solving Linear systems in SciPy

A linear system is a system of equations of the form
$$ \left\{ \begin{aligned} a_{1 1} x_1 + a_{1 2} x_2 + \cdots + a_{1 n} x_n &= b_1 \\ a_{2 1} x_1 + a_{2 2} x_2 + \cdots + a_{2 n} x_n &= b_2 \\ & \vdots \\ a_{n 1} x_1 + a_{n 2} x_2 + \cdots + a_{n n} x_n &= b_n\end{aligned} \right.$$
on the variables $x_1, \ldots, x_n$.

Solving this system is equivalent to solving the equation $A x = b$ on $x$ where $A = (a_{ij})_{ij}$ is a $n\times n$ matrix and $x = (x_1, \ldots, x_n) \; \& \; b = (b_1, \ldots, b_n)$ are vectors, which can always be done provided $A$ is invertible.

In SciPy, we can solve linear systems using the `la.solve` function.

In [1]:
import numpy as np
import scipy.linalg as la

A, b = np.array([[6,15,1],[8,7,12],[2,7,8]]), np.array([[2], [14], [10]])
la.solve(A, b)

array([[-0.12672176],
       [ 0.1046832 ],
       [ 1.19008264]])

## But how does `la.solve` work???

Internally, the `la.solve` function uses the the [LAPACK library](https://netlib.org/lapack), a Fortran package for numerical linear algebra. The LAPACK generic linear solver algorithm goes something like the following:

1. Decompose $A$ as $A = PL U$, where $P$ is permutation matrix, $L$ is a lower triangular matrix with $1$ in the diagonal and $U$ is an upper triangular matrix.
3. Solve $P b' = b$ for $b'$, i.e. compute $b' = P^{-1} b$. Since $P$ is a permutation matrix, this operation is $O(n)$.
4. Solve $L b'' = b'$ for $b''$, i.e. compute $b'' = L^{-1} P^{-1} b$. Since $L$ is known to be lower triangular, this operation is $O(n^2)$.
3. Solve $U x = b''$ for $x$, i.e. compute $x = U^{-1} L^{-1} P^{-1} b = A^{-1} b$. Since $U$ is known to be upper triangular, this operation is $O(n^2)$.

This is implemented in the `GETRS` family of subroutines.

As for the decomposition of $A$ in the first step, LAPACK uses a method called [_partial pivoting_](https://en.wikipedia.org/wiki/LU_decomposition#LU_factorization_with_partial_pivoting). A simple simple recurssive algorithm using such method might look something like the following:

1. If $A = a_{11}$ is $1 \times 1$ then take $P = L = 1$ and $U = a_{11}$.
2. If $A$ is $n \times n$ for $n > 1$, choose $i_0$ that maximizes $|a_{i_0, 1}|$ and consider the $n \times n$ permutation matrix $S_{i_0}$ that swaps the first and $i_0$-th basis vectors. Searching for $i_0$ is an $O(n)$ operation.
3. Write
   $$S_{i_0} A = \left( \begin{array}{c|c} a_{i_0} & A_{12}' \\ \hline A_{21}' & A_{22}' \end{array} \right), $$
   where $A_{22}'$ is $(n - 1) \times (n - 1)$ and $a_{i_0} \ne 0$ â€” given $A$ is invertible. Since $S_{i_0}$ acts on $A$ by swaping the first and $i_0$-th rows, computing $S_{i_0} A$ is an $O(n)$ operation.
4. We want to solve the equation
   $$S_{i_0} A = \left( \begin{array}{c|c} 1 & 0 \\ \hline 0 & P_{22} \end{array} \right) \left( \begin{array}{c|c} 1 & 0 \\ \hline L_{21} & L_{22} \end{array} \right) \cdot \left( \begin{array}{c|c} u_{11} & U_{12} \\ \hline 0 & U_{22} \end{array} \right),$$
   where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonal entries and $U_{22}$ is upper triangular. In other words, we want to solve the equations
   $$
   \begin{aligned}
       a_{i_0} &= u_{11} & A_{12}' &= U_{12} \\
       A_{21}' &= u_{11} P_{22} L_{21} & A_{22}' &= P_{22} L_{21} U_{12} + P_{22} L_{22} U_{22}.
   \end{aligned}
   $$
   We must take $u_{11} = a_{i_0}$, $U_{12} = A_{12}'$ and $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}'$, so it remains to solve the bottom-right equation.
5. Write $(A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}') = P_{22} L_{22} U_{22}$, where $P_{22}$ is a permutation matrix, $L_{22}$ is lower triangular with $1$ in the diagonals and $U_{22}$ is upper triangular. Computing $A_{21}' A_{12}'$ (and thus $A_{22}' - a_{i_0}^{-1} A_{21}' A_{12}'$) is, of course, a $O(n^3)$ operation. Now since $P_{22}$ is a permutation matrix, computing $L_{21} = a_{i_0}^{-1} P_{22}^{-1} A_{12}$ is an $O(n^2)$ operation.
7. Take
   $$
   \begin{aligned}
       L &= \begin{pmatrix}      1 & 0      \\ L_{21} & L_{22} \end{pmatrix} &
       U &= \begin{pmatrix} u_{11} & U_{12} \\      0 & U_{22} \end{pmatrix}
   \end{aligned}
   $$
   for $L_{21}, L_{22}, u_{11}, U_{12}, U_{22}$ as above, so that
   $$
   S_{i_0} A = \begin{pmatrix} 1 & 0 \\ 0 & P_{22} \end{pmatrix} L U.
   $$
8. Hence by taking
   $$
   P = S_{i_0} \begin{pmatrix} 1 & 0 \\ 0 & P_{22} \end{pmatrix}
   $$
   we get $A = P L U$ as desired!

In total, this algorithm takes $n$ recursive steps to solve $A = P L U$. Since each step involves $O(n^3)$ operations, the total complexity of our facotization algorithm is $O(n^4)$.

The LAPACK algorithm for factorizing a $m \times n$ matrix $A$ improves on this simple concept by instead taking the decomposition
$$
A =
\left(
\begin{array}{c|c}
    A_{11} & A_{12} \\ \hline
    A_{21} & A_{22}
\end{array}
\right),
$$
where $A_{11}$ is a $\left\lfloor \frac{\min \{m, n\}}{2} \right\rfloor \times \left\lfloor \frac{\min \{m, n\}}{2} \right\rfloor$ matrix.
This is implemented in the `GETRF` family of subroutines. A SciPy version of their precise algorithm might look something like the following:

In [3]:
import numpy as np
import scipy.linalg as la

def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):
    """Returns the P, L, U

    * A is m by n
    * P is m by m
    * L is m by n if m >= n and m by m if m <= n
    * U is n by n if m >= n and m by n if m <= n
    """
    m, n = A.shape
    
    # A is a row
    if m == 1:
        return np.eye(1), np.eye(1), A

    # A is a column
    elif n == 1:
        i0 = 0

        for i in range(m):
            if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i

        P = np.eye(m)
        P[0,0],  P[i0,i0] = 0, 0
        P[i0,0], P[0,i0]  = 1, 1

        if A[i0, 0] != 0:
            L = P@A / A[i0, 0]
            U = A[i0, 0] * np.eye(1)
        else:
            L = A
            U = np.zeros((1, 1))

        return P, L, U
    else:
        n1 = min(m, n)//2
        n2 = n - n1

        # Write
        #
        #   A = [[A11, A12],
        #        [A21, A22]],
        #
        #   A1 = [[A11, 
        #          A21]],
        #
        #   A2 = [[A12, 
        #          A22]]
        #
        # where A11 is n1 by n1 and A22 is n2 by n2
        A11, A12 = A[:n1,:n1], A[:n1,n1:]
        A21, A22 = A[n1:,:n1], A[n1:,n1:]
        A1, A2   = A[:,:n1],   A[:,n1:]

        # Solve the A1 block
        P1, L1, U11 = getrf(A1)

        # Apply pivots
        A2 = la.inv(P1) @ A2
        A12, A22 = A2[:n1,:], A2[n1:,:]
        
        # Solve A12 
        L11, L21 = L1[:n1,:], L1[n1:,:]
        A12 = la.inv(L11) @ A12

        # Update A22
        A22 = -L21@A12 + A22
        
        # Solve the A22 block
        P2, L22, U22 = getrf(A22)

        # Take P = P1 @ P2_ for
        #
        # P2_ = [[1, 0,
        #         0, P2]]
        P2_ = np.eye(m)
        P2_[n1:,n1:] = P2
        P = P1 @ P2_

        # Apply interchanges to L21
        L21 = la.inv(P2) @ L21

        # Take
        # 
        # L = [[L11, 0],
        #      [L21, L22]],
        #
        # U = [[U11, A12],
        #      [  0, U22]]
        if m >= n:
            L, U = np.zeros((m, n)), np.zeros((n, n))
        else:
            L, U = np.zeros((m, m)), np.zeros((m, n))
        L[:n1,:n1] = L11
        L[n1:,:n1] = L21
        L[n1:,n1:] = L22
        U[:n1,:n1] = U11
        U[:n1,n1:] = A12
        U[n1:,n1:] = U22

        return P, L, U