numeric-linalg

Educational material on the SciPy implementation of numerical linear algebra algorithms

NameSizeMode
..
getrf/getrf.py 2450B -rw-r--r--
001
002
003
004
005
006
007
008
009
010
011
012
013
014
015
016
017
018
019
020
021
022
023
024
025
026
027
028
029
030
031
032
033
034
035
036
037
038
039
040
041
042
043
044
045
046
047
048
049
050
051
052
053
054
055
056
057
058
059
060
061
062
063
064
065
066
067
068
069
070
071
072
073
074
075
076
077
078
079
080
081
082
083
084
085
086
087
088
089
090
091
092
093
094
095
096
097
098
099
100
101
102
103
104
105
106
# "SciPy-transpiled" version of LAPACK's GETRF family of subroutines!
import numpy as np
import scipy.linalg as la

def getrf(A: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):
    """Returns the P, L, U

    * A is m by n
    * P is m by m
    * L is m by n if m >= n and m by m if m <= n
    * U is n by n if m >= n and m by n if m <= n
    """
    m, n = A.shape
    
    # A is a row
    if m == 1:
        return np.eye(1), np.eye(1), A

    # A is a column
    elif n == 1:
        i0 = 0

        for i in range(m):
            if abs(A[i, 0]) > abs(A[i0, 0]): i0 = i

        # P permutes the 0-th and i0-th basis vectors
        P = np.eye(m)
        P[0,0],  P[i0,i0] = 0, 0
        P[i0,0], P[0,i0]  = 1, 1

        if A[i0, 0] != 0:
            L = P@A / A[i0, 0]
            U = A[i0, 0] * np.eye(1)
        else:
            L = A
            U = np.zeros((1, 1))

        return P, L, U
    else:
        n1 = min(m, n)//2
        n2 = n - n1

        # Write
        #
        #   A = [[A11, A12],
        #        [A21, A22]],
        #
        #   A1 = [[A11, 
        #          A21]],
        #
        #   A2 = [[A12, 
        #          A22]]
        #
        # where A11 is n1 by n1 and A22 is n2 by n2
        A11, A12 = A[:n1,:n1], A[:n1,n1:]
        A21, A22 = A[n1:,:n1], A[n1:,n1:]
        A1, A2   = A[:,:n1],   A[:,n1:]

        # Solve the A1 block
        P1, L1, U11 = getrf(A1)

        # Apply pivots
        # A2 is m by n2
        A2 = la.inv(P1) @ A2
        A12, A22 = A2[:n1,:], A2[n1:,:]
        
        # Solve A12 
        L11, L21 = L1[:n1,:], L1[n1:,:]
        A12 = la.inv(L11) @ A12

        # Update A22
        A22 = -L21@A12 + A22
        
        # Solve the A22 block
        P2, L22, U22 = getrf(A22)

        # Take P = P1 @ P2_ for
        #
        # P2_ = [[1, 0,
        #         0, P2]]
        P2_ = np.eye(m)
        P2_[n1:,n1:] = P2
        P = P1 @ P2_

        # Apply interchanges to L21
        L21 = la.inv(P2) @ L21

        # Take
        # 
        # L = [[L11, 0],
        #      [L21, L22]],
        #
        # U = [[U11, A12],
        #      [0, U22]]
        if m >= n:
            L, U = np.zeros((m, n)), np.zeros((n, n))
        else:
            L, U = np.zeros((m, m)), np.zeros((m, n))
        L[:n1,:n1] = L11
        L[n1:,:n1] = L21
        L[n1:,n1:] = L22
        U[:n1,:n1] = U11
        U[:n1,n1:] = A12
        U[n1:,n1:] = U22

        return P, L, U