numeric-linalg

Educational material on the SciPy implementation of numerical linear algebra algorithms

getrf.py (2450B)

  1 # "SciPy-transpiled" version of LAPACK's GETRF family of subroutines!
  2 import numpy as np
  3 import scipy.linalg as la
  4 
  5 def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):
  6     """Returns the P, L, U
  7 
  8     * A is m by n
  9     * P is m by m
 10     * L is m by n if m >= n and m by m if m <= n
 11     * U is n by n if m >= n and m by n if m <= n
 12     """
 13     m, n = A.shape
 14     
 15     # A is a row
 16     if m == 1:
 17         return np.eye(1), np.eye(1), A
 18 
 19     # A is a column
 20     elif n == 1:
 21         i0 = 0
 22 
 23         for i in range(m):
 24             if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i
 25 
 26         # P permutes the 0-th and i0-th basis vectors
 27         P = np.eye(m)
 28         P[0,0],  P[i0,i0] = 0, 0
 29         P[i0,0], P[0,i0]  = 1, 1
 30 
 31         if A[i0, 0] != 0:
 32             L = P@A / A[i0, 0]
 33             U = A[i0, 0] * np.eye(1)
 34         else:
 35             L = A
 36             U = np.zeros((1, 1))
 37 
 38         return P, L, U
 39     else:
 40         n1 = min(m, n)//2
 41         n2 = n - n1
 42 
 43         # Write
 44         #
 45         #   A = [[A11, A12],
 46         #        [A21, A22]],
 47         #
 48         #   A1 = [[A11, 
 49         #          A21]],
 50         #
 51         #   A2 = [[A12, 
 52         #          A22]]
 53         #
 54         # where A11 is n1 by n1 and A22 is n2 by n2
 55         A11, A12 = A[:n1,:n1], A[:n1,n1:]
 56         A21, A22 = A[n1:,:n1], A[n1:,n1:]
 57         A1, A2   = A[:,:n1],   A[:,n1:]
 58 
 59         # Solve the A1 block
 60         P1, L1, U11 = getrf(A1)
 61 
 62         # Apply pivots
 63         # A2 is m by n2
 64         A2 = la.inv(P1) @ A2
 65         A12, A22 = A2[:n1,:], A2[n1:,:]
 66         
 67         # Solve A12 
 68         L11, L21 = L1[:n1,:], L1[n1:,:]
 69         A12 = la.inv(L11) @ A12
 70 
 71         # Update A22
 72         A22 = -L21@A12 + A22
 73         
 74         # Solve the A22 block
 75         P2, L22, U22 = getrf(A22)
 76 
 77         # Take P = P1 @ P2_ for
 78         #
 79         # P2_ = [[1, 0,
 80         #         0, P2]]
 81         P2_ = np.eye(m)
 82         P2_[n1:,n1:] = P2
 83         P = P1 @ P2_
 84 
 85         # Apply interchanges to L21
 86         L21 = la.inv(P2) @ L21
 87 
 88         # Take
 89         # 
 90         # L = [[L11, 0],
 91         #      [L21, L22]],
 92         #
 93         # U = [[U11, A12],
 94         #      [0, U22]]
 95         if m >= n:
 96             L, U = np.zeros((m, n)), np.zeros((n, n))
 97         else:
 98             L, U = np.zeros((m, m)), np.zeros((m, n))
 99         L[:n1,:n1] = L11
100         L[n1:,:n1] = L21
101         L[n1:,n1:] = L22
102         U[:n1,:n1] = U11
103         U[:n1,n1:] = A12
104         U[n1:,n1:] = U22
105 
106         return P, L, U