Moar `math/matrix' review/refactoring

* Gram-Schmidt using vector type

* QR decomposition

* Operator 1-norm and maximum norm; stub for 2-norm and angle between
  subspaces (`matrix-basis-angle')

* `matrix-absolute-error' and `matrix-relative-error'; also predicates
  based on them, such as `matrix-identity?'

* Lots of shuffling code about

* Types that can have contracts, and an exhaustive test to make sure
  every value exported by `math/matrix' has a contract when used in
  untyped code

* Some more tests (still needs some)
This commit is contained in:
Neil Toronto 2012-12-31 14:13:36 -07:00
parent e06f31c94e
commit f5fa93572d
26 changed files with 1366 additions and 934 deletions

View File

@ -1,16 +1,133 @@
#lang typed/racket/base
#lang racket/base
(require typed/untyped-utils)
(require "private/matrix/matrix-arithmetic.rkt"
"private/matrix/matrix-constructors.rkt"
"private/matrix/matrix-conversion.rkt"
"private/matrix/matrix-syntax.rkt"
"private/matrix/matrix-basic.rkt"
"private/matrix/matrix-operations.rkt"
"private/matrix/matrix-comprehension.rkt"
"private/matrix/matrix-expt.rkt"
"private/matrix/matrix-types.rkt"
"private/matrix/matrix-2d.rkt"
"private/matrix/utils.rkt")
;;"private/matrix/matrix-gauss-elim.rkt" ; all use require/untyped-contract
(except-in "private/matrix/matrix-solve.rkt"
matrix-determinant
matrix-inverse
matrix-solve)
(except-in "private/matrix/matrix-constructors.rkt"
vandermonde-matrix)
(except-in "private/matrix/matrix-basic.rkt"
matrix-dot
matrix-angle
matrix-normalize
matrix-conjugate
matrix-hermitian
matrix-trace
matrix-normalize-rows
matrix-normalize-cols)
(except-in "private/matrix/matrix-subspace.rkt"
matrix-col-space)
(except-in "private/matrix/matrix-operator-norm.rkt"
matrix-basis-angle)
;;"private/matrix/matrix-qr.rkt" ; all use require/untyped-contract
;;"private/matrix/matrix-lu.rkt" ; all use require/untyped-contract
;;"private/matrix/matrix-gram-schmidt.rkt" ; all use require/untyped-contract
)
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-gauss-elim.rkt"
[matrix-gauss-elim
(case-> ((Matrix Number) -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index))))]
[matrix-row-echelon
(case-> ((Matrix Number) -> (Matrix Number))
((Matrix Number) Any -> (Matrix Number))
((Matrix Number) Any Any -> (Matrix Number)))])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-solve.rkt"
[matrix-determinant
((Matrix Number) -> Number)]
[matrix-inverse
(All (A) (case-> ((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number)))))]
[matrix-solve
(All (A) (case->
((Matrix Number) (Matrix Number) -> (Matrix Number))
((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number)))))])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-constructors.rkt"
[vandermonde-matrix ((Listof Number) Integer -> (Matrix Number))])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-basic.rkt"
[matrix-dot
(case-> ((Matrix Number) -> Nonnegative-Real)
((Matrix Number) (Matrix Number) -> Number))]
[matrix-angle
((Matrix Number) (Matrix Number) -> Number)]
[matrix-normalize
(All (A) (case-> ((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))]
[matrix-conjugate
((Matrix Number) -> (Matrix Number))]
[matrix-hermitian
((Matrix Number) -> (Matrix Number))]
[matrix-trace
((Matrix Number) -> Number)]
[matrix-normalize-rows
(All (A) (case-> ((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))]
[matrix-normalize-cols
(All (A) (case-> ((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-subspace.rkt"
[matrix-col-space
(All (A) (case-> ((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number)))))])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-operator-norm.rkt"
[matrix-basis-angle
((Matrix Number) (Matrix Number) -> Number)])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-qr.rkt"
[matrix-qr
(case-> ((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-lu.rkt"
[matrix-lu
(All (A) (case-> ((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number)))))])
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"
"private/array/array-struct.rkt"))
"private/matrix/matrix-gram-schmidt.rkt"
[matrix-gram-schmidt
(case-> ((Matrix Number) -> (Array Number))
((Matrix Number) Any -> (Array Number))
((Matrix Number) Any Integer -> (Array Number)))]
[matrix-basis-extension
((Matrix Number) -> (Array Number))])
(provide (all-from-out
"private/matrix/matrix-arithmetic.rkt"
@ -18,8 +135,40 @@
"private/matrix/matrix-conversion.rkt"
"private/matrix/matrix-syntax.rkt"
"private/matrix/matrix-basic.rkt"
"private/matrix/matrix-operations.rkt"
"private/matrix/matrix-subspace.rkt"
"private/matrix/matrix-solve.rkt"
"private/matrix/matrix-operator-norm.rkt"
"private/matrix/matrix-comprehension.rkt"
"private/matrix/matrix-expt.rkt"
"private/matrix/matrix-types.rkt"
"private/matrix/matrix-2d.rkt"))
"private/matrix/matrix-2d.rkt")
;; matrix-gauss-elim.rkt
matrix-gauss-elim
matrix-row-echelon
;; matrix-solve.rkt
matrix-determinant
matrix-inverse
matrix-solve
;; matrix-constructors.rkt
vandermonde-matrix
;; matrix-basic.rkt
matrix-dot
matrix-angle
matrix-normalize
matrix-conjugate
matrix-hermitian
matrix-trace
matrix-normalize-rows
matrix-normalize-cols
;; matrix-subspace.rkt
matrix-col-space
;; matrix-operator-norm.rkt
matrix-basis-angle
;; matrix-qr.rkt
matrix-qr
;; matrix-lu.rkt
matrix-lu
;; matrix-gram-schmidt.rkt
matrix-gram-schmidt
matrix-basis-extension
)

View File

@ -25,8 +25,9 @@
(define-syntax-rule (define-array-op name op)
(define-syntax-rule (name arrs (... ...)) (array-map op arrs (... ...))))
(define-syntax-rule (array-scale arr x)
(inline-array-map (λ (y) (* x y)) arr))
(define-syntax-rule (array-scale arr x-expr)
(let ([x x-expr])
(inline-array-map (λ (y) (* x y)) arr)))
(define-array-op1 array-sqr sqr)
(define-array-op1 array-sqrt sqrt)

View File

@ -1,7 +1,6 @@
#lang typed/racket/base
(require math/array
"matrix-types.rkt"
(require "matrix-types.rkt"
"matrix-syntax.rkt")
(provide matrix-2d-rotation

View File

@ -1,13 +1,52 @@
#lang racket/base
(module untyped-arithmetic-defs typed/racket/base
(require "matrix-types.rkt"
(prefix-in typed: "typed-matrix-arithmetic.rkt"))
(provide (all-defined-out))
(: matrix* ((Matrix Number) (Matrix Number) * -> (Matrix Number)))
(define matrix* typed:matrix*)
(: matrix+ ((Matrix Number) (Matrix Number) * -> (Matrix Number)))
(define matrix+ typed:matrix+)
(: matrix- ((Matrix Number) (Matrix Number) * -> (Matrix Number)))
(define matrix- typed:matrix-)
(: matrix-scale ((Matrix Number) Number -> (Matrix Number)))
(define matrix-scale typed:matrix-scale)
(: matrix-sum ((Listof (Matrix Number)) -> (Matrix Number)))
(define matrix-sum typed:matrix-sum)
) ; module untyped-arithmetic-defs
(module arithmetic-defs racket/base
(require typed/untyped-utils
(prefix-in typed: "typed-matrix-arithmetic.rkt")
(prefix-in untyped: (submod ".." untyped-arithmetic-defs))
(rename-in "untyped-matrix-arithmetic.rkt"
[matrix-map untyped:matrix-map]))
(provide (all-defined-out))
(define-typed/untyped-identifier matrix-map typed:matrix-map untyped:matrix-map)
(define-typed/untyped-identifier matrix* typed:matrix* untyped:matrix*)
(define-typed/untyped-identifier matrix+ typed:matrix+ untyped:matrix+)
(define-typed/untyped-identifier matrix- typed:matrix- untyped:matrix-)
(define-typed/untyped-identifier matrix-scale typed:matrix-scale untyped:matrix-scale)
(define-typed/untyped-identifier matrix-sum typed:matrix-sum untyped:matrix-sum)
) ; module arithmetic-defs
(require (for-syntax racket/base)
typed/untyped-utils
(prefix-in typed: "typed-matrix-arithmetic.rkt")
(rename-in "untyped-matrix-arithmetic.rkt"
[matrix-map untyped:matrix-map]))
(define-typed/untyped-identifier matrix-map
typed:matrix-map untyped:matrix-map)
(prefix-in fun: (submod "." arithmetic-defs))
(except-in "untyped-matrix-arithmetic.rkt" matrix-map)
)
(define-syntax (define/inline-macro stx)
(syntax-case stx ()
@ -19,17 +58,16 @@
[(_ . es) (syntax/loc inner-stx (typed:fun . es))]
[_ (syntax/loc inner-stx typed:fun)])))]))
(define/inline-macro matrix* (a as ...) inline-matrix* typed:matrix*)
(define/inline-macro matrix+ (a as ...) inline-matrix+ typed:matrix+)
(define/inline-macro matrix- (a as ...) inline-matrix- typed:matrix-)
(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale)
(define/inline-macro do-matrix-map (f a as ...) inline-matrix-map matrix-map)
(define/inline-macro matrix-map (f a as ...) inline-matrix-map fun:matrix-map)
(define/inline-macro matrix* (a as ...) inline-matrix* fun:matrix*)
(define/inline-macro matrix+ (a as ...) inline-matrix+ fun:matrix+)
(define/inline-macro matrix- (a as ...) inline-matrix- fun:matrix-)
(define/inline-macro matrix-scale (a x) inline-matrix-scale fun:matrix-scale)
(provide
(rename-out [do-matrix-map matrix-map]
[typed:matrix= matrix=]
[typed:matrix-sum matrix-sum])
(rename-out [typed:matrix= matrix=]
[fun:matrix-sum matrix-sum])
matrix-map
matrix*
matrix+
matrix-

View File

@ -1,41 +1,66 @@
#lang typed/racket
#lang typed/racket/base
(require racket/list
racket/fixnum
math/array
math/flonum
math/base
"matrix-types.rkt"
"matrix-arithmetic.rkt"
"matrix-constructors.rkt"
"matrix-conversion.rkt"
"utils.rkt"
"../unsafe.rkt")
"../unsafe.rkt"
"../array/array-struct.rkt"
"../array/array-indexing.rkt"
"../array/array-sequence.rkt"
"../array/array-transform.rkt"
"../array/array-fold.rkt"
"../array/array-pointwise.rkt"
"../array/array-convert.rkt"
"../array/utils.rkt"
"../vector/vector-mutate.rkt")
(provide
;; Extraction
matrix-ref
matrix-diagonal
submatrix
matrix-row
matrix-col
matrix-rows
matrix-cols
;; Predicates
matrix-zero?
matrix-diagonal
matrix-upper-triangle
matrix-lower-triangle
;; Embiggenment
matrix-augment
matrix-stack
;; Norm and inner product
;; Inner product space
matrix-1norm
matrix-2norm
matrix-inf-norm
matrix-norm
matrix-dot
matrix-angle
matrix-normalize
;; Simple operators
matrix-transpose
matrix-conjugate
matrix-hermitian
matrix-trace)
matrix-trace
;; Row/column operators
matrix-map-rows
matrix-map-cols
matrix-normalize-rows
matrix-normalize-cols
;; Predicates
matrix-zero?
matrix-rows-orthogonal?
matrix-cols-orthogonal?)
;; ===================================================================================================
;; Extraction
(: matrix-ref (All (A) (Array A) Integer Integer -> A))
(: matrix-ref (All (A) (Matrix A) Integer Integer -> A))
(define (matrix-ref a i j)
(define-values (m n) (matrix-shape a))
(cond [(or (i . < . 0) (i . >= . m))
@ -45,16 +70,6 @@
[else
(unsafe-array-ref a ((inst vector Index) i j))]))
(: matrix-diagonal (All (A) ((Array A) -> (Array A))))
(define (matrix-diagonal a)
(define m (square-matrix-size a))
(define proc (unsafe-array-proc a))
(unsafe-build-array
((inst vector Index) m)
(λ: ([js : Indexes])
(define i (unsafe-vector-ref js 0))
(proc ((inst vector Index) i i)))))
(: submatrix (All (A) (Matrix A) Slice-Spec Slice-Spec -> (Matrix A)))
(define (submatrix a row-range col-range)
(array-slice-ref (ensure-matrix 'submatrix a) (list row-range col-range)))
@ -89,42 +104,67 @@
(unsafe-vector-set! ij 1 0)
res))]))
(: matrix-rows (All (A) (Array A) -> (Listof (Array A))))
(: matrix-rows (All (A) (Matrix A) -> (Listof (Matrix A))))
(define (matrix-rows a)
(array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0))
(: matrix-cols (All (A) (Array A) -> (Listof (Array A))))
(: matrix-cols (All (A) (Matrix A) -> (Listof (Matrix A))))
(define (matrix-cols a)
(array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1))
;; ===================================================================================================
;; Predicates
(: matrix-diagonal (All (A) ((Matrix A) -> (Array A))))
(define (matrix-diagonal a)
(define m (square-matrix-size a))
(define proc (unsafe-array-proc a))
(unsafe-build-array
((inst vector Index) m)
(λ: ([js : Indexes])
(define i (unsafe-vector-ref js 0))
(proc ((inst vector Index) i i)))))
(: matrix-zero? ((Array Number) -> Boolean))
(define (matrix-zero? a)
(array-all-and (matrix-map zero? a)))
(: matrix-upper-triangle (All (A) ((Matrix A) -> (Matrix (U A 0)))))
(define (matrix-upper-triangle M)
(define-values (m n) (matrix-shape M))
(define proc (unsafe-array-proc M))
(unsafe-build-array
((inst vector Index) m n)
(λ: ([ij : Indexes])
(define i (unsafe-vector-ref ij 0))
(define j (unsafe-vector-ref ij 1))
(if (i . fx<= . j) (proc ij) 0))))
(: matrix-lower-triangle (All (A) ((Matrix A) -> (Matrix (U A 0)))))
(define (matrix-lower-triangle M)
(define-values (m n) (matrix-shape M))
(define proc (unsafe-array-proc M))
(unsafe-build-array
((inst vector Index) m n)
(λ: ([ij : Indexes])
(define i (unsafe-vector-ref ij 0))
(define j (unsafe-vector-ref ij 1))
(if (i . fx>= . j) (proc ij) 0))))
;; ===================================================================================================
;; Embiggenment (this is a perfectly cromulent word)
(: matrix-augment (All (A) (Listof (Array A)) -> (Array A)))
(: matrix-augment (All (A) (Listof (Matrix A)) -> (Matrix A)))
(define (matrix-augment as)
(cond [(empty? as) (raise-argument-error 'matrix-augment "nonempty List" as)]
[else
(define m (matrix-num-rows (first as)))
(cond [(andmap (λ: ([a : (Array A)]) (= m (matrix-num-rows a))) (rest as))
(cond [(andmap (λ: ([a : (Matrix A)]) (= m (matrix-num-rows a))) (rest as))
(array-append* as 1)]
[else
(error 'matrix-augment
"matrices must have the same number of rows; given ~a"
(format-matrices/error as))])]))
(: matrix-stack (All (A) (Listof (Array A)) -> (Array A)))
(: matrix-stack (All (A) (Listof (Matrix A)) -> (Matrix A)))
(define (matrix-stack as)
(cond [(empty? as) (raise-argument-error 'matrix-stack "nonempty List" as)]
[else
(define n (matrix-num-cols (first as)))
(cond [(andmap (λ: ([a : (Array A)]) (= n (matrix-num-cols a))) (rest as))
(cond [(andmap (λ: ([a : (Matrix A)]) (= n (matrix-num-cols a))) (rest as))
(array-append* as 0)]
[else
(error 'matrix-stack
@ -132,81 +172,223 @@
(format-matrices/error as))])]))
;; ===================================================================================================
;; Matrix norms and Frobenius inner product
;; Inner product space (entrywise norm)
(: maximum-norm ((Array Number) -> Real))
(define (maximum-norm a)
(array-all-max (array-magnitude a)))
(: taxicab-norm ((Array Number) -> Real))
(define (taxicab-norm a)
(: matrix-1norm ((Matrix Number) -> Nonnegative-Real))
(define (matrix-1norm a)
(array-all-sum (array-magnitude a)))
(: frobenius-norm ((Array Number) -> Real))
(define (frobenius-norm a)
(: matrix-2norm ((Matrix Number) -> Nonnegative-Real))
(define (matrix-2norm a)
(let ([a (array-strict (array-magnitude a))])
;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max a))
(cond [(and (rational? mx) (positive? mx))
(assert
(* mx (sqrt (array-all-sum (inline-array-map (λ: ([x : Number]) (sqr (/ x mx))) a))))
real?)]
(* mx (sqrt (array-all-sum
(inline-array-map (λ: ([x : Nonnegative-Real]) (sqr (/ x mx))) a))))]
[else mx])))
(: p-norm ((Array Number) Positive-Real -> Real))
(define (p-norm a p)
(: matrix-inf-norm ((Matrix Number) -> Nonnegative-Real))
(define (matrix-inf-norm a)
(array-all-max (array-magnitude a)))
(: matrix-p-norm ((Matrix Number) Positive-Real -> Nonnegative-Real))
(define (matrix-p-norm a p)
(let ([a (array-strict (array-magnitude a))])
;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max a))
(cond [(and (rational? mx) (positive? mx))
(assert
(* mx (expt (array-all-sum (inline-array-map (λ: ([x : Real]) (expt (/ x mx) p)) a))
(* mx (expt (array-all-sum
(inline-array-map (λ: ([x : Nonnegative-Real]) (expt (/ x mx) p)) a))
(/ p)))
real?)]
(make-predicate Nonnegative-Real))]
[else mx])))
(: matrix-norm (case-> ((Array Number) -> Real)
((Array Number) Real -> Real)))
(: matrix-norm (case-> ((Matrix Number) -> Nonnegative-Real)
((Matrix Number) Real -> Nonnegative-Real)))
;; Computes the p norm of a matrix
(define (matrix-norm a [p 2])
(cond [(not (matrix? a)) (raise-argument-error 'matrix-norm "matrix?" 0 a p)]
[(p . = . 2) (frobenius-norm a)]
[(p . = . +inf.0) (maximum-norm a)]
[(p . = . 1) (taxicab-norm a)]
[(p . > . 1) (p-norm a p)]
[(p . = . 1) (matrix-1norm a)]
[(p . = . 2) (matrix-2norm a)]
[(p . = . +inf.0) (matrix-inf-norm a)]
[(p . > . 1) (matrix-p-norm a p)]
[else (raise-argument-error 'matrix-norm "Real >= 1" 1 a p)]))
(: matrix-dot (case-> ((Array Real) (Array Real) -> Real)
((Array Number) (Array Number) -> Number)))
;; Computes the Frobenius inner product of two matrices
(define (matrix-dot a b)
(define-values (m n) (matrix-shapes 'matrix-dot a b))
(define aproc (unsafe-array-proc a))
(define bproc (unsafe-array-proc b))
(array-all-sum
(unsafe-build-array
((inst vector Index) m n)
(λ: ([js : Indexes])
(* (aproc js) (conjugate (bproc js)))))))
(: matrix-dot (case-> ((Matrix Real) -> Nonnegative-Real)
((Matrix Real) (Matrix Real) -> Real)
((Matrix Number) -> Nonnegative-Real)
((Matrix Number) (Matrix Number) -> Number)))
;; Computes the Frobenius inner product of a matrix with itself or of two matrices
(define matrix-dot
(case-lambda
[(a)
(assert
(array-all-sum
(inline-array-map
(λ (x) (* x (conjugate x)))
(ensure-matrix 'matrix-dot a)))
(make-predicate Nonnegative-Real))]
[(a b)
(define-values (m n) (matrix-shapes 'matrix-dot a b))
(define aproc (unsafe-array-proc a))
(define bproc (unsafe-array-proc b))
(array-all-sum
(unsafe-build-array
((inst vector Index) m n)
(λ: ([js : Indexes])
(* (aproc js) (conjugate (bproc js))))))]))
(: matrix-angle (case-> ((Matrix Real) (Matrix Real) -> Real)
((Matrix Number) (Matrix Number) -> Number)))
(define (matrix-angle M N)
(acos (/ (matrix-dot M N) (* (matrix-2norm M) (matrix-2norm N)))))
(: matrix-normalize
(All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) Real -> (Matrix Real))
((Matrix Real) Real (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number))))))
(define matrix-normalize
(case-lambda
[(M) (matrix-normalize M 2)]
[(M p) (matrix-normalize M p (λ () (raise-argument-error
'matrix-normalize "nonzero matrix?" 0 M p)))]
[(M p fail)
(array-strict! M)
(define x (matrix-norm M p))
(cond [(and (zero? x) (exact? x)) (fail)]
[else (matrix-scale M (/ x))])]))
;; ===================================================================================================
;; Operators
(: matrix-transpose (All (A) (Array A) -> (Array A)))
(: matrix-transpose (All (A) (Matrix A) -> (Matrix A)))
(define (matrix-transpose a)
(array-axis-swap (ensure-matrix 'matrix-transpose a) 0 1))
(: matrix-conjugate (case-> ((Array Real) -> (Array Real))
((Array Number) -> (Array Number))))
(: matrix-conjugate (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Number) -> (Matrix Number))))
(define (matrix-conjugate a)
(array-conjugate (ensure-matrix 'matrix-conjugate a)))
(: matrix-hermitian (case-> ((Array Real) -> (Array Real))
((Array Number) -> (Array Number))))
(: matrix-hermitian (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Number) -> (Matrix Number))))
(define (matrix-hermitian a)
(array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1))
(: matrix-trace (case-> ((Array Real) -> Real)
((Array Number) -> Number)))
(: matrix-trace (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number)))
(define (matrix-trace a)
(array-all-sum (matrix-diagonal a)))
;; ===================================================================================================
;; Row/column operations
(: matrix-map-rows
(All (A B F) (case-> (((Matrix A) -> (Matrix B)) (Matrix A) -> (Matrix B))
(((Matrix A) -> (U #f (Matrix B))) (Matrix A) (-> F)
-> (U F (Matrix B))))))
(define matrix-map-rows
(case-lambda
[(f M) (matrix-stack (map f (matrix-rows M)))]
[(f M fail)
(define ms (matrix-rows M))
(define n (f (first ms)))
(cond [n (let loop ([ms (rest ms)] [ns (list n)])
(cond [(empty? ms) (matrix-stack (reverse ns))]
[else (define n (f (first ms)))
(cond [n (loop (rest ms) (cons n ns))]
[else (fail)])]))]
[else (fail)])]))
(: matrix-map-cols
(All (A B F) (case-> (((Matrix A) -> (Matrix B)) (Matrix A) -> (Matrix B))
(((Matrix A) -> (U #f (Matrix B))) (Matrix A) (-> F)
-> (U F (Matrix B))))))
(define matrix-map-cols
(case-lambda
[(f M) (matrix-augment (map f (matrix-cols M)))]
[(f M fail)
(define ms (matrix-cols M))
(define n (f (first ms)))
(cond [n (let loop ([ms (rest ms)] [ns (list n)])
(cond [(empty? ms) (matrix-augment (reverse ns))]
[else (define n (f (first ms)))
(cond [n (loop (rest ms) (cons n ns))]
[else (fail)])]))]
[else (fail)])]))
(: make-matrix-normalize (Real -> (case-> ((Matrix Real) -> (U #f (Matrix Real)))
((Matrix Number) -> (U #f (Matrix Number))))))
(define ((make-matrix-normalize p) M)
(matrix-normalize M p (λ () #f)))
(: matrix-normalize-rows
(All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) Real -> (Matrix Real))
((Matrix Real) Real (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number))))))
(define matrix-normalize-rows
(case-lambda
[(M) (matrix-normalize-rows M 2)]
[(M p)
(define (fail) (raise-argument-error 'matrix-normalize-rows "matrix? with nonzero rows" 0 M p))
(matrix-normalize-rows M p fail)]
[(M p fail)
(matrix-map-rows (make-matrix-normalize p) M fail)]))
(: matrix-normalize-cols
(All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) Real -> (Matrix Real))
((Matrix Real) Real (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number))))))
(define matrix-normalize-cols
(case-lambda
[(M) (matrix-normalize-cols M 2)]
[(M p)
(define (fail)
(raise-argument-error 'matrix-normalize-cols "matrix? with nonzero columns" 0 M p))
(matrix-normalize-cols M p fail)]
[(M p fail)
(matrix-map-cols (make-matrix-normalize p) M fail)]))
;; ===================================================================================================
;; Robust predicates using entrywise norms
(: matrix-zero? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean)))
(define (matrix-zero? M [eps 0])
(cond [(negative? eps) (raise-argument-error 'matrix-zero? "Nonnegative-Real" 1 M eps)]
[else (<= (matrix-norm M +inf.0) eps)]))
(: rows-orthogonal? ((Matrix Number) Nonnegative-Real -> Boolean))
(define (rows-orthogonal? M eps)
(define rows (matrix->vector* M))
(define m (vector-length rows))
(let/ec: return : Boolean
(for*: ([i0 (in-range m)] [i1 (in-range (fx+ i0 1) m)])
(define r0 (unsafe-vector-ref rows i0))
(define r1 (unsafe-vector-ref rows i1))
(when ((sqrt (magnitude (vector-dot r0 r1))) . >= . eps) (return #f)))
#t))
(: matrix-rows-orthogonal? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean)))
(define (matrix-rows-orthogonal? M [eps (* 10 epsilon.0)])
(cond [(negative? eps) (raise-argument-error 'matrix-rows-orthogonal? "Nonnegative-Real" 1 M eps)]
[else (rows-orthogonal? M eps)]))
(: matrix-cols-orthogonal? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean)))
(define (matrix-cols-orthogonal? M [eps (* 10 epsilon.0)])
(cond [(negative? eps) (raise-argument-error 'matrix-cols-orthogonal? "Nonnegative-Real" 1 M eps)]
[else (rows-orthogonal? (matrix-transpose M) eps)]))

View File

@ -1,112 +0,0 @@
#lang typed/racket
(require racket/fixnum
math/array
math/matrix
"matrix-column.rkt"
"utils.rkt"
"../unsafe.rkt"
"../vector/vector-mutate.rkt"
)
(: col-matrix-project1 (case-> ((Matrix Real) (Matrix Real) Any -> (U #f (Matrix Real)))
((Matrix Number) (Matrix Number) Any -> (U #f (Matrix Number)))))
(define (col-matrix-project1 v b unit?)
(cond [unit? (matrix-scale b (matrix-dot v b))]
[else (define b.b (matrix-dot b b))
(cond [(and (zero? b.b) (exact? b.b)) #f]
[else (matrix-scale b (/ (matrix-dot v b) b.b))])]))
(: col-matrix-project
(All (A) (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Real) (Matrix Real) Any -> (U A (Matrix Real)))
((Matrix Real) (Matrix Real) Any (-> A) -> (U A (Matrix Real)))
((Matrix Number) (Matrix Number) -> (Matrix Number))
((Matrix Number) (Matrix Number) Any -> (U A (Matrix Number)))
((Matrix Number) (Matrix Number) Any (-> A) -> (U A (Matrix Number))))))
(define col-matrix-project
(case-lambda
[(v B) (col-matrix-project v B #f)]
[(v B unit?)
(col-matrix-project
v B unit?
(λ () (error 'col-matrix-project "expected basis with nonzero column vectors; given ~e" B)))]
[(v B unit? fail)
(unless (col-matrix? v) (raise-argument-error 'col-matrix-project "col-matrix?" v))
(define bs (matrix-cols (ensure-matrix 'col-matrix-project B)))
(define p (col-matrix-project1 v (first bs) unit?))
(cond [p (let loop ([bs (rest bs)] [p p])
(cond [(empty? bs) p]
[else (define q (col-matrix-project1 v (first bs) unit?))
(if q (loop (rest bs) (matrix+ p q)) (fail))]))]
[else (fail)])]))
(: find-nonzero-vector (case-> ((Vectorof (Vectorof Real)) -> (U #f Index))
((Vectorof (Vectorof Number)) -> (U #f Index))))
(define (find-nonzero-vector vss)
(define n (vector-length vss))
(cond [(= n 0) #f]
[else (let loop ([#{i : Nonnegative-Fixnum} 0])
(cond [(i . fx< . n)
(define vs (unsafe-vector-ref vss i))
(if (vector-zero? vs) (loop (fx+ i 1)) i)]
[else #f]))]))
(: subtract-projections!
(case-> ((Vectorof (Vectorof Real)) Index Index (Vectorof Real) Any -> Void)
((Vectorof (Vectorof Number)) Index Index (Vectorof Number) Any -> Void)))
(define (subtract-projections! cols n i ci unit?)
(let j-loop ([#{j : Nonnegative-Fixnum} (fx+ i 1)])
(when (j . fx< . n)
(vector-sub-proj! (unsafe-vector-ref cols j) ci unit?)
(j-loop (fx+ j 1)))))
(: matrix-gram-schmidt (All (A) (case-> ((Matrix Real) -> (Array Real))
((Matrix Real) Any -> (Array Real))
((Matrix Number) -> (Array Number))
((Matrix Number) Any -> (Array Number)))))
(define (matrix-gram-schmidt M [unit? #f])
(define rows (matrix->vector* M))
(define n (vector-length rows))
(define i (find-nonzero-vector rows))
(cond [i (define rowi (unsafe-vector-ref rows i))
(subtract-projections! rows n i rowi #f)
(when unit? (vector-normalize! rowi))
(let loop ([#{i : Nonnegative-Fixnum} (fx+ i 1)] [bs (list rowi)])
(cond [(i . fx< . n)
(define rowi (unsafe-vector-ref rows i))
(cond [(vector-zero? rowi) (loop (fx+ i 1) bs)]
[else (subtract-projections! rows n i rowi #f)
(when unit? (vector-normalize! rowi))
(loop (fx+ i 1) (cons rowi bs))])]
[else
(vector*->matrix (list->vector (reverse bs)))]))]
[else
(make-array (vector 0 (matrix-num-cols M)) 0)]))
#|
(define a (col-matrix [1 2 1]))
(define b (col-matrix [1 -2 2]))
(define basis
(gram-schmidt-orthogonal
(matrix-cols
(array #[#[2 1 0] #[2 2 1] #[0 2 0]]))))
(column-project a b)
(col-matrix-project a b)
(projection-on-orthogonal-basis a basis)
(col-matrix-project a (matrix-augment basis))
(projection-on-orthonormal-basis a basis)
(col-matrix-project a (matrix-augment basis) 'orthonormal)
(matrix-gram-schmidt
(matrix [[0 1 2]
[0 2 3]
[0 1 5]]))
(matrix-gram-schmidt
(matrix [[5 1 2]
[2 2 3]
[-3 1 5]]))
|#

View File

@ -1,126 +0,0 @@
#lang typed/racket/base
(require math/array
math/base
"matrix-types.rkt"
"matrix-conversion.rkt"
"matrix-arithmetic.rkt"
"../unsafe.rkt")
(provide unit-column
column-height
unsafe-column->vector
column-scale
column+
column-dot
column-norm
column-project
column-project/unit
column-normalize)
(: unit-column : Integer Integer -> (Result-Column Number))
(define (unit-column m i)
(cond
[(and (index? m) (index? i))
(define v (make-vector m 0))
(if (< i m)
(vector-set! v i 1)
(error 'unit-vector "dimension must be largest"))
(vector->matrix m 1 v)]
[else
(error 'unit-vector "expected two indices")]))
(: column-height : (Column Number) -> Index)
(define (column-height v)
(if (vector? v)
(vector-length v)
(matrix-num-rows v)))
(: unsafe-column->vector : (Column Number) -> (Vectorof Number))
(define (unsafe-column->vector v)
(if (vector? v) v
(let ()
(define-values (m n) (matrix-shape v))
(if (= n 1)
(mutable-array-data (array->mutable-array v))
(error 'unsafe-column->vector
"expected a column (vector or mx1 matrix), got ~a" v)))))
(: column-scale : (Column Number) Number -> (Result-Column Number))
(define (column-scale a s)
(if (vector? a)
(let*: ([n (vector-length a)]
[v : (Vectorof Number) (make-vector n 0)])
(for: ([i (in-range 0 n)]
[x : Number (in-vector a)])
(vector-set! v i (* s x)))
(->col-matrix v))
(matrix-scale a s)))
(: column+ : (Column Number) (Column Number) -> (Result-Column Number))
(define (column+ v w)
(cond [(and (vector? v) (vector? w))
(let ([n (vector-length v)]
[m (vector-length w)])
(unless (= m n)
(error 'column+
"expected two column vectors of the same length, got ~a and ~a" v w))
(define: v+w : (Vectorof Number) (make-vector n 0))
(for: ([i (in-range 0 n)]
[x : Number (in-vector v)]
[y : Number (in-vector w)])
(vector-set! v+w i (+ x y)))
(->col-matrix v+w))]
[else
(unless (= (column-height v) (column-height w))
(error 'column+
"expected two column vectors of the same length, got ~a and ~a" v w))
(array+ (->col-matrix v) (->col-matrix w))]))
(: column-dot : (Column Number) (Column Number) -> Number)
(define (column-dot c d)
(define v (unsafe-column->vector c))
(define w (unsafe-column->vector d))
(define m (column-height v))
(define s (column-height w))
(cond
[(not (= m s)) (error 'column-dot
"expected two mx1 matrices with same number of rows, got ~a and ~a"
c d)]
[else
(for/sum: : Number ([i (in-range 0 m)])
(assert i index?)
; Note: If d is a vector of reals,
; then the conjugate is a no-op
(* (unsafe-vector-ref v i)
(conjugate (unsafe-vector-ref w i))))]))
(: column-norm : (Column Number) -> Real)
(define (column-norm v)
(define norm (sqrt (column-dot v v)))
(assert norm real?))
(: column-project : (Column Number) (Column Number) -> (Result-Column Number))
; (column-project v w)
; Return the projection og vector v on vector w.
(define (column-project v w)
(let ([w.w (column-dot w w)])
(if (zero? w.w)
(error 'column-project "projection on the zero vector not defined")
(matrix-scale (->col-matrix w) (/ (column-dot v w) w.w)))))
(: column-project/unit : (Column Number) (Column Number) -> (Result-Column Number))
; (column-project-on-unit v w)
; Return the projection og vector v on a unit vector w.
(define (column-project/unit v w)
(matrix-scale (->col-matrix w) (column-dot v w)))
(: column-normalize : (Column Number) -> (Result-Column Number))
; (column-vector-normalize v)
; Return unit vector with same direction as v.
; If v is the zero vector, the zero vector is returned.
(define (column-normalize w)
(let ([norm (column-norm w)]
[w (->col-matrix w)])
(cond [(zero? norm) w]
[else (matrix-scale w (/ norm))])))

View File

@ -2,7 +2,7 @@
(require (for-syntax racket/base
syntax/parse)
math/array)
"../array/array-comprehension.rkt")
(provide for/matrix:
for*/matrix:
@ -12,7 +12,7 @@
(module typed-defs typed/racket/base
(require (for-syntax racket/base
syntax/parse)
math/array)
"../array/array-comprehension.rkt")
(provide (all-defined-out))

View File

@ -3,9 +3,12 @@
(require racket/fixnum
racket/list
racket/vector
math/array
"matrix-types.rkt"
"../unsafe.rkt")
"../unsafe.rkt"
"../array/array-struct.rkt"
"../array/array-constructors.rkt"
"../array/array-unfold.rkt"
"../array/utils.rkt")
(provide identity-matrix
make-matrix
@ -42,7 +45,7 @@
;; ===================================================================================================
;; Diagonal matrices
(: diagonal-matrix/zero (All (A) (Listof A) A -> (Array A)))
(: diagonal-matrix/zero (All (A) ((Listof A) A -> (Matrix A))))
(define (diagonal-matrix/zero xs zero)
(cond [(empty? xs)
(raise-argument-error 'diagonal-matrix "nonempty List" xs)]
@ -56,15 +59,14 @@
(cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)]
[else zero])))]))
(: diagonal-matrix (case-> ((Listof Real) -> (Array Real))
((Listof Number) -> (Array Number))))
(: diagonal-matrix (All (A) ((Listof A) -> (Matrix (U A 0)))))
(define (diagonal-matrix xs)
(diagonal-matrix/zero xs 0))
;; ===================================================================================================
;; Block diagonal matrices
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Matrix A)) A -> (Matrix A)))
(define (block-diagonal-matrix/zero* as zero)
(define num (vector-length as))
(define-values (ms ns)
@ -94,7 +96,7 @@
(vector-set! hs (unsafe-fx+ res-j j) k)
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
(define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as))
(define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as))
(unsafe-build-array
((inst vector Index) res-m res-n)
(λ: ([ij : Indexes])
@ -114,7 +116,7 @@
[else
zero]))))
(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A)))
(: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A))))
(define (block-diagonal-matrix/zero as zero)
(let ([as (list->vector as)])
(define num (vector-length as))
@ -125,8 +127,7 @@
[else
(block-diagonal-matrix/zero* as zero)])))
(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real))
((Listof (Array Number)) -> (Array Number))))
(: block-diagonal-matrix (All (A) ((Listof (Matrix A)) -> (Matrix (U A 0)))))
(define (block-diagonal-matrix as)
(block-diagonal-matrix/zero as 0))
@ -140,8 +141,8 @@
(cond [(real? x) (assert (expt x n) real?)]
[else (expt x n)]))
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real))
((Listof Number) Integer -> (Array Number))))
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Matrix Real))
((Listof Number) Integer -> (Matrix Number))))
(define (vandermonde-matrix xs n)
(cond [(empty? xs)
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]

View File

@ -3,9 +3,13 @@
(require racket/fixnum
racket/list
racket/vector
math/array
"matrix-types.rkt"
"utils.rkt"
"../array/array-struct.rkt"
"../array/array-convert.rkt"
"../array/array-transform.rkt"
"../array/mutable-array.rkt"
"../array/array-fold.rkt"
"../array/utils.rkt"
"../unsafe.rkt")
@ -24,9 +28,9 @@
matrix->vector*)
;; ===================================================================================================
;; Flat conversion
;; Flat conversion to rectangular matrices
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Matrix A))))
(define (list->matrix m n xs)
(cond [(or (not (index? m)) (= m 0))
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
@ -34,7 +38,7 @@
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
[else (list->array (vector m n) xs)]))
(: matrix->list (All (A) ((Array A) -> (Listof A))))
(: matrix->list (All (A) ((Matrix A) -> (Listof A))))
(define (matrix->list a)
(array->list (ensure-matrix 'matrix->list a)))
@ -46,26 +50,18 @@
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
[else (vector->array (vector m n) v)]))
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
(: matrix->vector (All (A) ((Matrix A) -> (Vectorof A))))
(define (matrix->vector a)
(array->vector (ensure-matrix 'matrix->vector a)))
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
(define (list->row-matrix xs)
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
;; ===================================================================================================
;; Flat conversion to column and row matrices
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
(: list->col-matrix (All (A) ((Listof A) -> (Matrix A))))
(define (list->col-matrix xs)
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
(define (vector->row-matrix xs)
(define n (vector-length xs))
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
[else (vector->array ((inst vector Index) 1 n) xs)]))
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
(define (vector->col-matrix xs)
(define n (vector-length xs))
@ -80,40 +76,15 @@
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
[else (values 0 0)])))
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
(define (array->row-matrix arr)
(define (fail)
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
(define ds (array-shape arr))
(define dims (vector-length ds))
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
(cond [(zero? (array-size arr)) (fail)]
[(row-matrix? arr) arr]
[(= num-ones dims)
(define: js : (Vectorof Index) (make-vector dims 0))
(define proc (unsafe-array-proc arr))
(unsafe-build-array ((inst vector Index) 1 1)
(λ: ([ij : Indexes]) (proc js)))]
[(= num-ones (- dims 1))
(define-values (k n) (find-nontrivial-axis ds))
(define js (make-thread-local-indexes dims))
(define proc (unsafe-array-proc arr))
(unsafe-build-array ((inst vector Index) 1 n)
(λ: ([ij : Indexes])
(let ([js (js)])
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
(proc js))))]
[else (fail)]))
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
(: array->col-matrix (All (A) ((Array A) -> (Matrix A))))
(define (array->col-matrix arr)
(define (fail)
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
(raise-argument-error 'array->col-matrix
"nonempty Array with exactly one axis of length >= 1" arr))
(define ds (array-shape arr))
(define dims (vector-length ds))
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
(cond [(zero? (array-size arr)) (fail)]
[(col-matrix? arr) arr]
[(= num-ones dims)
(define: js : (Vectorof Index) (make-vector dims 0))
(define proc (unsafe-array-proc arr))
@ -130,17 +101,19 @@
(proc js))))]
[else (fail)]))
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
(define (->row-matrix xs)
(cond [(list? xs) (list->row-matrix xs)]
[(array? xs) (array->row-matrix xs)]
[else (vector->row-matrix xs)]))
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Matrix A))))
(define (->col-matrix xs)
(cond [(list? xs) (list->col-matrix xs)]
[(array? xs) (array->col-matrix xs)]
[else (vector->col-matrix xs)]))
[(vector? xs) (vector->col-matrix xs)]
[(col-matrix? xs) xs]
[else (array->col-matrix xs)]))
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Matrix A))))
(define (->row-matrix xs)
(cond [(list? xs) (array-axis-swap (list->col-matrix xs) 0 1)]
[(vector? xs) (array-axis-swap (vector->col-matrix xs) 0 1)]
[(row-matrix? xs) xs]
[else (array-axis-swap (array->col-matrix xs) 0 1)]))
;; ===================================================================================================
;; Nested conversion

View File

@ -1,7 +1,6 @@
#lang typed/racket
#lang typed/racket/base
(require math/array
"matrix-types.rkt"
(require "matrix-types.rkt"
"matrix-constructors.rkt"
"matrix-arithmetic.rkt")

View File

@ -0,0 +1,62 @@
#lang typed/racket/base
(require racket/fixnum
racket/list
"matrix-types.rkt"
"matrix-conversion.rkt"
"utils.rkt"
"../unsafe.rkt"
"../vector/vector-mutate.rkt")
(provide
matrix-gauss-elim
matrix-row-echelon)
(: matrix-gauss-elim
(case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Any -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Any Any -> (Values (Matrix Real) (Listof Index)))
((Matrix Number) -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index)))))
(define (matrix-gauss-elim M [jordan? #f] [unitize-pivot? #f])
(define-values (m n) (matrix-shape M))
(define rows (matrix->vector* M))
(let loop ([#{i : Nonnegative-Fixnum} 0]
[#{j : Nonnegative-Fixnum} 0]
[#{without-pivot : (Listof Index)} empty])
(cond
[(j . fx>= . n)
(values (vector*->matrix rows)
(reverse without-pivot))]
[(i . fx>= . m)
(values (vector*->matrix rows)
;; None of the rest of the columns can have pivots
(let loop ([#{j : Nonnegative-Fixnum} j] [without-pivot without-pivot])
(cond [(j . fx< . n) (loop (fx+ j 1) (cons j without-pivot))]
[else (reverse without-pivot)])))]
[else
(define-values (p pivot) (find-partial-pivot rows m i j))
(cond
[(zero? pivot) (loop i (fx+ j 1) (cons j without-pivot))]
[else
;; Swap pivot row with current
(vector-swap! rows i p)
;; Possibly unitize the new current row
(let ([pivot (if unitize-pivot?
(begin (vector-scale! (unsafe-vector-ref rows i) (/ pivot))
1)
pivot)])
(elim-rows! rows m i j pivot (if jordan? 0 (fx+ i 1)))
(loop (fx+ i 1) (fx+ j 1) without-pivot))])])))
(: matrix-row-echelon
(case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) Any -> (Matrix Real))
((Matrix Real) Any Any -> (Matrix Real))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Any -> (Matrix Number))
((Matrix Number) Any Any -> (Matrix Number))))
(define (matrix-row-echelon M [jordan? #f] [unitize-pivot? jordan?])
(let-values ([(M _) (matrix-gauss-elim M jordan? unitize-pivot?)])
M))

View File

@ -0,0 +1,80 @@
#lang typed/racket/base
(require racket/fixnum
racket/list
"matrix-types.rkt"
"matrix-basic.rkt"
"matrix-conversion.rkt"
"matrix-constructors.rkt"
"utils.rkt"
"../unsafe.rkt"
"../vector/vector-mutate.rkt"
"../array/array-struct.rkt"
"../array/array-constructors.rkt"
"../array/array-indexing.rkt")
(provide matrix-gram-schmidt
matrix-basis-extension)
(: find-nonzero-vector (case-> ((Vectorof (Vectorof Real)) -> (U #f Index))
((Vectorof (Vectorof Number)) -> (U #f Index))))
(define (find-nonzero-vector vss)
(define n (vector-length vss))
(cond [(= n 0) #f]
[else (let loop ([#{i : Nonnegative-Fixnum} 0])
(cond [(i . fx< . n)
(define vs (unsafe-vector-ref vss i))
(if (vector-zero? vs) (loop (fx+ i 1)) i)]
[else #f]))]))
(: subtract-projections!
(case-> ((Vectorof (Vectorof Real)) Nonnegative-Fixnum Index (Vectorof Real) -> Void)
((Vectorof (Vectorof Number)) Nonnegative-Fixnum Index (Vectorof Number) -> Void)))
(define (subtract-projections! rows i m row)
(let loop ([#{i : Nonnegative-Fixnum} i])
(when (i . fx< . m)
(vector-sub-proj! (unsafe-vector-ref rows i) row #f)
(loop (fx+ i 1)))))
(: matrix-gram-schmidt (case-> ((Matrix Real) -> (Array Real))
((Matrix Real) Any -> (Array Real))
((Matrix Real) Any Integer -> (Array Real))
((Matrix Number) -> (Array Number))
((Matrix Number) Any -> (Array Number))
((Matrix Number) Any Integer -> (Array Number))))
;; Performs Gram-Schmidt orthogonalization on M, assuming the rows before `start' are already
;; orthogonal
(define (matrix-gram-schmidt M [normalize? #f] [start 0])
(define rows (matrix->vector* (matrix-transpose M)))
(define m (vector-length rows))
(define i (find-nonzero-vector rows))
(cond [(not (index? start))
(raise-argument-error 'matrix-gram-schmidt "Index" 2 M normalize? start)]
[i
(define rowi (unsafe-vector-ref rows i))
(subtract-projections! rows (fxmax start (fx+ i 1)) m rowi)
(when normalize? (vector-normalize! rowi))
(let loop ([#{i : Nonnegative-Fixnum} (fx+ i 1)] [bs (list rowi)])
(cond [(i . fx< . m)
(define rowi (unsafe-vector-ref rows i))
(cond [(vector-zero? rowi) (loop (fx+ i 1) bs)]
[else (subtract-projections! rows (fxmax start (fx+ i 1)) m rowi)
(when normalize? (vector-normalize! rowi))
(loop (fx+ i 1) (cons rowi bs))])]
[else
(matrix-transpose (vector*->matrix (list->vector (reverse bs))))]))]
[else
(make-array (vector (matrix-num-rows M) 0) 0)]))
(: matrix-basis-extension (case-> ((Matrix Real) -> (Array Real))
((Matrix Number) -> (Array Number))))
(define (matrix-basis-extension B)
(define-values (m n) (matrix-shape B))
(cond [(n . < . m)
(define S (matrix-gram-schmidt (matrix-augment (list B (identity-matrix m))) #f n))
(define R (submatrix S (::) (:: n #f)))
(matrix-augment (take (sort/key (matrix-cols R) > matrix-norm) (- m n)))]
[(n . = . m)
(make-array (vector m 0) 0)]
[else
(raise-argument-error 'matrix-extend-row-basis "matrix? with width < height" B)]))

View File

@ -0,0 +1,59 @@
#lang typed/racket/base
(require racket/fixnum
"matrix-types.rkt"
"matrix-conversion.rkt"
"matrix-arithmetic.rkt"
"utils.rkt"
"../unsafe.rkt"
"../vector/vector-mutate.rkt"
"../array/mutable-array.rkt")
(provide matrix-lu)
;; An LU factorization exists iff Gaussian elimination can be done without row swaps.
(: matrix-lu
(All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number))))))
(define matrix-lu
(case-lambda
[(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))]
[(M fail)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
;; Construct L in a weird way to prove to TR that it has the right type
(define L (array->mutable-array (matrix-scale M (ann 0 Real))))
;; Going to fill in the lower triangle by banging values into `ys'
(define ys (mutable-array-data L))
(let loop ([#{i : Nonnegative-Fixnum} 0])
(cond
[(i . fx< . m)
;; Pivot must be on the diagonal
(define pivot (unsafe-vector2d-ref rows i i))
(cond
[(zero? pivot) (values (fail) M)]
[else
;; Zero out everything below the pivot
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond
[(l . fx< . m)
(define x_li (unsafe-vector2d-ref rows l i))
(define y_li (/ x_li pivot))
(unless (zero? x_li)
;; Fill in lower triangle of L
(unsafe-vector-set! ys (+ (* l m) i) y_li)
;; Add row i, scaled
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- y_li)))
(l-loop (fx+ l 1))]
[else
(loop (fx+ i 1))]))])]
[else
;; L's lower triangle has been filled; now fill the diagonal with 1s
(for: ([i : Integer (in-range 0 m)])
(vector-set! ys (+ (* i m) i) 1))
(values L (vector*->matrix rows))]))]))

View File

@ -1,442 +0,0 @@
#lang typed/racket/base
(require racket/fixnum
racket/list
racket/match
math/array
(only-in typed/racket conjugate)
"../unsafe.rkt"
"../vector/vector-mutate.rkt"
"matrix-types.rkt"
"matrix-constructors.rkt"
"matrix-conversion.rkt"
"matrix-arithmetic.rkt"
"matrix-basic.rkt"
"matrix-column.rkt"
"utils.rkt"
(for-syntax racket))
(provide
;; Gaussian elimination
matrix-gauss-elim
matrix-row-echelon
;; Derived functions
matrix-rank
matrix-nullity
matrix-determinant
matrix-determinant/row-reduction ; for testing
;; Spaces
matrix-column-space
;; Solving
matrix-invertible?
matrix-inverse
matrix-solve
;; Projection
projection-on-orthogonal-basis
projection-on-orthonormal-basis
projection-on-subspace
gram-schmidt-orthogonal
gram-schmidt-orthonormal
;; Decomposition
matrix-lu
matrix-qr
)
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
(define (unsafe-vector2d-ref vss i j)
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
;; ===================================================================================================
;; Gaussian elimination
(: find-partial-pivot
(case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real))
((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number))))
;; Find the element with maximum magnitude in a column
(define (find-partial-pivot rows m i j)
(define l (fx+ i 1))
(define pivot (unsafe-vector2d-ref rows i j))
(define mag-pivot (magnitude pivot))
(let loop ([#{l : Nonnegative-Fixnum} l] [#{p : Index} i] [pivot pivot] [mag-pivot mag-pivot])
(cond [(l . fx< . m)
(define new-pivot (unsafe-vector2d-ref rows l j))
(define mag-new-pivot (magnitude new-pivot))
(cond [(mag-new-pivot . > . mag-pivot) (loop (fx+ l 1) l new-pivot mag-new-pivot)]
[else (loop (fx+ l 1) p pivot mag-pivot)])]
[else (values p pivot)])))
(: elim-rows!
(case-> ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void)
((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void)))
(define (elim-rows! rows m i j pivot start)
(let loop ([#{l : Nonnegative-Fixnum} start])
(when (l . fx< . m)
(unless (l . fx= . i)
(define x_lj (unsafe-vector2d-ref rows l j))
(unless (zero? x_lj)
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- (/ x_lj pivot)))))
(loop (fx+ l 1)))))
(: matrix-gauss-elim (case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Any -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Any Any -> (Values (Matrix Real) (Listof Index)))
((Matrix Number) -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index)))))
(define (matrix-gauss-elim M [jordan? #f] [unitize-pivot-row? #f])
(define-values (m n) (matrix-shape M))
(define rows (matrix->vector* M))
(let loop ([#{i : Nonnegative-Fixnum} 0]
[#{j : Nonnegative-Fixnum} 0]
[#{without-pivot : (Listof Index)} empty])
(cond
[(j . fx>= . n)
(values (vector*->matrix rows)
(reverse without-pivot))]
[(i . fx>= . m)
(values (vector*->matrix rows)
;; None of the rest of the columns can have pivots
(let loop ([#{j : Nonnegative-Fixnum} j] [without-pivot without-pivot])
(cond [(j . fx< . n) (loop (fx+ j 1) (cons j without-pivot))]
[else (reverse without-pivot)])))]
[else
(define-values (p pivot) (find-partial-pivot rows m i j))
(cond
[(zero? pivot) (loop i (fx+ j 1) (cons j without-pivot))]
[else
;; Swap pivot row with current
(vector-swap! rows i p)
;; Possibly unitize the new current row
(let ([pivot (if unitize-pivot-row?
(begin (vector-scale! (unsafe-vector-ref rows i) (/ pivot))
1)
pivot)])
(elim-rows! rows m i j pivot (if jordan? 0 (fx+ i 1)))
(loop (fx+ i 1) (fx+ j 1) without-pivot))])])))
;; ===================================================================================================
;; Simple functions derived from Gaussian elimination
(: matrix-row-echelon
(case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) Any -> (Matrix Real))
((Matrix Real) Any Any -> (Matrix Real))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Any -> (Matrix Number))
((Matrix Number) Any Any -> (Matrix Number))))
(define (matrix-row-echelon M [jordan? #f] [unitize-pivot-row? jordan?])
(let-values ([(M _) (matrix-gauss-elim M jordan? unitize-pivot-row?)])
M))
(: matrix-rank : (Matrix Number) -> Index)
;; Returns the dimension of the column space (equiv. row space) of M
(define (matrix-rank M)
(define n (matrix-num-cols M))
(define-values (_ cols-without-pivot) (matrix-gauss-elim M))
(assert (- n (length cols-without-pivot)) index?))
(: matrix-nullity : (Matrix Number) -> Index)
;; Returns the dimension of the null space of M
(define (matrix-nullity M)
(define-values (_ cols-without-pivot)
(matrix-gauss-elim (ensure-matrix 'matrix-nullity M)))
(length cols-without-pivot))
(: maybe-cons-submatrix (All (A) ((Matrix A) Nonnegative-Fixnum Nonnegative-Fixnum (Listof (Matrix A))
-> (Listof (Matrix A)))))
(define (maybe-cons-submatrix M j0 j1 Bs)
(cond [(= j0 j1) Bs]
[else (cons (submatrix M (::) (:: j0 j1)) Bs)]))
(: matrix-column-space (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-column-space
(case-lambda
[(M) (matrix-column-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))]
[(M fail)
(define n (matrix-num-cols M))
(define-values (_ wps) (matrix-gauss-elim M))
(cond [(empty? wps) M]
[(= (length wps) n) (fail)]
[else
(define next-j (first wps))
(define Bs (maybe-cons-submatrix M 0 next-j empty))
(let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs])
(cond [(empty? wps)
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
[else
(define next-j (first wps))
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])]))
;; ===================================================================================================
;; Determinant
(: matrix-determinant (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number)))
(define (matrix-determinant M)
(define m (square-matrix-size M))
(cond
[(= m 1) (matrix-ref M 0 0)]
[(= m 2) (match-define (vector a b c d)
(mutable-array-data (array->mutable-array M)))
(- (* a d) (* b c))]
[(= m 3) (match-define (vector a b c d e f g h i)
(mutable-array-data (array->mutable-array M)))
(+ (* a (- (* e i) (* f h)))
(* (- b) (- (* d i) (* f g)))
(* c (- (* d h) (* e g))))]
[else
(matrix-determinant/row-reduction M)]))
(: matrix-determinant/row-reduction (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number)))
(define (matrix-determinant/row-reduction M)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
(let loop ([#{i : Nonnegative-Fixnum} 0] [#{sign : Real} 1])
(cond
[(i . fx< . m)
(define-values (p pivot) (find-partial-pivot rows m i i))
(cond
[(zero? pivot) 0] ; no pivot means non-invertible matrix
[else
(vector-swap! rows i p) ; negates determinant if i != p
(elim-rows! rows m i i pivot (fx+ i 1)) ; doesn't change the determinant
(loop (fx+ i 1) (if (= i p) sign (* -1 sign)))])]
[else
(define prod (unsafe-vector2d-ref rows 0 0))
(let loop ([#{i : Nonnegative-Fixnum} 1] [prod prod])
(cond [(i . fx< . m)
(loop (fx+ i 1) (* prod (unsafe-vector2d-ref rows i i)))]
[else (* prod sign)]))])))
;; ===================================================================================================
;; Inversion and solving linear systems
(: matrix-invertible? ((Matrix Number) -> Boolean))
(define (matrix-invertible? M)
(not (zero? (matrix-determinant M))))
(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-inverse
(case-lambda
[(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))]
[(M fail)
(define m (square-matrix-size M))
(define I (identity-matrix m))
(define-values (IM^-1 wps) (matrix-gauss-elim (matrix-augment (list M I)) #t #t))
(cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IM^-1 (::) (:: m #f))]
[else (fail)])]))
(: matrix-solve (All (A) (case->
((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Real) (Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) (Matrix Number) -> (Matrix Number))
((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-solve
(case-lambda
[(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))]
[(M B fail)
(define m (square-matrix-size M))
(define-values (s t) (matrix-shape B))
(cond [(= m s)
(define-values (IX wps) (matrix-gauss-elim (matrix-augment (list M B)) #t #t))
(cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IX (::) (:: m #f))]
[else (fail)])]
[else
(error 'matrix-solve
"matrices must have the same number of rows; given ~e and ~e"
M B)])]))
;; ===================================================================================================
;; LU Factorization
;; An LU factorization exists iff Gaussian elimination can be done without row swaps.
(: matrix-lu
(All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number))))))
(define matrix-lu
(case-lambda
[(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))]
[(M fail)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
;; Construct L in a weird way to prove to TR that it has the right type
(define L (array->mutable-array (matrix-scale M (ann 0 Real))))
;; Going to fill in the lower triangle by banging values into `ys'
(define ys (mutable-array-data L))
(let loop ([#{i : Nonnegative-Fixnum} 0])
(cond
[(i . fx< . m)
;; Pivot must be on the diagonal
(define pivot (unsafe-vector2d-ref rows i i))
(cond
[(zero? pivot) (values (fail) M)]
[else
;; Zero out everything below the pivot
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond
[(l . fx< . m)
(define x_li (unsafe-vector2d-ref rows l i))
(define y_li (/ x_li pivot))
(unless (zero? x_li)
;; Fill in lower triangle of L
(unsafe-vector-set! ys (+ (* l m) i) y_li)
;; Add row i, scaled
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- y_li)))
(l-loop (fx+ l 1))]
[else
(loop (fx+ i 1))]))])]
[else
;; L's lower triangle has been filled; now fill the diagonal with 1s
(for: ([i : Integer (in-range 0 m)])
(vector-set! ys (+ (* i m) i) 1))
(values L (vector*->matrix rows))]))]))
;; ===================================================================================================
;; Projections and orthogonalization
(: projection-on-orthogonal-basis :
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
; (projection-on-orthogonal-basis v bs)
; Project the vector v on the orthogonal basis vectors in bs.
; The basis bs must be either the column vectors of a matrix
; or a sequence of column-vectors.
(define (projection-on-orthogonal-basis v bs)
(if (null? bs)
(error 'projection-on-orthogonal-basis
"received empty list of basis vectors")
(matrix-sum (map (λ: ([b : (Column Number)])
(column-project v (->col-matrix b)))
bs))))
; (projection-on-orthonormal-basis v bs)
; Project the vector v on the orthonormal basis vectors in bs.
; The basis bs must be either the column vectors of a matrix
; or a sequence of column-vectors.
(: projection-on-orthonormal-basis :
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
(define (projection-on-orthonormal-basis v bs)
#;(for/matrix-sum ([b bs]) (matrix-scale b (column-dot v b)))
(define: sum : (U False (Result-Column Number)) #f)
(for ([b1 (in-list bs)])
(define: b : (Result-Column Number) (->col-matrix b1))
(cond [(not sum) (set! sum (column-project/unit v b))]
[else (set! sum (array+ (assert sum) (column-project/unit v b)))]))
(cond [sum (assert sum)]
[else (error 'projection-on-orthonormal-basis
"received empty list of basis vectors")]))
(: gram-schmidt-orthogonal : (Listof (Column Number)) -> (Listof (Result-Column Number)))
; (gram-schmidt-orthogonal ws)
; Given a list ws of column vectors, produce
; an orthogonal basis for the span of the
; vectors in ws.
(define (gram-schmidt-orthogonal ws1)
(define ws (map (λ: ([w : (Column Number)]) (->col-matrix w)) ws1))
(cond
[(null? ws) '()]
[(null? (cdr ws)) (list (car ws))]
[else
(: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number))
-> (Listof (Result-Column Number)))
(define (loop vs ws)
(cond [(null? ws) vs]
[else
(define w (car ws))
(let ([w-proj (projection-on-orthogonal-basis w vs)])
; Note: We project onto vs (not on the original ws)
; in order to get numerical stability.
(let ([w-minus-proj (array-strict (array- w w-proj))])
(if (matrix-zero? w-minus-proj)
(loop vs (cdr ws)) ; w in span{vs} => omit it
(loop (cons w-minus-proj vs) (cdr ws)))))]))
(reverse (loop (list (car ws)) (cdr ws)))]))
(: gram-schmidt-orthonormal : (Listof (Column Number)) -> (Listof (Result-Column Number)))
; (gram-schmidt-orthonormal ws)
; Given a list ws of column vectors, produce
; an orthonormal basis for the span of the
; vectors in ws.
(define (gram-schmidt-orthonormal ws)
(map column-normalize (gram-schmidt-orthogonal ws)))
(: projection-on-subspace :
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
; (projection-on-subspace v ws)
; Returns the projection of v on span{w_i}, w_i in ws.
(define (projection-on-subspace v ws)
(projection-on-orthogonal-basis v (gram-schmidt-orthogonal ws)))
(: extend-span-to-basis :
(Listof (Matrix Number)) Integer -> (Listof (Matrix Number)))
; Extend the basis in vs to r-dimensional basis
(define (extend-span-to-basis vs r)
(define-values (m n) (matrix-shape (car vs)))
(: loop : (Listof (Matrix Number)) (Listof (Matrix Number)) Integer -> (Listof (Matrix Number)))
(define (loop vs ws i)
(if (>= i m)
ws
(let ()
(define ei (unit-column m i))
(define pi (projection-on-subspace ei vs))
(if (matrix= ei pi)
(loop vs ws (+ i 1))
(let ([w (array- ei pi)])
(loop (cons w vs) (cons w ws) (+ i 1)))))))
(: norm> : (Matrix Number) (Matrix Number) -> Boolean)
(define (norm> v w)
(> (column-norm v) (column-norm w)))
(if (index? r)
(take (sort (loop vs '() 0) norm>) r)
(error 'extend-span-to-basis "expected index as second argument, got ~a" r)))
;; ===================================================================================================
;; QR decomposition
(: matrix-qr : (Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
(define (matrix-qr M)
; compute the QR-facorization
; 1) QR = M
; 2) columns of Q is are orthonormal
; 3) R is upper-triangular
; Note: columnspace(A)=columnspace(Q) !
(define-values (m n) (matrix-shape M))
(let* ([basis-for-column-space
(gram-schmidt-orthonormal (matrix-cols M))]
[extension
(extend-span-to-basis
basis-for-column-space (- n (length basis-for-column-space)))]
[Q (matrix-augment
(append basis-for-column-space
(map column-normalize
extension)))]
[R
(let ()
(define v (make-vector (* n n) (ann 0 Number)))
(for*: ([i (in-range 0 n)]
[j (in-range 0 n)])
(if (> i j)
(void) ; v(i,j)=0 already
(let ()
(define: sum : Number 0)
(for: ([k (in-range m)])
(set! sum (+ sum (* (matrix-ref Q k i)
(matrix-ref M k j)))))
(vector-set! v (+ (* i n) j) sum))))
(vector->matrix n n v))])
(values Q R)))

View File

@ -0,0 +1,121 @@
#lang typed/racket/base
#|
Two of the functions defined here currently just raise an error: `matrix-op-2norm' and
`matrix-op-angle'. They need to compute, respectively, the maximum and minimum singular values of
their matrix argument.
See "How to Measure Errors" in the LAPACK manual for more details:
http://www.netlib.org/lapack/lug/node75.html
http://www.netlib.org/lapack/lug/node76.html
|#
(require racket/list
racket/fixnum
math/flonum
"matrix-types.rkt"
"matrix-arithmetic.rkt"
"matrix-constructors.rkt"
"matrix-basic.rkt"
"utils.rkt"
"../array/array-struct.rkt"
"../array/array-pointwise.rkt"
"../array/array-fold.rkt"
)
(provide
;; Operator norms
matrix-op-1norm
matrix-op-2norm
matrix-op-inf-norm
matrix-basis-angle
;; Error measurement
matrix-error-norm
matrix-absolute-error
matrix-relative-error
;; Approximate predicates
matrix-identity?
matrix-orthonormal?
)
(: matrix-op-1norm ((Matrix Number) -> Nonnegative-Real))
;; When M is a column matrix, this is equivalent to matrix-1norm
(define (matrix-op-1norm M)
(assert (apply max (map matrix-1norm (matrix-cols M))) nonnegative?))
(: matrix-op-2norm ((Matrix Number) -> Nonnegative-Real))
;; When M is a column matrix, this is equivalent to matrix-2norm
(define (matrix-op-2norm M)
;(matrix-max-singular-value M)
;(sqrt (matrix-max-eigenvalue M))
(error 'unimplemented))
(: matrix-op-inf-norm ((Matrix Number) -> Nonnegative-Real))
;; When M is a column matrix, this is equivalent to matrix-inf-norm
(define (matrix-op-inf-norm M)
(assert (apply max (map matrix-1norm (matrix-rows M))) nonnegative?))
(: matrix-basis-angle (case-> ((Matrix Real) (Matrix Real) -> Real)
((Matrix Number) (Matrix Number) -> Number)))
;; Returns the angle between the two subspaces spanned by the two given sets of column vectors
(define (matrix-basis-angle M R)
;(acos (matrix-min-singular-value (matrix* (matrix-hermitian M) R)))
(error 'unimplemented))
;; ===================================================================================================
;; Error measurement
(: matrix-error-norm (Parameterof ((Matrix Number) -> Nonnegative-Real)))
(define matrix-error-norm (make-parameter matrix-op-inf-norm))
(: matrix-absolute-error
(case-> ((Matrix Number) (Matrix Number) -> Nonnegative-Real)
((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real)
-> Nonnegative-Real)))
(define (matrix-absolute-error M R [norm (matrix-error-norm)])
(define-values (m n) (matrix-shapes 'matrix-absolute-error M R))
(array-strict! M)
(array-strict! R)
(cond [(array-all-and (inline-array-map eqv? M R)) 0]
[(and (array-all-and (inline-array-map number-rational? M))
(array-all-and (inline-array-map number-rational? R)))
(norm (matrix- (inline-array-map inexact->exact M)
(inline-array-map inexact->exact R)))]
[else +inf.0]))
(: matrix-relative-error
(case-> ((Matrix Number) (Matrix Number) -> Nonnegative-Real)
((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real)
-> Nonnegative-Real)))
(define (matrix-relative-error M R [norm (matrix-error-norm)])
(define-values (m n) (matrix-shapes 'matrix-relative-error M R))
(array-strict! M)
(array-strict! R)
(cond [(array-all-and (inline-array-map eqv? M R)) 0]
[(and (array-all-and (inline-array-map number-rational? M))
(array-all-and (inline-array-map number-rational? R)))
(define num (norm (matrix- M R)))
(define den (norm R))
(cond [(and (zero? num) (zero? den)) 0]
[(zero? den) +inf.0]
[else (assert (/ num den) nonnegative?)])]
[else +inf.0]))
;; ===================================================================================================
;; Approximate predicates
(: matrix-identity? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean)))
(define (matrix-identity? M [eps (* 10 epsilon.0)])
(cond [(eps . < . 0) (raise-argument-error 'matrix-identity? "Nonnegative-Real" 1 M eps)]
[else (and (square-matrix? M)
(<= (matrix-relative-error M (identity-matrix (square-matrix-size M))) eps))]))
(: matrix-orthonormal? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean)))
(define (matrix-orthonormal? M [eps (* 10 epsilon.0)])
(cond [(eps . < . 0) (raise-argument-error 'matrix-orthonormal? "Nonnegative-Real" 1 M eps)]
[else (and (square-matrix? M)
(matrix-identity? (matrix* M (matrix-hermitian M)) eps))]))

View File

@ -0,0 +1,39 @@
#lang typed/racket/base
(require "matrix-types.rkt"
"matrix-basic.rkt"
"matrix-arithmetic.rkt"
"matrix-constructors.rkt"
"matrix-gram-schmidt.rkt"
"../array/array-transform.rkt")
(provide matrix-qr)
#|
QR decomposition currently does Gram-Schmidt twice, as suggested by
Luc Giraud, Julien Langou, Miroslav Rozloznik.
On the round-off error analysis of the Gram-Schmidt algorithm with reorthogonalization.
Technical Report, 2002.
It normalizes only the second time.
I've verified experimentally that, with random, square matrices (elements in [0,1]), doing so
produces matrices for which `matrix-orthogonal?' returns #t with eps <= 10*epsilon.0, apparently
independently of the matrix size.
|#
(: matrix-qr (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number)))))
(define (matrix-qr M [full? #t])
(define B (matrix-gram-schmidt M #f))
(define Q
(matrix-gram-schmidt
(cond [(or (square-matrix? B) (and (matrix? B) (not full?))) B]
[(matrix? B) (array-append* (list B (matrix-basis-extension B)) 1)]
[full? (identity-matrix (matrix-num-rows M))]
[else (matrix-col (identity-matrix (matrix-num-rows M)) 0)])
#t))
(values Q (matrix-upper-triangle (matrix* (matrix-hermitian Q) M))))

View File

@ -0,0 +1,107 @@
#lang typed/racket/base
(require racket/fixnum
racket/match
racket/list
"matrix-types.rkt"
"matrix-constructors.rkt"
"matrix-conversion.rkt"
"matrix-basic.rkt"
"matrix-gauss-elim.rkt"
"utils.rkt"
"../vector/vector-mutate.rkt"
"../array/array-indexing.rkt"
"../array/mutable-array.rkt")
(provide
matrix-determinant
matrix-determinant/row-reduction ; for testing
matrix-invertible?
matrix-inverse
matrix-solve)
;; ===================================================================================================
;; Determinant
(: matrix-determinant (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number)))
(define (matrix-determinant M)
(define m (square-matrix-size M))
(cond
[(= m 1) (matrix-ref M 0 0)]
[(= m 2) (match-define (vector a b c d)
(mutable-array-data (array->mutable-array M)))
(- (* a d) (* b c))]
[(= m 3) (match-define (vector a b c d e f g h i)
(mutable-array-data (array->mutable-array M)))
(+ (* a (- (* e i) (* f h)))
(* (- b) (- (* d i) (* f g)))
(* c (- (* d h) (* e g))))]
[else
(matrix-determinant/row-reduction M)]))
(: matrix-determinant/row-reduction (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number)))
(define (matrix-determinant/row-reduction M)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
(let loop ([#{i : Nonnegative-Fixnum} 0] [#{sign : Real} 1])
(cond
[(i . fx< . m)
(define-values (p pivot) (find-partial-pivot rows m i i))
(cond
[(zero? pivot) 0] ; no pivot means non-invertible matrix
[else
(let ([sign (if (= i p) sign (begin (vector-swap! rows i p) ; swapping negates sign
(* -1 sign)))])
(elim-rows! rows m i i pivot (fx+ i 1)) ; adding scaled rows doesn't change it
(loop (fx+ i 1) sign))])]
[else
(define prod (unsafe-vector2d-ref rows 0 0))
(let loop ([#{i : Nonnegative-Fixnum} 1] [prod prod])
(cond [(i . fx< . m)
(loop (fx+ i 1) (* prod (unsafe-vector2d-ref rows i i)))]
[else (* prod sign)]))])))
;; ===================================================================================================
;; Inversion and solving linear systems
(: matrix-invertible? ((Matrix Number) -> Boolean))
(define (matrix-invertible? M)
(not (zero? (matrix-determinant M))))
(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-inverse
(case-lambda
[(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))]
[(M fail)
(define m (square-matrix-size M))
(define I (identity-matrix m))
(define-values (IM^-1 wps) (matrix-gauss-elim (matrix-augment (list M I)) #t #t))
(cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IM^-1 (::) (:: m #f))]
[else (fail)])]))
(: matrix-solve (All (A) (case->
((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Real) (Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) (Matrix Number) -> (Matrix Number))
((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-solve
(case-lambda
[(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))]
[(M B fail)
(define m (square-matrix-size M))
(define-values (s t) (matrix-shape B))
(cond [(= m s)
(define-values (IX wps) (matrix-gauss-elim (matrix-augment (list M B)) #t #t))
(cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IX (::) (:: m #f))]
[else (fail)])]
[else
(error 'matrix-solve
"matrices must have the same number of rows; given ~e and ~e"
M B)])]))

View File

@ -0,0 +1,57 @@
#lang typed/racket/base
(require racket/fixnum
racket/list
"matrix-types.rkt"
"matrix-basic.rkt"
"matrix-gauss-elim.rkt"
"utils.rkt"
"../array/array-indexing.rkt"
"../array/array-constructors.rkt")
(provide
matrix-rank
matrix-nullity
matrix-col-space)
(: matrix-rank : (Matrix Number) -> Index)
;; Returns the dimension of the column space (equiv. row space) of M
(define (matrix-rank M)
(define n (matrix-num-cols M))
(define-values (_ cols-without-pivot) (matrix-gauss-elim M))
(assert (- n (length cols-without-pivot)) index?))
(: matrix-nullity : (Matrix Number) -> Index)
;; Returns the dimension of the null space of M
(define (matrix-nullity M)
(define-values (_ cols-without-pivot)
(matrix-gauss-elim (ensure-matrix 'matrix-nullity M)))
(length cols-without-pivot))
(: maybe-cons-submatrix (All (A) ((Matrix A) Nonnegative-Fixnum Nonnegative-Fixnum (Listof (Matrix A))
-> (Listof (Matrix A)))))
(define (maybe-cons-submatrix M j0 j1 Bs)
(cond [(= j0 j1) Bs]
[else (cons (submatrix M (::) (:: j0 j1)) Bs)]))
(: matrix-col-space (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-col-space
(case-lambda
[(M) (matrix-col-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))]
[(M fail)
(define n (matrix-num-cols M))
(define-values (_ wps) (matrix-gauss-elim M))
(cond [(empty? wps) M]
[(= (length wps) n) (fail)]
[else
(define next-j (first wps))
(define Bs (maybe-cons-submatrix M 0 next-j empty))
(let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs])
(cond [(empty? wps)
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
[else
(define next-j (first wps))
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])]))

View File

@ -3,7 +3,7 @@
(require (for-syntax racket/base
syntax/parse)
(only-in typed/racket/base :)
math/array)
"../array/array-struct.rkt")
(provide matrix row-matrix col-matrix)

View File

@ -1,10 +1,12 @@
#lang typed/racket/base
(require racket/list
math/array
"matrix-types.rkt"
"utils.rkt"
(except-in "untyped-matrix-arithmetic.rkt" matrix-map))
(except-in "untyped-matrix-arithmetic.rkt" matrix-map)
"../array/array-struct.rkt"
"../array/array-fold.rkt"
"../array/utils.rkt")
(provide matrix-map
matrix=
@ -14,16 +16,17 @@
matrix-scale
matrix-sum)
(: matrix-map (All (R A B T ...)
(case-> ((A -> R) (Array A) -> (Array R))
((A B T ... T -> R) (Array A) (Array B) (Array T) ... T -> (Array R)))))
(: matrix-map
(All (R A B T ...)
(case-> ((A -> R) (Matrix A) -> (Matrix R))
((A B T ... T -> R) (Matrix A) (Matrix B) (Matrix T) ... T -> (Matrix R)))))
(define matrix-map
(case-lambda:
[([f : (A -> R)] [arr : (Array A)])
[([f : (A -> R)] [arr : (Matrix A)])
(inline-matrix-map f arr)]
[([f : (A B -> R)] [arr0 : (Array A)] [arr1 : (Array B)])
[([f : (A B -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix B)])
(inline-matrix-map f arr0 arr1)]
[([f : (A B T ... T -> R)] [arr0 : (Array A)] [arr1 : (Array B)] . [arrs : (Array T) ... T])
[([f : (A B T ... T -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix B)] . [arrs : (Matrix T) ... T])
(define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs))
(define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1))
@ -33,7 +36,7 @@
(λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> T)]) (g js)) gs))))]))
(: matrix=? ((Array Number) (Array Number) -> Boolean))
(: matrix=? ((Matrix Number) (Matrix Number) -> Boolean))
(define (matrix=? arr0 arr1)
(define-values (m0 n0) (matrix-shape arr0))
(define-values (m1 n1) (matrix-shape arr1))
@ -46,35 +49,35 @@
(λ: ([js : Indexes])
(= (proc0 js) (proc1 js))))))))
(: matrix= (case-> ((Array Number) (Array Number) -> Boolean)
((Array Number) (Array Number) (Array Number) (Array Number) * -> Boolean)))
(: matrix= (case-> ((Matrix Number) (Matrix Number) -> Boolean)
((Matrix Number) (Matrix Number) (Matrix Number) (Matrix Number) * -> Boolean)))
(define matrix=
(case-lambda:
[([arr0 : (Array Number)] [arr1 : (Array Number)]) (matrix=? arr0 arr1)]
[([arr0 : (Array Number)] [arr1 : (Array Number)] . [arrs : (Array Number) *])
[([arr0 : (Matrix Number)] [arr1 : (Matrix Number)]) (matrix=? arr0 arr1)]
[([arr0 : (Matrix Number)] [arr1 : (Matrix Number)] . [arrs : (Matrix Number) *])
(and (matrix=? arr0 arr1)
(let: loop : Boolean ([arr1 : (Array Number) arr1]
[arrs : (Listof (Array Number)) arrs])
(let: loop : Boolean ([arr1 : (Matrix Number) arr1]
[arrs : (Listof (Matrix Number)) arrs])
(cond [(empty? arrs) #t]
[else (and (matrix=? arr1 (first arrs))
(loop (first arrs) (rest arrs)))])))]))
(: matrix* (case-> ((Array Real) (Array Real) * -> (Array Real))
((Array Number) (Array Number) * -> (Array Number))))
(: matrix* (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real))
((Matrix Number) (Matrix Number) * -> (Matrix Number))))
(define (matrix* a . as)
(let loop ([a a] [as as])
(cond [(empty? as) a]
[else (loop (inline-matrix* a (first as)) (rest as))])))
(: matrix+ (case-> ((Array Real) (Array Real) * -> (Array Real))
((Array Number) (Array Number) * -> (Array Number))))
(: matrix+ (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real))
((Matrix Number) (Matrix Number) * -> (Matrix Number))))
(define (matrix+ a . as)
(let loop ([a a] [as as])
(cond [(empty? as) a]
[else (loop (inline-matrix+ a (first as)) (rest as))])))
(: matrix- (case-> ((Array Real) (Array Real) * -> (Array Real))
((Array Number) (Array Number) * -> (Array Number))))
(: matrix- (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real))
((Matrix Number) (Matrix Number) * -> (Matrix Number))))
(define (matrix- a . as)
(cond [(empty? as) (inline-matrix- a)]
[else
@ -82,12 +85,12 @@
(cond [(empty? as) a]
[else (loop (inline-matrix- a (first as)) (rest as))]))]))
(: matrix-scale (case-> ((Array Real) Real -> (Array Real))
((Array Number) Number -> (Array Number))))
(: matrix-scale (case-> ((Matrix Real) Real -> (Matrix Real))
((Matrix Number) Number -> (Matrix Number))))
(define (matrix-scale a x) (inline-matrix-scale a x))
(: matrix-sum (case-> ((Listof (Array Real)) -> (Array Real))
((Listof (Array Number)) -> (Array Number))))
(: matrix-sum (case-> ((Listof (Matrix Real)) -> (Matrix Real))
((Listof (Matrix Number)) -> (Matrix Number))))
(define (matrix-sum lst)
(cond [(empty? lst) (raise-argument-error 'matrix-sum "nonempty List" lst)]
[else (apply matrix+ lst)]))

View File

@ -10,13 +10,16 @@
(module syntax-defs racket/base
(require (for-syntax racket/base)
(only-in typed/racket/base λ: : inst Index)
math/array
"matrix-types.rkt"
"utils.rkt")
"utils.rkt"
"../array/array-struct.rkt"
"../array/array-fold.rkt"
"../array/array-transform.rkt"
"../array/utils.rkt")
(provide (all-defined-out))
;(: matrix-multiply ((Array Number) (Array Number) -> (Array Number)))
;(: matrix-multiply ((Matrix Number) (Matrix Number) -> (Matrix Number)))
;; This is a macro so the result can have as precise a type as possible
(define-syntax-rule (inline-matrix-multiply arr-expr brr-expr)
(let ([arr arr-expr]
@ -69,26 +72,31 @@
(define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...))
(define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...))
(define-syntax-rule (inline-matrix-scale arr x) (inline-matrix-map (λ (y) (* x y)) arr))
(define-syntax-rule (inline-matrix-scale arr x-expr)
(let ([x x-expr])
(inline-matrix-map (λ (y) (* x y)) arr)))
) ; module
(module untyped-defs typed/racket/base
(require math/array
(submod ".." syntax-defs)
"utils.rkt")
(require (submod ".." syntax-defs)
"matrix-types.rkt"
"utils.rkt"
"../array/array-struct.rkt"
"../array/utils.rkt")
(provide matrix-map)
(: matrix-map (All (R A) (case-> ((A -> R) (Array A) -> (Array R))
((A A A * -> R) (Array A) (Array A) (Array A) * -> (Array R)))))
(: matrix-map
(All (R A) (case-> ((A -> R) (Matrix A) -> (Matrix R))
((A A A * -> R) (Matrix A) (Matrix A) (Matrix A) * -> (Matrix R)))))
(define matrix-map
(case-lambda:
[([f : (A -> R)] [arr : (Array A)])
[([f : (A -> R)] [arr : (Matrix A)])
(inline-matrix-map f arr)]
[([f : (A A -> R)] [arr0 : (Array A)] [arr1 : (Array A)])
[([f : (A A -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix A)])
(inline-matrix-map f arr0 arr1)]
[([f : (A A A * -> R)] [arr0 : (Array A)] [arr1 : (Array A)] . [arrs : (Array A) *])
[([f : (A A A * -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix A)] . [arrs : (Matrix A) *])
(define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs))
(define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1))

View File

@ -1,9 +1,11 @@
#lang typed/racket/base
(require racket/match
racket/string
math/array
"matrix-types.rkt")
(require racket/string
racket/fixnum
"matrix-types.rkt"
"../unsafe.rkt"
"../array/array-struct.rkt"
"../vector/vector-mutate.rkt")
(provide (all-defined-out))
@ -36,3 +38,61 @@
(: ensure-matrix (All (A) Symbol (Array A) -> (Array A)))
(define (ensure-matrix name a)
(if (matrix? a) a (raise-argument-error name "matrix?" a)))
(: ensure-row-matrix (All (A) Symbol (Array A) -> (Array A)))
(define (ensure-row-matrix name a)
(if (row-matrix? a) a (raise-argument-error name "row-matrix?" a)))
(: ensure-col-matrix (All (A) Symbol (Array A) -> (Array A)))
(define (ensure-col-matrix name a)
(if (col-matrix? a) a (raise-argument-error name "col-matrix?" a)))
(: sort/key (All (A B) (case-> ((Listof A) (B B -> Boolean) (A -> B) -> (Listof A))
((Listof A) (B B -> Boolean) (A -> B) Boolean -> (Listof A)))))
;; Sometimes necessary because TR can't do inference with keyword arguments yet
(define (sort/key lst lt? key [cache-keys? #f])
((inst sort A B) lst lt? #:key key #:cache-keys? cache-keys?))
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
(define (unsafe-vector2d-ref vss i j)
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
;; Note: this accepts +nan.0
(define nonnegative?
(λ: ([x : Real]) (not (x . < . 0))))
(define number-rational?
(λ: ([z : Number])
(cond [(real? z) (rational? z)]
[else (and (rational? (real-part z))
(rational? (imag-part z)))])))
(: find-partial-pivot
(case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real))
((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number))))
;; Find the element with maximum magnitude in a column
(define (find-partial-pivot rows m i j)
(define l (fx+ i 1))
(define pivot (unsafe-vector2d-ref rows i j))
(define mag-pivot (magnitude pivot))
(let loop ([#{l : Nonnegative-Fixnum} l] [#{p : Index} i] [pivot pivot] [mag-pivot mag-pivot])
(cond [(l . fx< . m)
(define new-pivot (unsafe-vector2d-ref rows l j))
(define mag-new-pivot (magnitude new-pivot))
(cond [(mag-new-pivot . > . mag-pivot) (loop (fx+ l 1) l new-pivot mag-new-pivot)]
[else (loop (fx+ l 1) p pivot mag-pivot)])]
[else (values p pivot)])))
(: elim-rows!
(case-> ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void)
((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void)))
(define (elim-rows! rows m i j pivot start)
(let loop ([#{l : Nonnegative-Fixnum} start])
(when (l . fx< . m)
(unless (l . fx= . i)
(define x_lj (unsafe-vector2d-ref rows l j))
(unless (zero? x_lj)
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- (/ x_lj pivot)))))
(loop (fx+ l 1)))))

View File

@ -16,9 +16,7 @@
(: mag^2 (Number -> Nonnegative-Real))
(define (mag^2 x)
(define y (* x (conjugate x)))
(cond [(and (real? y) (y . >= . 0)) y]
[else (error 'impossible)]))
(max 0 (real-part (* x (conjugate x)))))
(: vector-swap! (All (A) ((Vectorof A) Integer Integer -> Void)))
(define (vector-swap! vs i0 i1)

View File

@ -4,14 +4,24 @@
math/base
math/flonum
math/matrix
"../private/matrix/matrix-column.rkt"
"test-utils.rkt")
(define-syntax (check-matrix=? stx)
(syntax-case stx ()
[(_ a b)
(syntax/loc stx (check-true (matrix=? a b) (format "(matrix=? ~v ~v)" a b)))]
[(_ a b eps)
(syntax/loc stx (check-true (matrix=? a b eps) (format "(matrix=? ~v ~v ~v)" a b eps)))]))
(: random-matrix (case-> (Integer Integer -> (Matrix Integer))
(Integer Integer Integer -> (Matrix Integer))))
(Integer Integer Integer -> (Matrix Integer))
(Integer Integer Integer Integer -> (Matrix Integer))))
;; Generates a random matrix with Natural elements < k. Useful to test properties.
(define (random-matrix m n [k 100])
(array-strict (build-array (vector m n) (λ (_) (random k)))))
(define random-matrix
(case-lambda
[(m n) (random-matrix m n 100)]
[(m n k) (array-strict (build-matrix m n (λ (i j) (random-natural k))))]
[(m n k0 k1) (array-strict (build-matrix m n (λ (i j) (random-integer k0 k1))))]))
(define nonmatrices
(list (make-array #() 0)
@ -21,6 +31,16 @@
(make-array #(0 0) 0)
(make-array #(1 1 1) 0)))
(: matrix-l ((Matrix Number) -> (Matrix Number)))
(define (matrix-l M)
(define-values (L U) (matrix-lu M))
L)
(: matrix-q ((Matrix Number) -> (Matrix Number)))
(define (matrix-q M)
(define-values (Q R) (matrix-qr M))
Q)
;; ===================================================================================================
;; Literal syntax
@ -74,13 +94,6 @@
(for: ([a (in-list nonmatrices)])
(check-false (col-matrix? a)))
(check-true (matrix-zero? (make-matrix 4 3 0)))
(check-true (matrix-zero? (make-matrix 4 3 0.0)))
(check-true (matrix-zero? (make-matrix 4 3 0+0.0i)))
(check-false (matrix-zero? (row-matrix [0 0 0 0 1])))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-zero? a))))
;; ===================================================================================================
;; Accessors
@ -425,17 +438,9 @@
;; ===================================================================================================
;; Comprehensions
;; for/matrix and friends are defined in terms of for/array and friends, so we only need to test that
;; it works for one case each, and that they properly raise exceptions when given zero-length axes
(check-equal?
(for/matrix 2 2 ([i (in-range 4)]) i)
(matrix [[0 1] [2 3]]))
#;; TR can't type this, but it's defined using exactly the same wrapper as `for/matrix'
(check-equal?
(for*/matrix 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j))
(matrix [[0 1] [1 2]]))
;; for:/matrix and friends are defined in terms of for:/array and friends, so we only need to test
;; that it works for one case each, and that they properly raise exceptions when given zero-length
;; axes
(check-equal?
(for/matrix: 2 2 ([i (in-range 4)]) i)
@ -445,11 +450,6 @@
(for*/matrix: 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j))
(matrix [[0 1] [1 2]]))
(check-exn exn:fail:contract? (λ () (for/matrix 2 0 () 0)))
(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0)))
(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0)))
(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0)))
(check-exn exn:fail:contract? (λ () (for/matrix: 2 0 () 0)))
(check-exn exn:fail:contract? (λ () (for/matrix: 0 2 () 0)))
(check-exn exn:fail:contract? (λ () (for*/matrix: 2 0 () 0)))
@ -531,6 +531,10 @@
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-cols a))))
;; TODO: matrix-upper-triangle
;; TODO: matrix-lower-triangle
;; ===================================================================================================
;; Embiggenment (it's a perfectly cromulent word)
@ -626,6 +630,10 @@
(check-exn exn:fail:contract? (λ () (matrix-dot a (matrix [[1]]))))
(check-exn exn:fail:contract? (λ () (matrix-dot (matrix [[1]]) a))))
;; TODO: matrix-angle
;; TODO: matrix-normalize
;; ===================================================================================================
;; Simple operators
@ -647,8 +655,8 @@
;; matrix-hermitian
(let ([a (array-make-rectangular (random-matrix 5 6)
(random-matrix 5 6))])
(let ([a (array-make-rectangular (random-matrix 5 6 -100 100)
(random-matrix 5 6 -100 100))])
(check-equal? (matrix-hermitian a)
(matrix-conjugate (matrix-transpose a)))
(check-equal? (matrix-hermitian a)
@ -667,6 +675,86 @@
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-trace a))))
;; ===================================================================================================
;; Row/column operators
;; TODO: matrix-map-rows
;; TODO: matrix-map-cols
;; TODO: matrix-normalize-rows
;; TODO: matrix-normalize-cols
;; ===================================================================================================
;; Operator norms
;; TODO: matrix-op-1norm
;; TODO: matrix-op-2norm (after it's implemented)
;; TODO: matrix-op-inf-norm
;; ===================================================================================================
;; Error
(for*: ([x (in-list '(-inf.0 -10.0 -1.0 -0.1 -0.0 0.0 0.1 1.0 10.0 +inf.0 +nan.0))]
[y (in-list '(-inf.0 -10.0 -1.0 -0.1 -0.0 0.0 0.1 1.0 10.0 +inf.0 +nan.0))])
(check-eqv? (fl (matrix-absolute-error (row-matrix [x])
(row-matrix [y])))
(fl (absolute-error x y))
(format "x = ~v y = ~v" x y))
(check-eqv? (fl (matrix-relative-error (row-matrix [x])
(row-matrix [y])))
(fl (relative-error x y))
(format "x = ~v y = ~v" x y)))
(check-equal? (matrix-absolute-error (row-matrix [1 2])
(row-matrix [1 2]))
0)
(check-equal? (matrix-absolute-error (row-matrix [1 2])
(row-matrix [2 2]))
1)
(check-equal? (matrix-absolute-error (row-matrix [1 2])
(row-matrix [2 +nan.0]))
+inf.0)
(check-equal? (matrix-relative-error (row-matrix [1 2])
(row-matrix [1 2]))
0)
(check-equal? (matrix-relative-error (row-matrix [1 2])
(row-matrix [2 2]))
(/ 1 (matrix-op-inf-norm (row-matrix [2 2]))))
(check-equal? (matrix-relative-error (row-matrix [1 2])
(row-matrix [2 +nan.0]))
+inf.0)
;; TODO: matrix-basis-angle
;; ===================================================================================================
;; Approximate predicates
;; matrix-zero? (TODO: approximations)
(check-true (matrix-zero? (make-matrix 4 3 0)))
(check-true (matrix-zero? (make-matrix 4 3 0.0)))
(check-true (matrix-zero? (make-matrix 4 3 0+0.0i)))
(check-false (matrix-zero? (row-matrix [0 0 0 0 1])))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-zero? a))))
;; TODO: matrix-rows-orthogonal?
;; TODO: matrix-cols-orthogonal?
;; TODO: matrix-identity?
;; TODO: matrix-orthonormal?
;; ===================================================================================================
;; Gaussian elimination
@ -743,7 +831,7 @@
5280)
(for: ([_ (in-range 100)])
(define a (array- (random-matrix 3 3 7) (array 3)))
(define a (random-matrix 3 3 -3 4))
(check-equal? (matrix-determinant/row-reduction a)
(matrix-determinant a)))
@ -756,8 +844,8 @@
;; Solving linear systems
(for: ([_ (in-range 100)])
(define M (array- (random-matrix 3 3 7) (array 3)))
(define B (array- (random-matrix 3 (+ 1 (random 10)) 7) (array 3)))
(define M (random-matrix 3 3 -3 4))
(define B (random-matrix 3 (+ 1 (random 10)) -3 4))
(cond [(matrix-invertible? M)
(define X (matrix-solve M B))
(check-equal? (matrix* M X) B (format "M = ~a B = ~a" M B))]
@ -779,7 +867,7 @@
;; Inversion
(for: ([_ (in-range 100)])
(define a (array- (random-matrix 3 3 7) (array 3)))
(define a (random-matrix 3 3 -3 4))
(cond [(matrix-invertible? a)
(check-equal? (matrix* a (matrix-inverse a))
(identity-matrix 3)
@ -815,11 +903,6 @@
[0 0 0 -13]]))
(check-equal? (matrix* L V) M))
(: matrix-l ((Matrix Number) -> Any))
(define (matrix-l M)
(define-values (L U) (matrix-lu M))
L)
(check-exn exn:fail? (λ () (matrix-l (matrix [[1 1 0 2]
[0 2 0 1]
[1 0 0 0]
@ -830,15 +913,57 @@
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-l a))))
;; ===================================================================================================
;; Gram-Schmidt
(check-equal? (matrix-gram-schmidt (matrix [[3 2] [1 2]]))
(matrix [[3 -2/5] [1 6/5]]))
(check-equal? (matrix-gram-schmidt (matrix [[3 2] [1 2]]) #t)
(matrix-scale (matrix [[3 -1] [1 3]]) (sqrt 1/10)))
(check-equal? (matrix-gram-schmidt (matrix [[12 -51 4]
[ 6 167 -68]
[-4 24 -41]])
#t)
(matrix [[ 6/7 -69/175 -58/175]
[ 3/7 158/175 6/175]
[-2/7 6/35 -33/35 ]]))
(check-equal? (matrix-gram-schmidt (matrix [[12 -51 4]
[ 6 167 -68]
[-4 24 -41]])
#t)
(matrix [[ 6/7 -69/175 -58/175]
[ 3/7 158/175 6/175]
[-2/7 6/35 -33/35 ]]))
(check-equal? (matrix-gram-schmidt (matrix [[12 -51]
[ 6 167]
[-4 24]])
#t)
(matrix [[ 6/7 -69/175]
[ 3/7 158/175]
[-2/7 6/35 ]]))
(check-equal? (matrix-gram-schmidt (col-matrix [12 6 -4]) #t)
(col-matrix [6/7 3/7 -2/7]))
(check-equal? (matrix-gram-schmidt (col-matrix [12 6 -4]) #f)
(col-matrix [12 6 -4]))
;; ===================================================================================================
;; QR decomposition
(check-true (matrix-orthonormal? (matrix-q (index-array #(100 1)))))
#|
;; ===================================================================================================
;; Tests not yet converted to rackunit
(matrix-gram-schmidt
(matrix [[2 1 2]
[2 2 3]
[5 1 5]])
#t)
;; A particularly tricky one used to demonstrate loss of orthogonality
(matrix-qr (matrix [[0.70000 0.70711]
[0.70001 0.70711]]))
(begin

View File

@ -0,0 +1,51 @@
#lang racket
(require (for-syntax racket/match)
rackunit
math/matrix)
;; ===================================================================================================
;; Contract tests
(begin-for-syntax
(define exceptions (list 'matrix 'col-matrix 'row-matrix
'matrix-determinant/row-reduction))
(define (looks-like-value? sym)
(define str (symbol->string sym))
(and (not (char-upper-case? (string-ref str 0)))
(not (regexp-match #rx"for/" str))
(not (regexp-match #rx"for\\*/" str))
(not (member sym exceptions))))
(define matrix-exports
(let ()
(match-define (list (list #f _ ...)
(list 1 _ ...)
(list 0 matrix-exports ...))
(syntax-local-module-exports #'math/matrix))
(filter looks-like-value? matrix-exports)))
)
(define-syntax (all-exports stx)
(with-syntax ([(matrix-exports ...) matrix-exports])
(syntax/loc stx
(begin (void matrix-exports) ...))))
(all-exports)
;; ===================================================================================================
;; Comprehensions
(check-equal?
(for/matrix 2 2 ([i (in-range 4)]) i)
(matrix [[0 1] [2 3]]))
(check-equal?
(for*/matrix 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j))
(matrix [[0 1] [1 2]]))
(check-exn exn:fail:contract? (λ () (for/matrix 2 0 () 0)))
(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0)))
(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0)))
(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0)))