Faster LU decomposition
This commit is contained in:
parent
01bb5c400a
commit
fc02d40a66
|
@ -25,181 +25,31 @@
|
|||
; 5. Eigenvalues and eigenvectors
|
||||
|
||||
(provide
|
||||
matrix-inverse
|
||||
; row and column
|
||||
matrix-scale-row
|
||||
matrix-scale-column
|
||||
matrix-swap-rows
|
||||
matrix-swap-columns
|
||||
matrix-add-scaled-row
|
||||
; reduction
|
||||
;; Gaussian elimination
|
||||
matrix-gauss-elim
|
||||
matrix-row-echelon
|
||||
; invariant
|
||||
;; Derived functions
|
||||
matrix-rank
|
||||
matrix-nullity
|
||||
matrix-determinant
|
||||
matrix-determinant/row-reduction ; for testing
|
||||
matrix-invertible?
|
||||
; solvers
|
||||
matrix-solve
|
||||
; spaces
|
||||
;; Spaces (TODO: null space, row space, left null space)
|
||||
matrix-column-space
|
||||
; projection
|
||||
;; Solving
|
||||
matrix-invertible?
|
||||
matrix-inverse
|
||||
matrix-solve
|
||||
;; Projection
|
||||
projection-on-orthogonal-basis
|
||||
projection-on-orthonormal-basis
|
||||
projection-on-subspace
|
||||
gram-schmidt-orthogonal
|
||||
gram-schmidt-orthonormal
|
||||
; factorization
|
||||
;; Decomposition
|
||||
matrix-lu
|
||||
matrix-qr
|
||||
)
|
||||
|
||||
;;;
|
||||
;;; Row and column
|
||||
;;;
|
||||
|
||||
(: matrix-scale-row : (Matrix Number) Integer Number -> (Matrix Number))
|
||||
(define (matrix-scale-row a i c)
|
||||
((inline-matrix-scale-row i c) a))
|
||||
|
||||
(define-syntax (inline-matrix-scale-row stx)
|
||||
(syntax-case stx ()
|
||||
[(_ i c)
|
||||
(syntax/loc stx
|
||||
(λ (arr)
|
||||
(define ds (array-shape arr))
|
||||
(define g (unsafe-array-proc arr))
|
||||
(cond
|
||||
[(< i 0)
|
||||
(error 'matrix-scale-row "row index must be non-negative, got ~a" i)]
|
||||
[(not (< i (vector-ref ds 0)))
|
||||
(error 'matrix-scale-row "row index must be smaller than the number of rows, got ~a" i)]
|
||||
[else
|
||||
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
|
||||
(if (= i (vector-ref js 0))
|
||||
(* c (g js))
|
||||
(g js))))])))]))
|
||||
|
||||
(: matrix-scale-column : (Matrix Number) Integer Number -> (Matrix Number))
|
||||
(define (matrix-scale-column a i c)
|
||||
((inline-matrix-scale-column i c) a))
|
||||
|
||||
(define-syntax (inline-matrix-scale-column stx)
|
||||
(syntax-case stx ()
|
||||
[(_ j c)
|
||||
(syntax/loc stx
|
||||
(λ (arr)
|
||||
(define ds (array-shape arr))
|
||||
(define g (unsafe-array-proc arr))
|
||||
(cond
|
||||
[(< j 0)
|
||||
(error 'matrix-scale-row "column index must be non-negative, got ~a" j)]
|
||||
[(not (< j (vector-ref ds 1)))
|
||||
(error 'matrix-scale-row
|
||||
"column index must be smaller than the number of rows, got ~a" j)]
|
||||
[else
|
||||
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
|
||||
(if (= j (vector-ref js 1))
|
||||
(* c (g js))
|
||||
(g js))))])))]))
|
||||
|
||||
(: matrix-swap-rows : (Matrix Number) Integer Integer -> (Matrix Number))
|
||||
(define (matrix-swap-rows a i j)
|
||||
((inline-matrix-swap-rows i j) a))
|
||||
|
||||
(define-syntax (inline-matrix-swap-rows stx)
|
||||
(syntax-case stx ()
|
||||
[(_ i j)
|
||||
(syntax/loc stx
|
||||
(λ (arr)
|
||||
(define ds (array-shape arr))
|
||||
(define g (unsafe-array-proc arr))
|
||||
(cond
|
||||
[(< i 0)
|
||||
(error 'matrix-swap-rows "row index must be non-negative, got ~a" i)]
|
||||
[(< j 0)
|
||||
(error 'matrix-swap-rows "row index must be non-negative, got ~a" j)]
|
||||
[(not (< i (vector-ref ds 0)))
|
||||
(error 'matrix-swap-rows "row index must be smaller than the number of rows, got ~a" i)]
|
||||
[(not (< j (vector-ref ds 0)))
|
||||
(error 'matrix-swap-rows "row index must be smaller than the number of rows, got ~a" j)]
|
||||
[else
|
||||
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
|
||||
(cond
|
||||
[(= i (vector-ref js 0))
|
||||
(g (vector j (vector-ref js 1)))]
|
||||
[(= j (vector-ref js 0))
|
||||
(g (vector i (vector-ref js 1)))]
|
||||
[else
|
||||
(g js)])))])))]))
|
||||
|
||||
(: matrix-swap-columns : (Matrix Number) Integer Integer -> (Matrix Number))
|
||||
(define (matrix-swap-columns a i j)
|
||||
((inline-matrix-swap-columns i j) a))
|
||||
|
||||
(define-syntax (inline-matrix-swap-columns stx)
|
||||
(syntax-case stx ()
|
||||
[(_ i j)
|
||||
(syntax/loc stx
|
||||
(λ (arr)
|
||||
(define ds (array-shape arr))
|
||||
(define g (unsafe-array-proc arr))
|
||||
(cond
|
||||
[(< i 0)
|
||||
(error 'matrix-swap-columns "column index must be non-negative, got ~a" i)]
|
||||
[(< j 0)
|
||||
(error 'matrix-swap-columns "column index must be non-negative, got ~a" j)]
|
||||
[(not (< i (vector-ref ds 0)))
|
||||
(error 'matrix-swap-columns
|
||||
"column index must be smaller than the number of columns, got ~a" i)]
|
||||
[(not (< j (vector-ref ds 0)))
|
||||
(error 'matrix-swap-columns
|
||||
"column index must be smaller than the number of columns, got ~a" j)]
|
||||
[else
|
||||
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
|
||||
(cond
|
||||
[(= i (vector-ref js 1))
|
||||
(g (vector j (vector-ref js 1)))]
|
||||
[(= j (vector-ref js 1))
|
||||
(g (vector i (vector-ref js 1)))]
|
||||
[else
|
||||
(g js)])))])))]))
|
||||
|
||||
(: matrix-add-scaled-row : (Matrix Number) Integer Number Integer -> (Matrix Number))
|
||||
(define (matrix-add-scaled-row a i c j)
|
||||
((inline-matrix-add-scaled-row i c j) a))
|
||||
|
||||
(: flmatrix-add-scaled-row : (Matrix Flonum) Index Flonum Index -> (Matrix Flonum))
|
||||
(define (flmatrix-add-scaled-row a i c j)
|
||||
((inline-matrix-add-scaled-row i c j) a))
|
||||
|
||||
(define-syntax (inline-matrix-add-scaled-row stx)
|
||||
(syntax-case stx ()
|
||||
[(_ i c j)
|
||||
(syntax/loc stx
|
||||
(λ (arr)
|
||||
(define ds (array-shape arr))
|
||||
(define g (unsafe-array-proc arr))
|
||||
(cond
|
||||
[(< i 0)
|
||||
(error 'matrix-add-scaled-row "row index must be non-negative, got ~a" i)]
|
||||
[(< j 0)
|
||||
(error 'matrix-add-scaled-row "row index must be non-negative, got ~a" j)]
|
||||
[(not (< i (vector-ref ds 0)))
|
||||
(error 'matrix-add-scaled-row
|
||||
"row index must be smaller than the number of rows, got ~a" i)]
|
||||
[(not (< j (vector-ref ds 0)))
|
||||
(error 'matrix-add-scaled-row
|
||||
"row index must be smaller than the number of rows, got ~a" j)]
|
||||
[else
|
||||
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
|
||||
(if (= i (vector-ref js 0))
|
||||
(+ (g js) (* c (g (vector j (vector-ref js 1)))))
|
||||
(g js))))])))]))
|
||||
|
||||
|
||||
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
|
||||
(define (unsafe-vector2d-ref vss i j)
|
||||
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
|
||||
|
@ -289,17 +139,17 @@
|
|||
M))
|
||||
|
||||
(: matrix-rank : (Matrix Number) -> Index)
|
||||
;; Returns the dimension of the column space (equiv. row space) of M
|
||||
(define (matrix-rank M)
|
||||
; TODO: Use QR or SVD instead for inexact matrices
|
||||
; See answer: http://scicomp.stackexchange.com/questions/1861/understanding-how-numpy-does-svd
|
||||
; rank = dimension of column space = dimension of row space
|
||||
(define n (matrix-num-cols M))
|
||||
(define-values (_ cols-without-pivot) (matrix-gauss-elim M))
|
||||
(assert (- n (length cols-without-pivot)) index?))
|
||||
|
||||
(: matrix-nullity : (Matrix Number) -> Index)
|
||||
;; Returns the dimension of the null space of M
|
||||
(define (matrix-nullity M)
|
||||
; nullity = dimension of null space
|
||||
(define-values (_ cols-without-pivot)
|
||||
(matrix-gauss-elim (ensure-matrix 'matrix-nullity M)))
|
||||
(length cols-without-pivot))
|
||||
|
@ -310,22 +160,27 @@
|
|||
(cond [(= j0 j1) Bs]
|
||||
[else (cons (submatrix M (::) (:: j0 j1)) Bs)]))
|
||||
|
||||
(: matrix-column-space (case-> ((Matrix Real) -> (Array Real))
|
||||
((Matrix Number) -> (Array Number))))
|
||||
(define (matrix-column-space M)
|
||||
(define n (matrix-num-cols M))
|
||||
(define-values (_ wps) (matrix-gauss-elim M))
|
||||
(cond [(empty? wps) M]
|
||||
[(= (length wps) n) (make-array (vector 0 n) 0)]
|
||||
[else
|
||||
(define next-j (first wps))
|
||||
(define Bs (maybe-cons-submatrix M 0 next-j empty))
|
||||
(let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs])
|
||||
(cond [(empty? wps)
|
||||
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
|
||||
[else
|
||||
(define next-j (first wps))
|
||||
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))]))
|
||||
(: matrix-column-space (All (A) (case-> ((Matrix Real) -> (Matrix Real))
|
||||
((Matrix Real) (-> A) -> (U A (Matrix Real)))
|
||||
((Matrix Number) -> (Matrix Number))
|
||||
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
|
||||
(define matrix-column-space
|
||||
(case-lambda
|
||||
[(M) (matrix-column-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))]
|
||||
[(M fail)
|
||||
(define n (matrix-num-cols M))
|
||||
(define-values (_ wps) (matrix-gauss-elim M))
|
||||
(cond [(empty? wps) M]
|
||||
[(= (length wps) n) (fail)]
|
||||
[else
|
||||
(define next-j (first wps))
|
||||
(define Bs (maybe-cons-submatrix M 0 next-j empty))
|
||||
(let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs])
|
||||
(cond [(empty? wps)
|
||||
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
|
||||
[else
|
||||
(define next-j (first wps))
|
||||
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])]))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Determinant
|
||||
|
@ -376,17 +231,13 @@
|
|||
(define (matrix-invertible? M)
|
||||
(not (zero? (matrix-determinant M))))
|
||||
|
||||
(: make-invertible-fail (Symbol (Matrix Any) -> (-> Nothing)))
|
||||
(define ((make-invertible-fail name M))
|
||||
(raise-argument-error name "matrix-invertible?" M))
|
||||
|
||||
(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real))
|
||||
((Matrix Real) (-> A) -> (U A (Matrix Real)))
|
||||
((Matrix Number) -> (Matrix Number))
|
||||
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
|
||||
(define matrix-inverse
|
||||
(case-lambda
|
||||
[(M) (matrix-inverse M (make-invertible-fail 'matrix-inverse M))]
|
||||
[(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))]
|
||||
[(M fail)
|
||||
(define m (square-matrix-size M))
|
||||
(define I (identity-matrix m))
|
||||
|
@ -402,7 +253,7 @@
|
|||
((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number))))))
|
||||
(define matrix-solve
|
||||
(case-lambda
|
||||
[(M B) (matrix-solve M B (make-invertible-fail 'matrix-solve M))]
|
||||
[(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))]
|
||||
[(M B fail)
|
||||
(define m (square-matrix-size M))
|
||||
(define-values (s t) (matrix-shape B))
|
||||
|
@ -421,63 +272,53 @@
|
|||
|
||||
;; An LU factorization exists iff Gaussian elimination can be done without row swaps.
|
||||
|
||||
(: matrix-lu :
|
||||
(Matrix Number) -> (U False (List (Matrix Number) (Matrix Number))))
|
||||
(define (matrix-lu M)
|
||||
(define-values (m _) (matrix-shape M))
|
||||
(define: ms : (Listof Number) '())
|
||||
(define V
|
||||
(let/ec: return : (U False (Matrix Number))
|
||||
(let: i-loop : (Matrix Number)
|
||||
([i : Integer 0]
|
||||
[V : (Matrix Number) M])
|
||||
(cond
|
||||
[(= i m) V]
|
||||
[else
|
||||
; Gauss: find non-zero element
|
||||
; LU: this has to be the first
|
||||
(let ([x_ii (matrix-ref V i i)])
|
||||
(cond
|
||||
[(zero? x_ii)
|
||||
(return #f)] ; no LU - factorization possible
|
||||
[else
|
||||
; remove elements below pivot
|
||||
(let j-loop ([j (+ i 1)] [V V])
|
||||
(cond
|
||||
[(= j m) (i-loop (+ i 1) V)]
|
||||
[else
|
||||
(let* ([x_ji (matrix-ref V j i)]
|
||||
[m_ij (/ x_ji x_ii)])
|
||||
(set! ms (cons m_ij ms))
|
||||
(j-loop (+ j 1)
|
||||
(if (zero? x_ji)
|
||||
V
|
||||
(matrix-add-scaled-row V j (- m_ij) i))))]))]))]))))
|
||||
|
||||
; Now M has been transformed to U.
|
||||
(if (eq? V #f)
|
||||
#f
|
||||
(let ()
|
||||
(define: L-matrix : (Vectorof Number) (make-vector (* m m) 0))
|
||||
; fill below diagonal
|
||||
(set! ms (reverse ms))
|
||||
(for*: ([j : Integer (in-range 0 m)]
|
||||
[i : Integer (in-range (+ j 1) m)])
|
||||
(vector-set! L-matrix (+ (* i m) j) (car ms))
|
||||
(set! ms (cdr ms)))
|
||||
; fill diagonal
|
||||
(for: ([i : Integer (in-range 0 m)])
|
||||
(vector-set! L-matrix (+ (* i m) i) 1))
|
||||
|
||||
(define: L : (Matrix Number)
|
||||
(let ([ds (array-shape M)])
|
||||
(unsafe-build-array
|
||||
ds (λ: ([js : (Vectorof Index)])
|
||||
(define i (unsafe-vector-ref js 0))
|
||||
(define j (unsafe-vector-ref js 1))
|
||||
(vector-ref L-matrix (+ (* i m) j))))))
|
||||
(list L V))))
|
||||
(: matrix-lu
|
||||
(All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
|
||||
((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real)))
|
||||
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
|
||||
((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number))))))
|
||||
(define matrix-lu
|
||||
(case-lambda
|
||||
[(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))]
|
||||
[(M fail)
|
||||
(define m (square-matrix-size M))
|
||||
(define rows (matrix->vector* M))
|
||||
;; Construct L in a weird way to prove to TR that it has the right type
|
||||
(define L (array->mutable-array (matrix-scale M (ann 0 Real))))
|
||||
;; Going to fill in the lower triangle by banging values into `ys'
|
||||
(define ys (mutable-array-data L))
|
||||
(let loop ([#{i : Nonnegative-Fixnum} 0])
|
||||
(cond
|
||||
[(i . fx< . m)
|
||||
;; Pivot must be on the diagonal
|
||||
(define pivot (unsafe-vector2d-ref rows i i))
|
||||
(cond
|
||||
[(zero? pivot) (values (fail) M)]
|
||||
[else
|
||||
;; Zero out everything below the pivot
|
||||
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
|
||||
(cond
|
||||
[(l . fx< . m)
|
||||
(define x_li (unsafe-vector2d-ref rows l i))
|
||||
(define y_li (/ x_li pivot))
|
||||
(unless (zero? x_li)
|
||||
;; Fill in lower triangle of L
|
||||
(unsafe-vector-set! ys (+ (* l m) i) y_li)
|
||||
;; Add row i, scaled
|
||||
(vector-scaled-add! (unsafe-vector-ref rows l)
|
||||
(unsafe-vector-ref rows i)
|
||||
(- y_li)))
|
||||
(l-loop (fx+ l 1))]
|
||||
[else
|
||||
(loop (fx+ i 1))]))])]
|
||||
[else
|
||||
;; L's lower triangle has been filled; now fill the diagonal with 1s
|
||||
(for: ([i : Integer (in-range 0 m)])
|
||||
(vector-set! ys (+ (* i m) i) 1))
|
||||
(values L (vector*->matrix rows))]))]))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Projections and orthogonalization
|
||||
|
||||
(: projection-on-orthogonal-basis :
|
||||
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
(make-array #(0 1) 0)
|
||||
(make-array #(0 0) 0)
|
||||
(make-array #(1 1 1) 0)))
|
||||
|
||||
#|
|
||||
;; ===================================================================================================
|
||||
;; Literal syntax
|
||||
|
||||
|
@ -796,6 +796,39 @@
|
|||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-inverse a))))
|
||||
|#
|
||||
;; ===================================================================================================
|
||||
;; LU decomposition
|
||||
|
||||
(let ([M (matrix [[ 1 1 0 3]
|
||||
[ 2 1 -1 1]
|
||||
[ 3 -1 -1 2]
|
||||
[-1 2 3 -1]])])
|
||||
(define-values (L V) (matrix-lu M))
|
||||
(check-equal? L (matrix [[ 1 0 0 0]
|
||||
[ 2 1 0 0]
|
||||
[ 3 4 1 0]
|
||||
[-1 -3 0 1]]))
|
||||
(check-equal? V (matrix [[1 1 0 3]
|
||||
[0 -1 -1 -5]
|
||||
[0 0 3 13]
|
||||
[0 0 0 -13]]))
|
||||
(check-equal? (matrix* L V) M))
|
||||
|
||||
(: matrix-l ((Matrix Number) -> Any))
|
||||
(define (matrix-l M)
|
||||
(define-values (L U) (matrix-lu M))
|
||||
L)
|
||||
|
||||
(check-exn exn:fail? (λ () (matrix-l (matrix [[1 1 0 2]
|
||||
[0 2 0 1]
|
||||
[1 0 0 0]
|
||||
[1 1 2 1]]))))
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (matrix-l (random-matrix 3 4))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-l (random-matrix 4 3))))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-l a))))
|
||||
|
||||
#|
|
||||
;; ===================================================================================================
|
||||
|
@ -877,30 +910,6 @@
|
|||
4 4 ((inst vector Number)
|
||||
2 4 9 11 0 0.0 2.23606797749979 2.23606797749979
|
||||
0 0 0.0 4.440892098500626e-16 0 0 0 0.0))))))
|
||||
(let ()
|
||||
(define M (list*->matrix '[[1 1 0 3]
|
||||
[2 1 -1 1]
|
||||
[3 -1 -1 2]
|
||||
[-1 2 3 -1]]))
|
||||
(define LU (matrix-lu M))
|
||||
(if (eq? LU #f)
|
||||
(list 'matrix-lu #f)
|
||||
(let ()
|
||||
(define L (if (list? LU) (first LU) #f))
|
||||
(define V (if (list? LU) (second LU) #f))
|
||||
(list
|
||||
'matrix-lu
|
||||
(equal? L (list*->matrix
|
||||
'[[1 0 0 0]
|
||||
[2 1 0 0]
|
||||
[3 4 1 0]
|
||||
[-1 -3 0 1]]))
|
||||
(equal? V (list*->matrix
|
||||
'[[1 1 0 3]
|
||||
[0 -1 -1 -5]
|
||||
[0 0 3 13]
|
||||
[0 0 0 -13]]))
|
||||
(equal? (matrix* L V) M)))))
|
||||
#;
|
||||
(begin
|
||||
"matrix-2d.rkt"
|
||||
|
|
Loading…
Reference in New Issue
Block a user