numeric-linalg
Educational material on the SciPy implementation of numerical linear algebra algorithms
getrf.py (2450B)
1 # "SciPy-transpiled" version of LAPACK's GETRF family of subroutines! 2 import numpy as np 3 import scipy.linalg as la 4 5 def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray): 6 """Returns the P, L, U 7 8 * A is m by n 9 * P is m by m 10 * L is m by n if m >= n and m by m if m <= n 11 * U is n by n if m >= n and m by n if m <= n 12 """ 13 m, n = A.shape 14 15 # A is a row 16 if m == 1: 17 return np.eye(1), np.eye(1), A 18 19 # A is a column 20 elif n == 1: 21 i0 = 0 22 23 for i in range(m): 24 if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i 25 26 # P permutes the 0-th and i0-th basis vectors 27 P = np.eye(m) 28 P[0,0], P[i0,i0] = 0, 0 29 P[i0,0], P[0,i0] = 1, 1 30 31 if A[i0, 0] != 0: 32 L = P@A / A[i0, 0] 33 U = A[i0, 0] * np.eye(1) 34 else: 35 L = A 36 U = np.zeros((1, 1)) 37 38 return P, L, U 39 else: 40 n1 = min(m, n)//2 41 n2 = n - n1 42 43 # Write 44 # 45 # A = [[A11, A12], 46 # [A21, A22]], 47 # 48 # A1 = [[A11, 49 # A21]], 50 # 51 # A2 = [[A12, 52 # A22]] 53 # 54 # where A11 is n1 by n1 and A22 is n2 by n2 55 A11, A12 = A[:n1,:n1], A[:n1,n1:] 56 A21, A22 = A[n1:,:n1], A[n1:,n1:] 57 A1, A2 = A[:,:n1], A[:,n1:] 58 59 # Solve the A1 block 60 P1, L1, U11 = getrf(A1) 61 62 # Apply pivots 63 # A2 is m by n2 64 A2 = la.inv(P1) @ A2 65 A12, A22 = A2[:n1,:], A2[n1:,:] 66 67 # Solve A12 68 L11, L21 = L1[:n1,:], L1[n1:,:] 69 A12 = la.inv(L11) @ A12 70 71 # Update A22 72 A22 = -L21@A12 + A22 73 74 # Solve the A22 block 75 P2, L22, U22 = getrf(A22) 76 77 # Take P = P1 @ P2_ for 78 # 79 # P2_ = [[1, 0, 80 # 0, P2]] 81 P2_ = np.eye(m) 82 P2_[n1:,n1:] = P2 83 P = P1 @ P2_ 84 85 # Apply interchanges to L21 86 L21 = la.inv(P2) @ L21 87 88 # Take 89 # 90 # L = [[L11, 0], 91 # [L21, L22]], 92 # 93 # U = [[U11, A12], 94 # [0, U22]] 95 if m >= n: 96 L, U = np.zeros((m, n)), np.zeros((n, n)) 97 else: 98 L, U = np.zeros((m, m)), np.zeros((m, n)) 99 L[:n1,:n1] = L11 100 L[n1:,:n1] = L21 101 L[n1:,n1:] = L22 102 U[:n1,:n1] = U11 103 U[:n1,n1:] = A12 104 U[n1:,n1:] = U22 105 106 return P, L, U