Moar matrix review/refactoring

* Consolidated Gauss and Gauss-Jordan elimination

* Fixed Gaussian elimination to return all indexes for pivotless columns,
  not just those < m

* Consolidated `matrix-row-echelon' and `matrix-reduced-row-echelon'

* Specialized row reduction for determinants; removed option to not do
  partial pivoting (it's never necessary otherwise)

* Added `matrix-invertible?'

* Removed `matrix-solve-many'; now `matrix-solve' solves for multiple
  columns

* Gave `matrix-inverse' and `matrix-solve' optional failure thunk arguments

* Made some functions that return multiple columns return arrays instead
  (i.e. `matrix-column-space')

* Added more tests
This commit is contained in:
Neil Toronto 2012-12-21 22:59:21 -07:00
parent 3bc4c1ffdc
commit 1aebd171c5
3 changed files with 325 additions and 333 deletions

View File

@ -2,6 +2,7 @@
(require racket/fixnum
racket/list
racket/match
math/array
(only-in typed/racket conjugate)
"../unsafe.rkt"
@ -12,11 +13,12 @@
"matrix-arithmetic.rkt"
"matrix-basic.rkt"
"matrix-column.rkt"
"utils.rkt"
(for-syntax racket))
; TODO:
; 1. compute null space from QR factorization
; (better numerical stability than from Gauss elimnation)
; (better numerical stability than from Gauss elimination)
; 2. S+N decomposition
; 3. Linear least squares problems (data fitting)
; 4. Pseudo inverse
@ -31,19 +33,16 @@
matrix-swap-columns
matrix-add-scaled-row
; reduction
matrix-gauss-eliminate
matrix-gauss-jordan-eliminate
matrix-row-echelon-form
matrix-reduced-row-echelon-form
matrix-gauss-elim
matrix-row-echelon
; invariant
matrix-rank
matrix-nullity
matrix-determinant
; spaces
;matrix-column+null-space
matrix-determinant/row-reduction ; for testing
matrix-invertible?
; solvers
matrix-solve
matrix-solve-many
; spaces
matrix-column-space
; projection
@ -98,7 +97,8 @@
[(< j 0)
(error 'matrix-scale-row "column index must be non-negative, got ~a" j)]
[(not (< j (vector-ref ds 1)))
(error 'matrix-scale-row "column index must be smaller than the number of rows, got ~a" j)]
(error 'matrix-scale-row
"column index must be smaller than the number of rows, got ~a" j)]
[else
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
(if (= j (vector-ref js 1))
@ -152,9 +152,11 @@
[(< j 0)
(error 'matrix-swap-columns "column index must be non-negative, got ~a" j)]
[(not (< i (vector-ref ds 0)))
(error 'matrix-swap-columns "column index must be smaller than the number of columns, got ~a" i)]
(error 'matrix-swap-columns
"column index must be smaller than the number of columns, got ~a" i)]
[(not (< j (vector-ref ds 0)))
(error 'matrix-swap-columns "column index must be smaller than the number of columns, got ~a" j)]
(error 'matrix-swap-columns
"column index must be smaller than the number of columns, got ~a" j)]
[else
(unsafe-build-array ds (λ: ([js : (Vectorof Index)])
(cond
@ -198,279 +200,226 @@
(g js))))])))]))
;;; GAUSS ELIMINATION / ROW ECHELON FORM
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
(define (unsafe-vector2d-ref vss i j)
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
(: find-partial-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index))
((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index))))
;; ===================================================================================================
;; Gaussian elimination
(: find-partial-pivot
(case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real))
((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number))))
;; Find the element with maximum magnitude in a column
(define (find-partial-pivot rows m i j)
(let loop ([#{l : Nonnegative-Fixnum} i] [#{max-current : Real} -inf.0] [#{max-index : Index} i])
(define l (fx+ i 1))
(define pivot (unsafe-vector2d-ref rows i j))
(define mag-pivot (magnitude pivot))
(let loop ([#{l : Nonnegative-Fixnum} l] [#{p : Index} i] [pivot pivot] [mag-pivot mag-pivot])
(cond [(l . fx< . m)
(define v (magnitude (unsafe-vector2d-ref rows l j)))
(cond [(> v max-current) (loop (fx+ l 1) v l)]
[else (loop (fx+ l 1) max-current max-index)])]
[else max-index])))
(define new-pivot (unsafe-vector2d-ref rows l j))
(define mag-new-pivot (magnitude new-pivot))
(cond [(mag-new-pivot . > . mag-pivot) (loop (fx+ l 1) l new-pivot mag-new-pivot)]
[else (loop (fx+ l 1) p pivot mag-pivot)])]
[else (values p pivot)])))
(: find-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index))
((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index))))
;; Find a non-zero element in a column
(define (find-pivot rows m i j)
(let loop ([#{l : Nonnegative-Fixnum} i])
(cond [(l . fx>= . m) #f]
[(not (zero? (unsafe-vector2d-ref rows l j))) l]
[else (loop (fx+ l 1))])))
(: elim-rows!
(case-> ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void)
((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void)))
(define (elim-rows! rows m i j pivot start)
(let loop ([#{l : Nonnegative-Fixnum} start])
(when (l . fx< . m)
(unless (l . fx= . i)
(define x_lj (unsafe-vector2d-ref rows l j))
(unless (zero? x_lj)
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- (/ x_lj pivot)))))
(loop (fx+ l 1)))))
(: matrix-gauss-eliminate :
(case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Boolean -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Boolean Boolean -> (Values (Matrix Real) (Listof Index)))
((Matrix Number) -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Index)))))
;; Returns the result of Gaussian elimination and a list of column indexes that had no pivot value
;; If `reduced?' is #t, the result is in *reduced* row-echelon form, and is unique (up to
;; floating-point error)
;; If `partial-pivoting?' is #t, the largest value in each column is used as the pivot
(define (matrix-gauss-eliminate M [reduced? #f] [partial-pivoting? #t])
(: matrix-gauss-elim (case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Any -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Any Any -> (Values (Matrix Real) (Listof Index)))
((Matrix Number) -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index)))))
(define (matrix-gauss-elim M [jordan? #f] [unitize-pivot-row? #f])
(define-values (m n) (matrix-shape M))
(define rows (matrix->vector* M))
(let loop ([#{i : Nonnegative-Fixnum} 0]
[#{j : Nonnegative-Fixnum} 0]
[#{without-pivot : (Listof Index)} '()])
[#{without-pivot : (Listof Index)} empty])
(cond
[(and (i . fx< . m) (j . fx< . n))
;; Find the row with the pivot value
(define p (cond [partial-pivoting? (find-partial-pivot rows m i j)]
[else (find-pivot rows m i j)]))
(define pivot (if p (unsafe-vector2d-ref rows p j) 0))
(cond
[(or (not p) (zero? pivot)) ; didn't find pivot?
(loop i (fx+ j 1) (cons j without-pivot))]
[else
(vector-swap! rows i p) ; swap pivot row with current
(let ([pivot (cond [reduced? (vector-scale! (unsafe-vector-ref rows i) (/ pivot))
(/ pivot pivot)]
[else pivot])])
;; Remove elements below pivot by scaling and adding the pivot's row to each row below
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond [(l . fx< . m)
(define x_lj (unsafe-vector2d-ref rows l j))
(unless (zero? x_lj)
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- (/ x_lj pivot))))
(l-loop (fx+ l 1))]
[else
(loop (fx+ i 1) (fx+ j 1) without-pivot)])))])]
[else
[(j . fx>= . n)
(values (vector*->matrix rows)
(reverse without-pivot))])))
(reverse without-pivot))]
[(i . fx>= . m)
(values (vector*->matrix rows)
;; None of the rest of the columns can have pivots
(let loop ([#{j : Nonnegative-Fixnum} j] [without-pivot without-pivot])
(cond [(j . fx< . n) (loop (fx+ j 1) (cons j without-pivot))]
[else (reverse without-pivot)])))]
[else
(define-values (p pivot) (find-partial-pivot rows m i j))
(cond
[(zero? pivot) (loop i (fx+ j 1) (cons j without-pivot))]
[else
;; Swap pivot row with current
(vector-swap! rows i p)
;; Possibly unitize the new current row
(let ([pivot (if unitize-pivot-row?
(begin (vector-scale! (unsafe-vector-ref rows i) (/ pivot))
1)
pivot)])
(elim-rows! rows m i j pivot (if jordan? 0 (fx+ i 1)))
(loop (fx+ i 1) (fx+ j 1) without-pivot))])])))
(: matrix-rank : (Matrix Number) -> Integer)
;; ===================================================================================================
;; Simple functions derived from Gaussian elimination
(: matrix-row-echelon
(case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) Any -> (Matrix Real))
((Matrix Real) Any Any -> (Matrix Real))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Any -> (Matrix Number))
((Matrix Number) Any Any -> (Matrix Number))))
(define (matrix-row-echelon M [jordan? #f] [unitize-pivot-row? jordan?])
(let-values ([(M _) (matrix-gauss-elim M jordan? unitize-pivot-row?)])
M))
(: matrix-rank : (Matrix Number) -> Index)
(define (matrix-rank M)
; TODO: Use QR or SVD instead for inexact matrices
; See answer: http://scicomp.stackexchange.com/questions/1861/understanding-how-numpy-does-svd
; rank = dimension of column space = dimension of row space
(define-values (m n) (matrix-shape M))
(define-values (_ cols-without-pivot) (matrix-gauss-eliminate M))
(- n (length cols-without-pivot)))
(define n (matrix-num-cols M))
(define-values (_ cols-without-pivot) (matrix-gauss-elim M))
(assert (- n (length cols-without-pivot)) index?))
(: matrix-nullity : (Matrix Number) -> Integer)
(: matrix-nullity : (Matrix Number) -> Index)
(define (matrix-nullity M)
; nullity = dimension of null space
(define-values (m n) (matrix-shape M))
(define-values (_ cols-without-pivot) (matrix-gauss-eliminate M))
(define-values (_ cols-without-pivot)
(matrix-gauss-elim (ensure-matrix 'matrix-nullity M)))
(length cols-without-pivot))
(: matrix-determinant : (Matrix Number) -> Number)
(: maybe-cons-submatrix (All (A) ((Matrix A) Nonnegative-Fixnum Nonnegative-Fixnum (Listof (Matrix A))
-> (Listof (Matrix A)))))
(define (maybe-cons-submatrix M j0 j1 Bs)
(cond [(= j0 j1) Bs]
[else (cons (submatrix M (::) (:: j0 j1)) Bs)]))
(: matrix-column-space (case-> ((Matrix Real) -> (Array Real))
((Matrix Number) -> (Array Number))))
(define (matrix-column-space M)
(define n (matrix-num-cols M))
(define-values (_ wps) (matrix-gauss-elim M))
(cond [(empty? wps) M]
[(= (length wps) n) (make-array (vector 0 n) 0)]
[else
(define next-j (first wps))
(define Bs (maybe-cons-submatrix M 0 next-j empty))
(let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs])
(cond [(empty? wps)
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
[else
(define next-j (first wps))
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))]))
;; ===================================================================================================
;; Determinant
(: matrix-determinant (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number)))
(define (matrix-determinant M)
(define-values (m n) (matrix-shape M))
(define m (square-matrix-size M))
(cond
[(= m 1) (matrix-ref M 0 0)]
[(= m 2) (let ([a (matrix-ref M 0 0)]
[b (matrix-ref M 0 1)]
[c (matrix-ref M 1 0)]
[d (matrix-ref M 1 1)])
(- (* a d) (* b c)))]
[(= m 3) (let ([a (matrix-ref M 0 0)]
[b (matrix-ref M 0 1)]
[c (matrix-ref M 0 2)]
[d (matrix-ref M 1 0)]
[e (matrix-ref M 1 1)]
[f (matrix-ref M 1 2)]
[g (matrix-ref M 2 0)]
[h (matrix-ref M 2 1)]
[i (matrix-ref M 2 2)])
(+ (* a (- (* e i) (* f h)))
(* (- b) (- (* d i) (* f g)))
(* c (- (* d h) (* e g)))))]
[else
(let-values ([(M _) (matrix-gauss-eliminate M #f #f)])
; TODO: #f #f turns off partial pivoting
#; (for/product: : Number ([i (in-range 0 m)])
(matrix-ref M i i))
(let ()
(define: product : Number 1)
(for: ([i : Integer (in-range 0 m 1)])
(set! product (* product (matrix-ref M i i))))
product))]))
[(= m 1) (matrix-ref M 0 0)]
[(= m 2) (match-define (vector a b c d)
(mutable-array-data (array->mutable-array M)))
(- (* a d) (* b c))]
[(= m 3) (match-define (vector a b c d e f g h i)
(mutable-array-data (array->mutable-array M)))
(+ (* a (- (* e i) (* f h)))
(* (- b) (- (* d i) (* f g)))
(* c (- (* d h) (* e g))))]
[else
(matrix-determinant/row-reduction M)]))
(: matrix-column-space : (Matrix Number) -> (Listof (Matrix Number)))
; Returns
; 1) a list of column vectors spanning the column space
; 2) a list of column vectors spanning the null space
(define (matrix-column-space M)
(define-values (m n) (matrix-shape M))
(: M1 (Matrix Number))
(: cols-without-pivot (Listof Integer))
(define-values (M1 cols-without-pivot) (matrix-gauss-eliminate M #t))
(set! M1 (array->mutable-array M1))
(define: column-space : (Listof (Matrix Number))
(for/list:
([i : Index n]
#:when (not (member i cols-without-pivot)))
(matrix-col M1 i)))
column-space)
(: matrix-row-echelon-form :
(case-> ((Matrix Number) Boolean -> (Matrix Number))
((Matrix Number) Boolean -> (Matrix Number))
((Matrix Number) -> (Matrix Number))))
(define (matrix-row-echelon-form M [unitize-pivot-row? #f])
(let-values ([(M wp) (matrix-gauss-eliminate M unitize-pivot-row?)])
M))
(: matrix-gauss-jordan-eliminate :
(case-> ((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Integer)))
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer)))
((Matrix Number) -> (Values (Matrix Number) (Listof Integer)))))
(define (matrix-gauss-jordan-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t])
(define-values (m n) (matrix-shape M))
(: loop : (Integer Integer (Matrix Number) Integer (Listof Integer)
-> (Values (Matrix Number) (Listof Integer))))
(define (loop i j ; i from 0 to m
M
k ; count rows without pivot
without-pivot)
(: matrix-determinant/row-reduction (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number)))
(define (matrix-determinant/row-reduction M)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
(let loop ([#{i : Nonnegative-Fixnum} 0] [#{sign : Real} 1])
(cond
[(or (= i m) (= j n)) (values M without-pivot)]
[else
; find row to become pivot
(define p
(if partial-pivoting?
; find element with maximal absolute value
(let: max-loop : (U False Integer)
([l : Integer i] ; i<=l<m
[max-current : Real -inf.0]
[max-index : Integer i])
(cond
[(= l m) max-index]
[else
(let ([v (magnitude (matrix-ref M l j))])
(if (> (magnitude (matrix-ref M l j)) max-current)
(max-loop (+ l 1) v l)
(max-loop (+ l 1) max-current max-index)))]))
; find non-zero element in column
(let: first-loop : (U False Integer)
([l : Integer i]) ; i<=l<m
(cond
[(= l m) #f]
[(not (zero? (matrix-ref M l j))) l]
[else (first-loop (+ l 1))]))))
[(i . fx< . m)
(define-values (p pivot) (find-partial-pivot rows m i i))
(cond
[(eq? p #f)
; no pivot found - this implies the matrix is singular (not invertible)
(loop i (+ j 1) M (+ k 1) (cons j without-pivot))]
[else
; swap if neccessary
(let* ([M (if (= i p) M (matrix-swap-rows M i p))]
; now we now (i,j) is a pivot
[M ; maybe scale row
(if unitize-pivot-row?
(let ([pivot (matrix-ref M i j)])
(if (zero? pivot)
M
(matrix-scale-row M i (/ pivot))))
M)])
(let ([pivot (matrix-ref M i j)])
; remove elements above and below pivot
(let l-loop ([l 0] [M M])
(cond
[(= l m) (loop (+ i 1) (+ j 1) M k without-pivot)]
[(= l i) (l-loop (+ l 1) M)]
[else
(let ([x_lj (matrix-ref M l j)])
(l-loop (+ l 1)
(if (zero? x_lj)
M
(matrix-add-scaled-row M l (- (/ x_lj pivot)) i))))]))))])]))
(let-values ([(M without) (loop 0 0 M 0 '())])
(values M without)))
[(zero? pivot) 0] ; no pivot means non-invertible matrix
[else
(vector-swap! rows i p) ; negates determinant if i != p
(elim-rows! rows m i i pivot (fx+ i 1)) ; doesn't change the determinant
(loop (fx+ i 1) (if (= i p) sign (* -1 sign)))])]
[else
(define prod (unsafe-vector2d-ref rows 0 0))
(let loop ([#{i : Nonnegative-Fixnum} 1] [prod prod])
(cond [(i . fx< . m)
(loop (fx+ i 1) (* prod (unsafe-vector2d-ref rows i i)))]
[else (* prod sign)]))])))
(: matrix-reduced-row-echelon-form :
(case-> ((Matrix Number) Boolean -> (Matrix Number))
((Matrix Number) Boolean -> (Matrix Number))
((Matrix Number) -> (Matrix Number))))
(define (matrix-reduced-row-echelon-form M [unitize-pivot-row? #f])
(let-values ([(M wp) (matrix-gauss-jordan-eliminate M unitize-pivot-row?)])
M))
;; ===================================================================================================
;; Inversion and solving linear systems
(: matrix-inverse : (Matrix Number) -> (Matrix Number))
(define (matrix-inverse M)
(define-values (m n) (matrix-shape M))
(unless (= m n) (error 'matrix-inverse "matrix not square"))
(let ([MI (matrix-augment (list M (identity-matrix m)))])
(define 2m (* 2 m))
(if (index? 2m)
(submatrix (matrix-reduced-row-echelon-form MI #t)
(in-range 0 m) (in-range m 2m))
(error 'matrix-inverse "internal error"))))
(: matrix-invertible? ((Matrix Number) -> Boolean))
(define (matrix-invertible? M)
(not (zero? (matrix-determinant M))))
(: matrix-solve : (Matrix Number) (Matrix Number) -> (Matrix Number))
; Return a column-vector x such that Mx = b.
; If no such vector exists return #f.
(define (matrix-solve M b)
(define-values (m n) (matrix-shape M))
(define-values (s t) (matrix-shape b))
(define m+1 (+ m 1))
(cond
[(not (= t 1)) (error 'matrix-solve "expected column vector (i.e. r x 1 - matrix), got: ~a " b)]
[(not (= m s)) (error 'matrix-solve "expected column vector with same number of rows as the matrix")]
[(index? m+1)
(submatrix
(matrix-reduced-row-echelon-form
(matrix-augment (list M b)) #t)
(in-range 0 m) (in-range m m+1))]
[else (error 'matrix-solve "internatl error")]))
(: make-invertible-fail (Symbol (Matrix Any) -> (-> Nothing)))
(define ((make-invertible-fail name M))
(raise-argument-error name "matrix-invertible?" M))
(: matrix-solve-many : (Matrix Number) (Listof (Matrix Number)) -> (Matrix Number))
(define (matrix-solve-many M bs)
(define-values (m n) (matrix-shape M))
(define-values (s t) (matrix-shape (car bs)))
(define k (length bs))
(define m+1 (+ m 1))
(define m+k (+ m k))
(cond
[(not (= t 1)) (error 'matrix-solve-many "expected column vector (i.e. r x 1 - matrix), got: ~a " (car bs))]
[(not (= m s)) (error 'matrix-solve-many "expected column vectors with same number of rows as the matrix")]
[(and (index? m+1) (index? m+k))
(define bs-as-matrix (matrix-augment bs))
(define MB (matrix-augment (list M bs-as-matrix)))
(define reduced-MB (matrix-reduced-row-echelon-form MB #t))
(submatrix reduced-MB
(in-range 0 m+k)
(in-range m m+1))]
[else (error 'matrix-solve-many "internal error")]))
(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-inverse
(case-lambda
[(M) (matrix-inverse M (make-invertible-fail 'matrix-inverse M))]
[(M fail)
(define m (square-matrix-size M))
(define I (identity-matrix m))
(define-values (IM^-1 wps) (matrix-gauss-elim (matrix-augment (list M I)) #t #t))
(cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IM^-1 (::) (:: m #f))]
[else (fail)])]))
(: matrix-solve (All (A) (case->
((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Real) (Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) (Matrix Number) -> (Matrix Number))
((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-solve
(case-lambda
[(M B) (matrix-solve M B (make-invertible-fail 'matrix-solve M))]
[(M B fail)
(define m (square-matrix-size M))
(define-values (s t) (matrix-shape B))
(cond [(= m s)
(define-values (IX wps) (matrix-gauss-elim (matrix-augment (list M B)) #t #t))
(cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IX (::) (:: m #f))]
[else (fail)])]
[else
(error 'matrix-solve
"matrices must have the same number of rows; given ~e and ~e"
M B)])]))
;;; LU Factorization
; Not all matrices can be LU-factored.
; If Gauss-elimination can be done without any row swaps,
; a LU-factorization is possible.
;; ===================================================================================================
;; LU Factorization
;; An LU factorization exists iff Gaussian elimination can be done without row swaps.
(: matrix-lu :
(Matrix Number) -> (U False (List (Matrix Number) (Matrix Number))))
@ -625,6 +574,9 @@
(take (sort (loop vs '() 0) norm>) r)
(error 'extend-span-to-basis "expected index as second argument, got ~a" r)))
;; ===================================================================================================
;; QR decomposition
(: matrix-qr : (Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
(define (matrix-qr M)
; compute the QR-facorization

View File

@ -71,9 +71,9 @@
(: matrix-num-rows (All (A) ((Array A) -> Index)))
(define (matrix-num-rows a)
(cond [(matrix? a) (vector-ref (array-shape a) 0)]
[else (raise-argument-error 'matrix-col-length "matrix?" a)]))
[else (raise-argument-error 'matrix-num-rows "matrix?" a)]))
(: matrix-num-cols (All (A) ((Array A) -> Index)))
(define (matrix-num-cols a)
(cond [(matrix? a) (vector-ref (array-shape a) 1)]
[else (raise-argument-error 'matrix-row-length "matrix?" a)]))
[else (raise-argument-error 'matrix-num-cols "matrix?" a)]))

View File

@ -20,7 +20,7 @@
(make-array #(0 1) 0)
(make-array #(0 0) 0)
(make-array #(1 1 1) 0)))
#|
;; ===================================================================================================
;; Literal syntax
@ -666,32 +666,136 @@
(check-exn exn:fail:contract? (λ () (matrix-trace (col-matrix [1 2 3]))))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-trace a))))
|#
;; ===================================================================================================
;; Gaussian elimination
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
(define (gauss-eliminate M reduce? partial-pivot?)
(let-values ([(M wp) (matrix-gauss-eliminate M reduce? partial-pivot?)])
M))
(check-equal? (matrix-row-echelon (matrix [[2 4] [3 4]]) #f #f)
(matrix [[3 4] [0 4/3]]))
(check-equal? (gauss-eliminate (matrix [[1 2] [3 4]]) #f #f)
(matrix [[1 2] [0 -2]]))
(check-equal? (matrix-row-echelon (matrix [[2 4] [3 4]]) #f #t)
(matrix [[1 4/3] [0 1]]))
(check-equal? (gauss-eliminate (matrix [[2 4] [3 4]]) #t #f)
(matrix [[1 2] [0 1]]))
(check-equal? (gauss-eliminate (matrix [[2. 4.] [3. 4.]]) #t #t)
(matrix [[1. 1.3333333333333333] [0. 1.]]))
(check-equal? (gauss-eliminate (matrix [[1 4] [2 4]]) #t #t)
(matrix [[1 2] [0 1]]))
(check-equal? (gauss-eliminate (matrix [[1 2] [2 4]]) #f #t)
(check-equal? (matrix-row-echelon (matrix [[1 2] [2 4]]) #f #f)
(matrix [[2 4] [0 0]]))
(check-equal? (matrix-row-echelon (matrix [[1 4] [2 4]]) #f #t)
(matrix [[1 2] [0 1]]))
(check-equal? (matrix-row-echelon (matrix [[ 2 1 -1 8]
[-3 -1 2 -11]
[-2 1 2 -3]])
#f #t)
(matrix [[1 1/3 -2/3 11/3]
[0 1 2/5 13/5]
[0 0 1 -1]]))
(check-equal? (matrix-row-echelon (matrix [[ 2 1 -1 8]
[-3 -1 2 -11]
[-2 1 2 -3]])
#t)
(matrix [[1 0 0 2]
[0 1 0 3]
[0 0 1 -1]]))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (gauss-eliminate a #f #f))))
(check-exn exn:fail:contract? (λ () (matrix-row-echelon a))))
(check-equal? (matrix-rank (matrix [[0 0] [0 0]])) 0)
(check-equal? (matrix-rank (matrix [[1 0] [0 0]])) 1)
(check-equal? (matrix-rank (matrix [[1 0] [0 3]])) 2)
(check-equal? (matrix-rank (matrix [[1 2] [2 4]])) 1)
(check-equal? (matrix-rank (matrix [[1 2] [3 4]])) 2)
(check-equal? (matrix-rank (matrix [[1 2 3]])) 1)
(check-equal? (matrix-rank (matrix [[1 2 3] [2 3 5]])) 2)
(check-equal? (matrix-rank (matrix [[1 2 3] [2 3 5] [3 4 7]])) 2)
(check-equal? (matrix-rank (matrix [[1 2 3] [2 3 5] [3 4 7] [4 5 9]])) 2)
(check-equal? (matrix-rank (matrix [[1 2 3 5] [2 3 5 8]])) 2)
(check-equal? (matrix-rank (matrix [[1 5 2 3] [2 8 3 5]])) 2)
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-rank a))))
(check-equal? (matrix-nullity (matrix [[0 0] [0 0]])) 2)
(check-equal? (matrix-nullity (matrix [[1 0] [0 0]])) 1)
(check-equal? (matrix-nullity (matrix [[1 0] [0 3]])) 0)
(check-equal? (matrix-nullity (matrix [[1 2] [2 4]])) 1)
(check-equal? (matrix-nullity (matrix [[1 2] [3 4]])) 0)
(check-equal? (matrix-nullity (matrix [[1 2 3]])) 2)
(check-equal? (matrix-nullity (matrix [[1 2 3] [2 3 5]])) 1)
(check-equal? (matrix-nullity (matrix [[1 2 3] [2 3 5] [3 4 7]])) 1)
(check-equal? (matrix-nullity (matrix [[1 2 3] [2 3 5] [3 4 7] [4 5 9]])) 1)
(check-equal? (matrix-nullity (matrix [[1 2 3 5] [2 3 5 8]])) 2)
(check-equal? (matrix-nullity (matrix [[1 5 2 3] [2 8 3 5]])) 2)
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-nullity a))))
;; ===================================================================================================
;; Determinant
(check-equal? (matrix-determinant (matrix [[3]])) 3)
(check-equal? (matrix-determinant (matrix [[1 2] [3 4]])) (- (* 1 4) (* 2 3)))
(check-equal? (matrix-determinant (matrix [[1 2 3] [4 5 6] [7 8 9]])) 0)
(check-equal? (matrix-determinant (matrix [[1 2 3] [4 -5 6] [7 8 9]])) 120)
(check-equal? (matrix-determinant (matrix [[1 2 3 4]
[-5 6 7 8]
[9 10 -11 12]
[13 14 15 16]]))
5280)
(for: ([_ (in-range 100)])
(define a (array- (random-matrix 3 3 7) (array 3)))
(check-equal? (matrix-determinant/row-reduction a)
(matrix-determinant a)))
(check-exn exn:fail:contract? (λ () (matrix-determinant (matrix [[1 2 3] [4 5 6]]))))
(check-exn exn:fail:contract? (λ () (matrix-determinant (matrix [[1 4] [2 5] [3 6]]))))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-determinant a))))
;; ===================================================================================================
;; Solving linear systems
(for: ([_ (in-range 100)])
(define M (array- (random-matrix 3 3 7) (array 3)))
(define B (array- (random-matrix 3 (+ 1 (random 10)) 7) (array 3)))
(cond [(matrix-invertible? M)
(define X (matrix-solve M B))
(check-equal? (matrix* M X) B (format "M = ~a B = ~a" M B))]
[else
(check-false (matrix-solve M B (λ () #f))
(format "M = ~a B = ~a" M B))]))
(check-exn exn:fail? (λ () (matrix-solve (random-matrix 3 4) (random-matrix 3 1))))
(check-exn exn:fail? (λ () (matrix-solve (random-matrix 4 3) (random-matrix 4 1))))
(check-exn exn:fail:contract? (λ () (matrix-solve (random-matrix 3 4) (random-matrix 4 1))))
(check-exn exn:fail:contract? (λ () (matrix-solve (random-matrix 4 3) (random-matrix 3 1))))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-solve a (matrix [[1]]))))
(check-exn exn:fail:contract? (λ () (matrix-solve (matrix [[1]]) a))))
;; ===================================================================================================
;; Inversion
(for: ([_ (in-range 100)])
(define a (array- (random-matrix 3 3 7) (array 3)))
(cond [(matrix-invertible? a)
(check-equal? (matrix* a (matrix-inverse a))
(identity-matrix 3)
(format "~a" a))
(check-equal? (matrix* (matrix-inverse a) a)
(identity-matrix 3)
(format "~a" a))]
[else
(check-false (matrix-inverse a (λ () #f))
(format "~a" a))]))
(check-exn exn:fail:contract? (λ () (matrix-inverse (random-matrix 3 4))))
(check-exn exn:fail:contract? (λ () (matrix-inverse (random-matrix 4 3))))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-inverse a))))
#|
;; ===================================================================================================
@ -773,37 +877,6 @@
4 4 ((inst vector Number)
2 4 9 11 0 0.0 2.23606797749979 2.23606797749979
0 0 0.0 4.440892098500626e-16 0 0 0 0.0))))))
(list 'matrix-solve
(let* ([M (list*->matrix '[[1 5] [2 3]])]
[b (list*->matrix '[[5] [5]])])
(equal? (matrix* M (matrix-solve M b)) b)))
(list 'matrix-inverse
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* M (matrix-inverse M)))
(identity-matrix 2))
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* (matrix-inverse M) M))
(identity-matrix 2)))
(list 'matrix-determinant
(equal? (matrix-determinant (list*->matrix '[[3]])) 3)
(equal? (matrix-determinant (list*->matrix '[[1 2] [3 4]])) (- (* 1 4) (* 2 3)))
(equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 5 6] [7 8 9]])) 0)
(equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 -5 6] [7 8 9]])) 120)
(equal? (matrix-determinant (list*->matrix '[[1 2 3 4]
[-5 6 7 8]
[9 10 -11 12]
[13 14 15 16]]))
5280))
(list
'matrix-scale-row
(equal? (matrix-scale-row (identity-matrix 3) 0 2)
(list*->array '[[2 0 0] [0 1 0] [0 0 1]] real? )))
(list
'matrix-swap-rows
(equal? (matrix-swap-rows (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real? ) 0 1)
(list*->array '[[4 5 6] [1 2 3] [7 8 9]] real? )))
(list
'matrix-add-scaled-row
(equal? (matrix-add-scaled-row (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real? ) 0 2 1)
(list*->array '[[9 12 15] [4 5 6] [7 8 9]] real? )))
(let ()
(define M (list*->matrix '[[1 1 0 3]
[2 1 -1 1]
@ -828,39 +901,6 @@
[0 0 3 13]
[0 0 0 -13]]))
(equal? (matrix* L V) M)))))
(list
'matrix-rank
(equal? (matrix-rank (list*->matrix '[[0 0] [0 0]])) 0)
(equal? (matrix-rank (list*->matrix '[[1 0] [0 0]])) 1)
(equal? (matrix-rank (list*->matrix '[[1 0] [0 3]])) 2)
(equal? (matrix-rank (list*->matrix '[[1 2] [2 4]])) 1)
(equal? (matrix-rank (list*->matrix '[[1 2] [3 4]])) 2))
(list
'matrix-nullity
(equal? (matrix-nullity (list*->matrix '[[0 0] [0 0]])) 2)
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 0]])) 1)
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0)
(equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1)
(equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0))
#;
(let ()
(define-values (c1 n1)
(matrix-column+null-space (list*rix '[[0 0] [0 0]])))
(define-values (c2 n2)
(matrix-column+null-space (list*->matrix '[[1 2] [2 4]])))
(define-values (c3 n3)
(matrix-column+null-space (list*atrix '[[1 2] [2 5]])))
(list
'matrix-column+null-space
(equal? c1 '())
(equal? n1 (list (list*->matrix '[[0] [0]])
(list*->matrix '[[0] [0]])))
(equal? c2 (list (list*->matrix '[[1] [2]])))
;(equal? n2 '([0 0]))
(equal? c3 (list (list*->matrix '[[1] [2]])
(list*->matrix '[[2] [5]])))
(equal? n3 '()))))
#;
(begin
"matrix-2d.rkt"