Faster LU decomposition

This commit is contained in:
Neil Toronto 2012-12-22 11:45:38 -07:00
parent 01bb5c400a
commit fc02d40a66
2 changed files with 114 additions and 264 deletions

View File

@ -25,181 +25,31 @@
; 5. Eigenvalues and eigenvectors
(provide
matrix-inverse
; row and column
matrix-scale-row
matrix-scale-column
matrix-swap-rows
matrix-swap-columns
matrix-add-scaled-row
; reduction
;; Gaussian elimination
matrix-gauss-elim
matrix-row-echelon
; invariant
;; Derived functions
matrix-rank
matrix-nullity
matrix-determinant
matrix-determinant/row-reduction ; for testing
matrix-invertible?
; solvers
matrix-solve
; spaces
;; Spaces (TODO: null space, row space, left null space)
matrix-column-space
; projection
;; Solving
matrix-invertible?
matrix-inverse
matrix-solve
;; Projection
projection-on-orthogonal-basis
projection-on-orthonormal-basis
projection-on-subspace
gram-schmidt-orthogonal
gram-schmidt-orthonormal
; factorization
;; Decomposition
matrix-lu
matrix-qr
)
;;;
;;; Row and column
;;;
(: matrix-scale-row : (Matrix Number) Integer Number -> (Matrix Number))
(define (matrix-scale-row a i c)
((inline-matrix-scale-row i c) a))
(define-syntax (inline-matrix-scale-row stx)
(syntax-case stx ()
[(_ i c)
(syntax/loc stx
(λ (arr)
(define ds (array-shape arr))
(define g (unsafe-array-proc arr))
(cond
[(< i 0)
(error 'matrix-scale-row "row index must be non-negative, got ~a" i)]
[(not (< i (vector-ref ds 0)))
(error 'matrix-scale-row "row index must be smaller than the number of rows, got ~a" i)]
[else
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
(if (= i (vector-ref js 0))
(* c (g js))
(g js))))])))]))
(: matrix-scale-column : (Matrix Number) Integer Number -> (Matrix Number))
(define (matrix-scale-column a i c)
((inline-matrix-scale-column i c) a))
(define-syntax (inline-matrix-scale-column stx)
(syntax-case stx ()
[(_ j c)
(syntax/loc stx
(λ (arr)
(define ds (array-shape arr))
(define g (unsafe-array-proc arr))
(cond
[(< j 0)
(error 'matrix-scale-row "column index must be non-negative, got ~a" j)]
[(not (< j (vector-ref ds 1)))
(error 'matrix-scale-row
"column index must be smaller than the number of rows, got ~a" j)]
[else
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
(if (= j (vector-ref js 1))
(* c (g js))
(g js))))])))]))
(: matrix-swap-rows : (Matrix Number) Integer Integer -> (Matrix Number))
(define (matrix-swap-rows a i j)
((inline-matrix-swap-rows i j) a))
(define-syntax (inline-matrix-swap-rows stx)
(syntax-case stx ()
[(_ i j)
(syntax/loc stx
(λ (arr)
(define ds (array-shape arr))
(define g (unsafe-array-proc arr))
(cond
[(< i 0)
(error 'matrix-swap-rows "row index must be non-negative, got ~a" i)]
[(< j 0)
(error 'matrix-swap-rows "row index must be non-negative, got ~a" j)]
[(not (< i (vector-ref ds 0)))
(error 'matrix-swap-rows "row index must be smaller than the number of rows, got ~a" i)]
[(not (< j (vector-ref ds 0)))
(error 'matrix-swap-rows "row index must be smaller than the number of rows, got ~a" j)]
[else
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
(cond
[(= i (vector-ref js 0))
(g (vector j (vector-ref js 1)))]
[(= j (vector-ref js 0))
(g (vector i (vector-ref js 1)))]
[else
(g js)])))])))]))
(: matrix-swap-columns : (Matrix Number) Integer Integer -> (Matrix Number))
(define (matrix-swap-columns a i j)
((inline-matrix-swap-columns i j) a))
(define-syntax (inline-matrix-swap-columns stx)
(syntax-case stx ()
[(_ i j)
(syntax/loc stx
(λ (arr)
(define ds (array-shape arr))
(define g (unsafe-array-proc arr))
(cond
[(< i 0)
(error 'matrix-swap-columns "column index must be non-negative, got ~a" i)]
[(< j 0)
(error 'matrix-swap-columns "column index must be non-negative, got ~a" j)]
[(not (< i (vector-ref ds 0)))
(error 'matrix-swap-columns
"column index must be smaller than the number of columns, got ~a" i)]
[(not (< j (vector-ref ds 0)))
(error 'matrix-swap-columns
"column index must be smaller than the number of columns, got ~a" j)]
[else
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
(cond
[(= i (vector-ref js 1))
(g (vector j (vector-ref js 1)))]
[(= j (vector-ref js 1))
(g (vector i (vector-ref js 1)))]
[else
(g js)])))])))]))
(: matrix-add-scaled-row : (Matrix Number) Integer Number Integer -> (Matrix Number))
(define (matrix-add-scaled-row a i c j)
((inline-matrix-add-scaled-row i c j) a))
(: flmatrix-add-scaled-row : (Matrix Flonum) Index Flonum Index -> (Matrix Flonum))
(define (flmatrix-add-scaled-row a i c j)
((inline-matrix-add-scaled-row i c j) a))
(define-syntax (inline-matrix-add-scaled-row stx)
(syntax-case stx ()
[(_ i c j)
(syntax/loc stx
(λ (arr)
(define ds (array-shape arr))
(define g (unsafe-array-proc arr))
(cond
[(< i 0)
(error 'matrix-add-scaled-row "row index must be non-negative, got ~a" i)]
[(< j 0)
(error 'matrix-add-scaled-row "row index must be non-negative, got ~a" j)]
[(not (< i (vector-ref ds 0)))
(error 'matrix-add-scaled-row
"row index must be smaller than the number of rows, got ~a" i)]
[(not (< j (vector-ref ds 0)))
(error 'matrix-add-scaled-row
"row index must be smaller than the number of rows, got ~a" j)]
[else
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
(if (= i (vector-ref js 0))
(+ (g js) (* c (g (vector j (vector-ref js 1)))))
(g js))))])))]))
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
(define (unsafe-vector2d-ref vss i j)
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
@ -289,17 +139,17 @@
M))
(: matrix-rank : (Matrix Number) -> Index)
;; Returns the dimension of the column space (equiv. row space) of M
(define (matrix-rank M)
; TODO: Use QR or SVD instead for inexact matrices
; See answer: http://scicomp.stackexchange.com/questions/1861/understanding-how-numpy-does-svd
; rank = dimension of column space = dimension of row space
(define n (matrix-num-cols M))
(define-values (_ cols-without-pivot) (matrix-gauss-elim M))
(assert (- n (length cols-without-pivot)) index?))
(: matrix-nullity : (Matrix Number) -> Index)
;; Returns the dimension of the null space of M
(define (matrix-nullity M)
; nullity = dimension of null space
(define-values (_ cols-without-pivot)
(matrix-gauss-elim (ensure-matrix 'matrix-nullity M)))
(length cols-without-pivot))
@ -310,22 +160,27 @@
(cond [(= j0 j1) Bs]
[else (cons (submatrix M (::) (:: j0 j1)) Bs)]))
(: matrix-column-space (case-> ((Matrix Real) -> (Array Real))
((Matrix Number) -> (Array Number))))
(define (matrix-column-space M)
(define n (matrix-num-cols M))
(define-values (_ wps) (matrix-gauss-elim M))
(cond [(empty? wps) M]
[(= (length wps) n) (make-array (vector 0 n) 0)]
[else
(define next-j (first wps))
(define Bs (maybe-cons-submatrix M 0 next-j empty))
(let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs])
(cond [(empty? wps)
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
[else
(define next-j (first wps))
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))]))
(: matrix-column-space (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-column-space
(case-lambda
[(M) (matrix-column-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))]
[(M fail)
(define n (matrix-num-cols M))
(define-values (_ wps) (matrix-gauss-elim M))
(cond [(empty? wps) M]
[(= (length wps) n) (fail)]
[else
(define next-j (first wps))
(define Bs (maybe-cons-submatrix M 0 next-j empty))
(let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs])
(cond [(empty? wps)
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
[else
(define next-j (first wps))
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])]))
;; ===================================================================================================
;; Determinant
@ -376,17 +231,13 @@
(define (matrix-invertible? M)
(not (zero? (matrix-determinant M))))
(: make-invertible-fail (Symbol (Matrix Any) -> (-> Nothing)))
(define ((make-invertible-fail name M))
(raise-argument-error name "matrix-invertible?" M))
(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-inverse
(case-lambda
[(M) (matrix-inverse M (make-invertible-fail 'matrix-inverse M))]
[(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))]
[(M fail)
(define m (square-matrix-size M))
(define I (identity-matrix m))
@ -402,7 +253,7 @@
((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-solve
(case-lambda
[(M B) (matrix-solve M B (make-invertible-fail 'matrix-solve M))]
[(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))]
[(M B fail)
(define m (square-matrix-size M))
(define-values (s t) (matrix-shape B))
@ -421,63 +272,53 @@
;; An LU factorization exists iff Gaussian elimination can be done without row swaps.
(: matrix-lu :
(Matrix Number) -> (U False (List (Matrix Number) (Matrix Number))))
(define (matrix-lu M)
(define-values (m _) (matrix-shape M))
(define: ms : (Listof Number) '())
(define V
(let/ec: return : (U False (Matrix Number))
(let: i-loop : (Matrix Number)
([i : Integer 0]
[V : (Matrix Number) M])
(cond
[(= i m) V]
[else
; Gauss: find non-zero element
; LU: this has to be the first
(let ([x_ii (matrix-ref V i i)])
(cond
[(zero? x_ii)
(return #f)] ; no LU - factorization possible
[else
; remove elements below pivot
(let j-loop ([j (+ i 1)] [V V])
(cond
[(= j m) (i-loop (+ i 1) V)]
[else
(let* ([x_ji (matrix-ref V j i)]
[m_ij (/ x_ji x_ii)])
(set! ms (cons m_ij ms))
(j-loop (+ j 1)
(if (zero? x_ji)
V
(matrix-add-scaled-row V j (- m_ij) i))))]))]))]))))
; Now M has been transformed to U.
(if (eq? V #f)
#f
(let ()
(define: L-matrix : (Vectorof Number) (make-vector (* m m) 0))
; fill below diagonal
(set! ms (reverse ms))
(for*: ([j : Integer (in-range 0 m)]
[i : Integer (in-range (+ j 1) m)])
(vector-set! L-matrix (+ (* i m) j) (car ms))
(set! ms (cdr ms)))
; fill diagonal
(for: ([i : Integer (in-range 0 m)])
(vector-set! L-matrix (+ (* i m) i) 1))
(define: L : (Matrix Number)
(let ([ds (array-shape M)])
(unsafe-build-array
ds (λ: ([js : (Vectorof Index)])
(define i (unsafe-vector-ref js 0))
(define j (unsafe-vector-ref js 1))
(vector-ref L-matrix (+ (* i m) j))))))
(list L V))))
(: matrix-lu
(All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number))))))
(define matrix-lu
(case-lambda
[(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))]
[(M fail)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
;; Construct L in a weird way to prove to TR that it has the right type
(define L (array->mutable-array (matrix-scale M (ann 0 Real))))
;; Going to fill in the lower triangle by banging values into `ys'
(define ys (mutable-array-data L))
(let loop ([#{i : Nonnegative-Fixnum} 0])
(cond
[(i . fx< . m)
;; Pivot must be on the diagonal
(define pivot (unsafe-vector2d-ref rows i i))
(cond
[(zero? pivot) (values (fail) M)]
[else
;; Zero out everything below the pivot
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond
[(l . fx< . m)
(define x_li (unsafe-vector2d-ref rows l i))
(define y_li (/ x_li pivot))
(unless (zero? x_li)
;; Fill in lower triangle of L
(unsafe-vector-set! ys (+ (* l m) i) y_li)
;; Add row i, scaled
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- y_li)))
(l-loop (fx+ l 1))]
[else
(loop (fx+ i 1))]))])]
[else
;; L's lower triangle has been filled; now fill the diagonal with 1s
(for: ([i : Integer (in-range 0 m)])
(vector-set! ys (+ (* i m) i) 1))
(values L (vector*->matrix rows))]))]))
;; ===================================================================================================
;; Projections and orthogonalization
(: projection-on-orthogonal-basis :
(Column Number) (Listof (Column Number)) -> (Result-Column Number))

View File

@ -20,7 +20,7 @@
(make-array #(0 1) 0)
(make-array #(0 0) 0)
(make-array #(1 1 1) 0)))
#|
;; ===================================================================================================
;; Literal syntax
@ -796,6 +796,39 @@
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-inverse a))))
|#
;; ===================================================================================================
;; LU decomposition
(let ([M (matrix [[ 1 1 0 3]
[ 2 1 -1 1]
[ 3 -1 -1 2]
[-1 2 3 -1]])])
(define-values (L V) (matrix-lu M))
(check-equal? L (matrix [[ 1 0 0 0]
[ 2 1 0 0]
[ 3 4 1 0]
[-1 -3 0 1]]))
(check-equal? V (matrix [[1 1 0 3]
[0 -1 -1 -5]
[0 0 3 13]
[0 0 0 -13]]))
(check-equal? (matrix* L V) M))
(: matrix-l ((Matrix Number) -> Any))
(define (matrix-l M)
(define-values (L U) (matrix-lu M))
L)
(check-exn exn:fail? (λ () (matrix-l (matrix [[1 1 0 2]
[0 2 0 1]
[1 0 0 0]
[1 1 2 1]]))))
(check-exn exn:fail:contract? (λ () (matrix-l (random-matrix 3 4))))
(check-exn exn:fail:contract? (λ () (matrix-l (random-matrix 4 3))))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-l a))))
#|
;; ===================================================================================================
@ -877,30 +910,6 @@
4 4 ((inst vector Number)
2 4 9 11 0 0.0 2.23606797749979 2.23606797749979
0 0 0.0 4.440892098500626e-16 0 0 0 0.0))))))
(let ()
(define M (list*->matrix '[[1 1 0 3]
[2 1 -1 1]
[3 -1 -1 2]
[-1 2 3 -1]]))
(define LU (matrix-lu M))
(if (eq? LU #f)
(list 'matrix-lu #f)
(let ()
(define L (if (list? LU) (first LU) #f))
(define V (if (list? LU) (second LU) #f))
(list
'matrix-lu
(equal? L (list*->matrix
'[[1 0 0 0]
[2 1 0 0]
[3 4 1 0]
[-1 -3 0 1]]))
(equal? V (list*->matrix
'[[1 1 0 3]
[0 -1 -1 -5]
[0 0 3 13]
[0 0 0 -13]]))
(equal? (matrix* L V) M)))))
#;
(begin
"matrix-2d.rkt"