From fc02d40a66d7e1d18ec7eac222d982a580123f31 Mon Sep 17 00:00:00 2001 From: Neil Toronto Date: Sat, 22 Dec 2012 11:45:38 -0700 Subject: [PATCH] Faster LU decomposition --- .../math/private/matrix/matrix-operations.rkt | 319 +++++------------- collects/math/tests/matrix-tests.rkt | 59 ++-- 2 files changed, 114 insertions(+), 264 deletions(-) diff --git a/collects/math/private/matrix/matrix-operations.rkt b/collects/math/private/matrix/matrix-operations.rkt index 24cfee1c2b..1582d3f497 100644 --- a/collects/math/private/matrix/matrix-operations.rkt +++ b/collects/math/private/matrix/matrix-operations.rkt @@ -25,181 +25,31 @@ ; 5. Eigenvalues and eigenvectors (provide - matrix-inverse - ; row and column - matrix-scale-row - matrix-scale-column - matrix-swap-rows - matrix-swap-columns - matrix-add-scaled-row - ; reduction + ;; Gaussian elimination matrix-gauss-elim matrix-row-echelon - ; invariant + ;; Derived functions matrix-rank matrix-nullity matrix-determinant matrix-determinant/row-reduction ; for testing - matrix-invertible? - ; solvers - matrix-solve - ; spaces + ;; Spaces (TODO: null space, row space, left null space) matrix-column-space - ; projection + ;; Solving + matrix-invertible? + matrix-inverse + matrix-solve + ;; Projection projection-on-orthogonal-basis projection-on-orthonormal-basis projection-on-subspace gram-schmidt-orthogonal gram-schmidt-orthonormal - ; factorization + ;; Decomposition matrix-lu matrix-qr ) -;;; -;;; Row and column -;;; - -(: matrix-scale-row : (Matrix Number) Integer Number -> (Matrix Number)) -(define (matrix-scale-row a i c) - ((inline-matrix-scale-row i c) a)) - -(define-syntax (inline-matrix-scale-row stx) - (syntax-case stx () - [(_ i c) - (syntax/loc stx - (λ (arr) - (define ds (array-shape arr)) - (define g (unsafe-array-proc arr)) - (cond - [(< i 0) - (error 'matrix-scale-row "row index must be non-negative, got ~a" i)] - [(not (< i (vector-ref ds 0))) - (error 'matrix-scale-row "row index must be smaller than the number of rows, got ~a" i)] - [else - (unsafe-build-array ds (λ: ([js : (Vectorof Index)]) - (if (= i (vector-ref js 0)) - (* c (g js)) - (g js))))])))])) - -(: matrix-scale-column : (Matrix Number) Integer Number -> (Matrix Number)) -(define (matrix-scale-column a i c) - ((inline-matrix-scale-column i c) a)) - -(define-syntax (inline-matrix-scale-column stx) - (syntax-case stx () - [(_ j c) - (syntax/loc stx - (λ (arr) - (define ds (array-shape arr)) - (define g (unsafe-array-proc arr)) - (cond - [(< j 0) - (error 'matrix-scale-row "column index must be non-negative, got ~a" j)] - [(not (< j (vector-ref ds 1))) - (error 'matrix-scale-row - "column index must be smaller than the number of rows, got ~a" j)] - [else - (unsafe-build-array ds (λ: ([js : (Vectorof Index)]) - (if (= j (vector-ref js 1)) - (* c (g js)) - (g js))))])))])) - -(: matrix-swap-rows : (Matrix Number) Integer Integer -> (Matrix Number)) -(define (matrix-swap-rows a i j) - ((inline-matrix-swap-rows i j) a)) - -(define-syntax (inline-matrix-swap-rows stx) - (syntax-case stx () - [(_ i j) - (syntax/loc stx - (λ (arr) - (define ds (array-shape arr)) - (define g (unsafe-array-proc arr)) - (cond - [(< i 0) - (error 'matrix-swap-rows "row index must be non-negative, got ~a" i)] - [(< j 0) - (error 'matrix-swap-rows "row index must be non-negative, got ~a" j)] - [(not (< i (vector-ref ds 0))) - (error 'matrix-swap-rows "row index must be smaller than the number of rows, got ~a" i)] - [(not (< j (vector-ref ds 0))) - (error 'matrix-swap-rows "row index must be smaller than the number of rows, got ~a" j)] - [else - (unsafe-build-array ds (λ: ([js : (Vectorof Index)]) - (cond - [(= i (vector-ref js 0)) - (g (vector j (vector-ref js 1)))] - [(= j (vector-ref js 0)) - (g (vector i (vector-ref js 1)))] - [else - (g js)])))])))])) - -(: matrix-swap-columns : (Matrix Number) Integer Integer -> (Matrix Number)) -(define (matrix-swap-columns a i j) - ((inline-matrix-swap-columns i j) a)) - -(define-syntax (inline-matrix-swap-columns stx) - (syntax-case stx () - [(_ i j) - (syntax/loc stx - (λ (arr) - (define ds (array-shape arr)) - (define g (unsafe-array-proc arr)) - (cond - [(< i 0) - (error 'matrix-swap-columns "column index must be non-negative, got ~a" i)] - [(< j 0) - (error 'matrix-swap-columns "column index must be non-negative, got ~a" j)] - [(not (< i (vector-ref ds 0))) - (error 'matrix-swap-columns - "column index must be smaller than the number of columns, got ~a" i)] - [(not (< j (vector-ref ds 0))) - (error 'matrix-swap-columns - "column index must be smaller than the number of columns, got ~a" j)] - [else - (unsafe-build-array ds (λ: ([js : (Vectorof Index)]) - (cond - [(= i (vector-ref js 1)) - (g (vector j (vector-ref js 1)))] - [(= j (vector-ref js 1)) - (g (vector i (vector-ref js 1)))] - [else - (g js)])))])))])) - -(: matrix-add-scaled-row : (Matrix Number) Integer Number Integer -> (Matrix Number)) -(define (matrix-add-scaled-row a i c j) - ((inline-matrix-add-scaled-row i c j) a)) - -(: flmatrix-add-scaled-row : (Matrix Flonum) Index Flonum Index -> (Matrix Flonum)) -(define (flmatrix-add-scaled-row a i c j) - ((inline-matrix-add-scaled-row i c j) a)) - -(define-syntax (inline-matrix-add-scaled-row stx) - (syntax-case stx () - [(_ i c j) - (syntax/loc stx - (λ (arr) - (define ds (array-shape arr)) - (define g (unsafe-array-proc arr)) - (cond - [(< i 0) - (error 'matrix-add-scaled-row "row index must be non-negative, got ~a" i)] - [(< j 0) - (error 'matrix-add-scaled-row "row index must be non-negative, got ~a" j)] - [(not (< i (vector-ref ds 0))) - (error 'matrix-add-scaled-row - "row index must be smaller than the number of rows, got ~a" i)] - [(not (< j (vector-ref ds 0))) - (error 'matrix-add-scaled-row - "row index must be smaller than the number of rows, got ~a" j)] - [else - (unsafe-build-array ds (λ: ([js : (Vectorof Index)]) - (if (= i (vector-ref js 0)) - (+ (g js) (* c (g (vector j (vector-ref js 1))))) - (g js))))])))])) - - (: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A))) (define (unsafe-vector2d-ref vss i j) (unsafe-vector-ref (unsafe-vector-ref vss i) j)) @@ -289,17 +139,17 @@ M)) (: matrix-rank : (Matrix Number) -> Index) +;; Returns the dimension of the column space (equiv. row space) of M (define (matrix-rank M) ; TODO: Use QR or SVD instead for inexact matrices ; See answer: http://scicomp.stackexchange.com/questions/1861/understanding-how-numpy-does-svd - ; rank = dimension of column space = dimension of row space (define n (matrix-num-cols M)) (define-values (_ cols-without-pivot) (matrix-gauss-elim M)) (assert (- n (length cols-without-pivot)) index?)) (: matrix-nullity : (Matrix Number) -> Index) +;; Returns the dimension of the null space of M (define (matrix-nullity M) - ; nullity = dimension of null space (define-values (_ cols-without-pivot) (matrix-gauss-elim (ensure-matrix 'matrix-nullity M))) (length cols-without-pivot)) @@ -310,22 +160,27 @@ (cond [(= j0 j1) Bs] [else (cons (submatrix M (::) (:: j0 j1)) Bs)])) -(: matrix-column-space (case-> ((Matrix Real) -> (Array Real)) - ((Matrix Number) -> (Array Number)))) -(define (matrix-column-space M) - (define n (matrix-num-cols M)) - (define-values (_ wps) (matrix-gauss-elim M)) - (cond [(empty? wps) M] - [(= (length wps) n) (make-array (vector 0 n) 0)] - [else - (define next-j (first wps)) - (define Bs (maybe-cons-submatrix M 0 next-j empty)) - (let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs]) - (cond [(empty? wps) - (matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))] - [else - (define next-j (first wps)) - (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])) +(: matrix-column-space (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Real) (-> A) -> (U A (Matrix Real))) + ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) +(define matrix-column-space + (case-lambda + [(M) (matrix-column-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))] + [(M fail) + (define n (matrix-num-cols M)) + (define-values (_ wps) (matrix-gauss-elim M)) + (cond [(empty? wps) M] + [(= (length wps) n) (fail)] + [else + (define next-j (first wps)) + (define Bs (maybe-cons-submatrix M 0 next-j empty)) + (let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs]) + (cond [(empty? wps) + (matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))] + [else + (define next-j (first wps)) + (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])])) ;; =================================================================================================== ;; Determinant @@ -376,17 +231,13 @@ (define (matrix-invertible? M) (not (zero? (matrix-determinant M)))) -(: make-invertible-fail (Symbol (Matrix Any) -> (-> Nothing))) -(define ((make-invertible-fail name M)) - (raise-argument-error name "matrix-invertible?" M)) - (: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real)) ((Matrix Real) (-> A) -> (U A (Matrix Real))) ((Matrix Number) -> (Matrix Number)) ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) (define matrix-inverse (case-lambda - [(M) (matrix-inverse M (make-invertible-fail 'matrix-inverse M))] + [(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))] [(M fail) (define m (square-matrix-size M)) (define I (identity-matrix m)) @@ -402,7 +253,7 @@ ((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number)))))) (define matrix-solve (case-lambda - [(M B) (matrix-solve M B (make-invertible-fail 'matrix-solve M))] + [(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))] [(M B fail) (define m (square-matrix-size M)) (define-values (s t) (matrix-shape B)) @@ -421,63 +272,53 @@ ;; An LU factorization exists iff Gaussian elimination can be done without row swaps. -(: matrix-lu : - (Matrix Number) -> (U False (List (Matrix Number) (Matrix Number)))) -(define (matrix-lu M) - (define-values (m _) (matrix-shape M)) - (define: ms : (Listof Number) '()) - (define V - (let/ec: return : (U False (Matrix Number)) - (let: i-loop : (Matrix Number) - ([i : Integer 0] - [V : (Matrix Number) M]) - (cond - [(= i m) V] - [else - ; Gauss: find non-zero element - ; LU: this has to be the first - (let ([x_ii (matrix-ref V i i)]) - (cond - [(zero? x_ii) - (return #f)] ; no LU - factorization possible - [else - ; remove elements below pivot - (let j-loop ([j (+ i 1)] [V V]) - (cond - [(= j m) (i-loop (+ i 1) V)] - [else - (let* ([x_ji (matrix-ref V j i)] - [m_ij (/ x_ji x_ii)]) - (set! ms (cons m_ij ms)) - (j-loop (+ j 1) - (if (zero? x_ji) - V - (matrix-add-scaled-row V j (- m_ij) i))))]))]))])))) - - ; Now M has been transformed to U. - (if (eq? V #f) - #f - (let () - (define: L-matrix : (Vectorof Number) (make-vector (* m m) 0)) - ; fill below diagonal - (set! ms (reverse ms)) - (for*: ([j : Integer (in-range 0 m)] - [i : Integer (in-range (+ j 1) m)]) - (vector-set! L-matrix (+ (* i m) j) (car ms)) - (set! ms (cdr ms))) - ; fill diagonal - (for: ([i : Integer (in-range 0 m)]) - (vector-set! L-matrix (+ (* i m) i) 1)) - - (define: L : (Matrix Number) - (let ([ds (array-shape M)]) - (unsafe-build-array - ds (λ: ([js : (Vectorof Index)]) - (define i (unsafe-vector-ref js 0)) - (define j (unsafe-vector-ref js 1)) - (vector-ref L-matrix (+ (* i m) j)))))) - (list L V)))) +(: matrix-lu + (All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real))) + ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) + ((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number)))))) +(define matrix-lu + (case-lambda + [(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))] + [(M fail) + (define m (square-matrix-size M)) + (define rows (matrix->vector* M)) + ;; Construct L in a weird way to prove to TR that it has the right type + (define L (array->mutable-array (matrix-scale M (ann 0 Real)))) + ;; Going to fill in the lower triangle by banging values into `ys' + (define ys (mutable-array-data L)) + (let loop ([#{i : Nonnegative-Fixnum} 0]) + (cond + [(i . fx< . m) + ;; Pivot must be on the diagonal + (define pivot (unsafe-vector2d-ref rows i i)) + (cond + [(zero? pivot) (values (fail) M)] + [else + ;; Zero out everything below the pivot + (let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)]) + (cond + [(l . fx< . m) + (define x_li (unsafe-vector2d-ref rows l i)) + (define y_li (/ x_li pivot)) + (unless (zero? x_li) + ;; Fill in lower triangle of L + (unsafe-vector-set! ys (+ (* l m) i) y_li) + ;; Add row i, scaled + (vector-scaled-add! (unsafe-vector-ref rows l) + (unsafe-vector-ref rows i) + (- y_li))) + (l-loop (fx+ l 1))] + [else + (loop (fx+ i 1))]))])] + [else + ;; L's lower triangle has been filled; now fill the diagonal with 1s + (for: ([i : Integer (in-range 0 m)]) + (vector-set! ys (+ (* i m) i) 1)) + (values L (vector*->matrix rows))]))])) +;; =================================================================================================== +;; Projections and orthogonalization (: projection-on-orthogonal-basis : (Column Number) (Listof (Column Number)) -> (Result-Column Number)) diff --git a/collects/math/tests/matrix-tests.rkt b/collects/math/tests/matrix-tests.rkt index 24febab056..c58fba12dc 100644 --- a/collects/math/tests/matrix-tests.rkt +++ b/collects/math/tests/matrix-tests.rkt @@ -20,7 +20,7 @@ (make-array #(0 1) 0) (make-array #(0 0) 0) (make-array #(1 1 1) 0))) - +#| ;; =================================================================================================== ;; Literal syntax @@ -796,6 +796,39 @@ (for: ([a (in-list nonmatrices)]) (check-exn exn:fail:contract? (λ () (matrix-inverse a)))) +|# +;; =================================================================================================== +;; LU decomposition + +(let ([M (matrix [[ 1 1 0 3] + [ 2 1 -1 1] + [ 3 -1 -1 2] + [-1 2 3 -1]])]) + (define-values (L V) (matrix-lu M)) + (check-equal? L (matrix [[ 1 0 0 0] + [ 2 1 0 0] + [ 3 4 1 0] + [-1 -3 0 1]])) + (check-equal? V (matrix [[1 1 0 3] + [0 -1 -1 -5] + [0 0 3 13] + [0 0 0 -13]])) + (check-equal? (matrix* L V) M)) + +(: matrix-l ((Matrix Number) -> Any)) +(define (matrix-l M) + (define-values (L U) (matrix-lu M)) + L) + +(check-exn exn:fail? (λ () (matrix-l (matrix [[1 1 0 2] + [0 2 0 1] + [1 0 0 0] + [1 1 2 1]])))) + +(check-exn exn:fail:contract? (λ () (matrix-l (random-matrix 3 4)))) +(check-exn exn:fail:contract? (λ () (matrix-l (random-matrix 4 3)))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-l a)))) #| ;; =================================================================================================== @@ -877,30 +910,6 @@ 4 4 ((inst vector Number) 2 4 9 11 0 0.0 2.23606797749979 2.23606797749979 0 0 0.0 4.440892098500626e-16 0 0 0 0.0)))))) - (let () - (define M (list*->matrix '[[1 1 0 3] - [2 1 -1 1] - [3 -1 -1 2] - [-1 2 3 -1]])) - (define LU (matrix-lu M)) - (if (eq? LU #f) - (list 'matrix-lu #f) - (let () - (define L (if (list? LU) (first LU) #f)) - (define V (if (list? LU) (second LU) #f)) - (list - 'matrix-lu - (equal? L (list*->matrix - '[[1 0 0 0] - [2 1 0 0] - [3 4 1 0] - [-1 -3 0 1]])) - (equal? V (list*->matrix - '[[1 1 0 3] - [0 -1 -1 -5] - [0 0 3 13] - [0 0 0 -13]])) - (equal? (matrix* L V) M))))) #; (begin "matrix-2d.rkt"