diff --git a/collects/math/matrix.rkt b/collects/math/matrix.rkt index 2447f34ffb..e3829832eb 100644 --- a/collects/math/matrix.rkt +++ b/collects/math/matrix.rkt @@ -1,16 +1,133 @@ -#lang typed/racket/base +#lang racket/base + +(require typed/untyped-utils) (require "private/matrix/matrix-arithmetic.rkt" - "private/matrix/matrix-constructors.rkt" "private/matrix/matrix-conversion.rkt" "private/matrix/matrix-syntax.rkt" - "private/matrix/matrix-basic.rkt" - "private/matrix/matrix-operations.rkt" "private/matrix/matrix-comprehension.rkt" "private/matrix/matrix-expt.rkt" "private/matrix/matrix-types.rkt" "private/matrix/matrix-2d.rkt" - "private/matrix/utils.rkt") + ;;"private/matrix/matrix-gauss-elim.rkt" ; all use require/untyped-contract + (except-in "private/matrix/matrix-solve.rkt" + matrix-determinant + matrix-inverse + matrix-solve) + (except-in "private/matrix/matrix-constructors.rkt" + vandermonde-matrix) + (except-in "private/matrix/matrix-basic.rkt" + matrix-dot + matrix-angle + matrix-normalize + matrix-conjugate + matrix-hermitian + matrix-trace + matrix-normalize-rows + matrix-normalize-cols) + (except-in "private/matrix/matrix-subspace.rkt" + matrix-col-space) + (except-in "private/matrix/matrix-operator-norm.rkt" + matrix-basis-angle) + ;;"private/matrix/matrix-qr.rkt" ; all use require/untyped-contract + ;;"private/matrix/matrix-lu.rkt" ; all use require/untyped-contract + ;;"private/matrix/matrix-gram-schmidt.rkt" ; all use require/untyped-contract + ) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-gauss-elim.rkt" + [matrix-gauss-elim + (case-> ((Matrix Number) -> (Values (Matrix Number) (Listof Index))) + ((Matrix Number) Any -> (Values (Matrix Number) (Listof Index))) + ((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index))))] + [matrix-row-echelon + (case-> ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Any -> (Matrix Number)) + ((Matrix Number) Any Any -> (Matrix Number)))]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-solve.rkt" + [matrix-determinant + ((Matrix Number) -> Number)] + [matrix-inverse + (All (A) (case-> ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) (-> A) -> (U A (Matrix Number)))))] + [matrix-solve + (All (A) (case-> + ((Matrix Number) (Matrix Number) -> (Matrix Number)) + ((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number)))))]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-constructors.rkt" + [vandermonde-matrix ((Listof Number) Integer -> (Matrix Number))]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-basic.rkt" + [matrix-dot + (case-> ((Matrix Number) -> Nonnegative-Real) + ((Matrix Number) (Matrix Number) -> Number))] + [matrix-angle + ((Matrix Number) (Matrix Number) -> Number)] + [matrix-normalize + (All (A) (case-> ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Real -> (Matrix Number)) + ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))] + [matrix-conjugate + ((Matrix Number) -> (Matrix Number))] + [matrix-hermitian + ((Matrix Number) -> (Matrix Number))] + [matrix-trace + ((Matrix Number) -> Number)] + [matrix-normalize-rows + (All (A) (case-> ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Real -> (Matrix Number)) + ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))] + [matrix-normalize-cols + (All (A) (case-> ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Real -> (Matrix Number)) + ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-subspace.rkt" + [matrix-col-space + (All (A) (case-> ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) (-> A) -> (U A (Matrix Number)))))]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-operator-norm.rkt" + [matrix-basis-angle + ((Matrix Number) (Matrix Number) -> Number)]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-qr.rkt" + [matrix-qr + (case-> ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) + ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-lu.rkt" + [matrix-lu + (All (A) (case-> ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) + ((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number)))))]) + +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt" + "private/array/array-struct.rkt")) + "private/matrix/matrix-gram-schmidt.rkt" + [matrix-gram-schmidt + (case-> ((Matrix Number) -> (Array Number)) + ((Matrix Number) Any -> (Array Number)) + ((Matrix Number) Any Integer -> (Array Number)))] + [matrix-basis-extension + ((Matrix Number) -> (Array Number))]) (provide (all-from-out "private/matrix/matrix-arithmetic.rkt" @@ -18,8 +135,40 @@ "private/matrix/matrix-conversion.rkt" "private/matrix/matrix-syntax.rkt" "private/matrix/matrix-basic.rkt" - "private/matrix/matrix-operations.rkt" + "private/matrix/matrix-subspace.rkt" + "private/matrix/matrix-solve.rkt" + "private/matrix/matrix-operator-norm.rkt" "private/matrix/matrix-comprehension.rkt" "private/matrix/matrix-expt.rkt" "private/matrix/matrix-types.rkt" - "private/matrix/matrix-2d.rkt")) + "private/matrix/matrix-2d.rkt") + ;; matrix-gauss-elim.rkt + matrix-gauss-elim + matrix-row-echelon + ;; matrix-solve.rkt + matrix-determinant + matrix-inverse + matrix-solve + ;; matrix-constructors.rkt + vandermonde-matrix + ;; matrix-basic.rkt + matrix-dot + matrix-angle + matrix-normalize + matrix-conjugate + matrix-hermitian + matrix-trace + matrix-normalize-rows + matrix-normalize-cols + ;; matrix-subspace.rkt + matrix-col-space + ;; matrix-operator-norm.rkt + matrix-basis-angle + ;; matrix-qr.rkt + matrix-qr + ;; matrix-lu.rkt + matrix-lu + ;; matrix-gram-schmidt.rkt + matrix-gram-schmidt + matrix-basis-extension + ) diff --git a/collects/math/private/array/array-pointwise.rkt b/collects/math/private/array/array-pointwise.rkt index f624f50890..bd4b923863 100644 --- a/collects/math/private/array/array-pointwise.rkt +++ b/collects/math/private/array/array-pointwise.rkt @@ -25,8 +25,9 @@ (define-syntax-rule (define-array-op name op) (define-syntax-rule (name arrs (... ...)) (array-map op arrs (... ...)))) -(define-syntax-rule (array-scale arr x) - (inline-array-map (λ (y) (* x y)) arr)) +(define-syntax-rule (array-scale arr x-expr) + (let ([x x-expr]) + (inline-array-map (λ (y) (* x y)) arr))) (define-array-op1 array-sqr sqr) (define-array-op1 array-sqrt sqrt) diff --git a/collects/math/private/matrix/matrix-2d.rkt b/collects/math/private/matrix/matrix-2d.rkt index 9f30616bcd..da16959d37 100644 --- a/collects/math/private/matrix/matrix-2d.rkt +++ b/collects/math/private/matrix/matrix-2d.rkt @@ -1,7 +1,6 @@ #lang typed/racket/base -(require math/array - "matrix-types.rkt" +(require "matrix-types.rkt" "matrix-syntax.rkt") (provide matrix-2d-rotation diff --git a/collects/math/private/matrix/matrix-arithmetic.rkt b/collects/math/private/matrix/matrix-arithmetic.rkt index cf65194af2..fb4027bfe1 100644 --- a/collects/math/private/matrix/matrix-arithmetic.rkt +++ b/collects/math/private/matrix/matrix-arithmetic.rkt @@ -1,13 +1,52 @@ #lang racket/base +(module untyped-arithmetic-defs typed/racket/base + (require "matrix-types.rkt" + (prefix-in typed: "typed-matrix-arithmetic.rkt")) + + (provide (all-defined-out)) + + (: matrix* ((Matrix Number) (Matrix Number) * -> (Matrix Number))) + (define matrix* typed:matrix*) + + (: matrix+ ((Matrix Number) (Matrix Number) * -> (Matrix Number))) + (define matrix+ typed:matrix+) + + (: matrix- ((Matrix Number) (Matrix Number) * -> (Matrix Number))) + (define matrix- typed:matrix-) + + (: matrix-scale ((Matrix Number) Number -> (Matrix Number))) + (define matrix-scale typed:matrix-scale) + + (: matrix-sum ((Listof (Matrix Number)) -> (Matrix Number))) + (define matrix-sum typed:matrix-sum) + + ) ; module untyped-arithmetic-defs + +(module arithmetic-defs racket/base + (require typed/untyped-utils + (prefix-in typed: "typed-matrix-arithmetic.rkt") + (prefix-in untyped: (submod ".." untyped-arithmetic-defs)) + (rename-in "untyped-matrix-arithmetic.rkt" + [matrix-map untyped:matrix-map])) + + (provide (all-defined-out)) + + (define-typed/untyped-identifier matrix-map typed:matrix-map untyped:matrix-map) + (define-typed/untyped-identifier matrix* typed:matrix* untyped:matrix*) + (define-typed/untyped-identifier matrix+ typed:matrix+ untyped:matrix+) + (define-typed/untyped-identifier matrix- typed:matrix- untyped:matrix-) + (define-typed/untyped-identifier matrix-scale typed:matrix-scale untyped:matrix-scale) + (define-typed/untyped-identifier matrix-sum typed:matrix-sum untyped:matrix-sum) + + ) ; module arithmetic-defs + (require (for-syntax racket/base) typed/untyped-utils (prefix-in typed: "typed-matrix-arithmetic.rkt") - (rename-in "untyped-matrix-arithmetic.rkt" - [matrix-map untyped:matrix-map])) - -(define-typed/untyped-identifier matrix-map - typed:matrix-map untyped:matrix-map) + (prefix-in fun: (submod "." arithmetic-defs)) + (except-in "untyped-matrix-arithmetic.rkt" matrix-map) + ) (define-syntax (define/inline-macro stx) (syntax-case stx () @@ -19,17 +58,16 @@ [(_ . es) (syntax/loc inner-stx (typed:fun . es))] [_ (syntax/loc inner-stx typed:fun)])))])) -(define/inline-macro matrix* (a as ...) inline-matrix* typed:matrix*) -(define/inline-macro matrix+ (a as ...) inline-matrix+ typed:matrix+) -(define/inline-macro matrix- (a as ...) inline-matrix- typed:matrix-) -(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale) - -(define/inline-macro do-matrix-map (f a as ...) inline-matrix-map matrix-map) +(define/inline-macro matrix-map (f a as ...) inline-matrix-map fun:matrix-map) +(define/inline-macro matrix* (a as ...) inline-matrix* fun:matrix*) +(define/inline-macro matrix+ (a as ...) inline-matrix+ fun:matrix+) +(define/inline-macro matrix- (a as ...) inline-matrix- fun:matrix-) +(define/inline-macro matrix-scale (a x) inline-matrix-scale fun:matrix-scale) (provide - (rename-out [do-matrix-map matrix-map] - [typed:matrix= matrix=] - [typed:matrix-sum matrix-sum]) + (rename-out [typed:matrix= matrix=] + [fun:matrix-sum matrix-sum]) + matrix-map matrix* matrix+ matrix- diff --git a/collects/math/private/matrix/matrix-basic.rkt b/collects/math/private/matrix/matrix-basic.rkt index b84ad8d55d..3fee71a986 100644 --- a/collects/math/private/matrix/matrix-basic.rkt +++ b/collects/math/private/matrix/matrix-basic.rkt @@ -1,41 +1,66 @@ -#lang typed/racket +#lang typed/racket/base (require racket/list racket/fixnum - math/array math/flonum + math/base "matrix-types.rkt" "matrix-arithmetic.rkt" + "matrix-constructors.rkt" + "matrix-conversion.rkt" "utils.rkt" - "../unsafe.rkt") + "../unsafe.rkt" + "../array/array-struct.rkt" + "../array/array-indexing.rkt" + "../array/array-sequence.rkt" + "../array/array-transform.rkt" + "../array/array-fold.rkt" + "../array/array-pointwise.rkt" + "../array/array-convert.rkt" + "../array/utils.rkt" + "../vector/vector-mutate.rkt") (provide ;; Extraction matrix-ref - matrix-diagonal submatrix matrix-row matrix-col matrix-rows matrix-cols - ;; Predicates - matrix-zero? + matrix-diagonal + matrix-upper-triangle + matrix-lower-triangle ;; Embiggenment matrix-augment matrix-stack - ;; Norm and inner product + ;; Inner product space + matrix-1norm + matrix-2norm + matrix-inf-norm matrix-norm matrix-dot + matrix-angle + matrix-normalize ;; Simple operators matrix-transpose matrix-conjugate matrix-hermitian - matrix-trace) + matrix-trace + ;; Row/column operators + matrix-map-rows + matrix-map-cols + matrix-normalize-rows + matrix-normalize-cols + ;; Predicates + matrix-zero? + matrix-rows-orthogonal? + matrix-cols-orthogonal?) ;; =================================================================================================== ;; Extraction -(: matrix-ref (All (A) (Array A) Integer Integer -> A)) +(: matrix-ref (All (A) (Matrix A) Integer Integer -> A)) (define (matrix-ref a i j) (define-values (m n) (matrix-shape a)) (cond [(or (i . < . 0) (i . >= . m)) @@ -45,16 +70,6 @@ [else (unsafe-array-ref a ((inst vector Index) i j))])) -(: matrix-diagonal (All (A) ((Array A) -> (Array A)))) -(define (matrix-diagonal a) - (define m (square-matrix-size a)) - (define proc (unsafe-array-proc a)) - (unsafe-build-array - ((inst vector Index) m) - (λ: ([js : Indexes]) - (define i (unsafe-vector-ref js 0)) - (proc ((inst vector Index) i i))))) - (: submatrix (All (A) (Matrix A) Slice-Spec Slice-Spec -> (Matrix A))) (define (submatrix a row-range col-range) (array-slice-ref (ensure-matrix 'submatrix a) (list row-range col-range))) @@ -89,42 +104,67 @@ (unsafe-vector-set! ij 1 0) res))])) -(: matrix-rows (All (A) (Array A) -> (Listof (Array A)))) +(: matrix-rows (All (A) (Matrix A) -> (Listof (Matrix A)))) (define (matrix-rows a) (array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0)) -(: matrix-cols (All (A) (Array A) -> (Listof (Array A)))) +(: matrix-cols (All (A) (Matrix A) -> (Listof (Matrix A)))) (define (matrix-cols a) (array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1)) -;; =================================================================================================== -;; Predicates +(: matrix-diagonal (All (A) ((Matrix A) -> (Array A)))) +(define (matrix-diagonal a) + (define m (square-matrix-size a)) + (define proc (unsafe-array-proc a)) + (unsafe-build-array + ((inst vector Index) m) + (λ: ([js : Indexes]) + (define i (unsafe-vector-ref js 0)) + (proc ((inst vector Index) i i))))) -(: matrix-zero? ((Array Number) -> Boolean)) -(define (matrix-zero? a) - (array-all-and (matrix-map zero? a))) +(: matrix-upper-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) +(define (matrix-upper-triangle M) + (define-values (m n) (matrix-shape M)) + (define proc (unsafe-array-proc M)) + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (if (i . fx<= . j) (proc ij) 0)))) + +(: matrix-lower-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) +(define (matrix-lower-triangle M) + (define-values (m n) (matrix-shape M)) + (define proc (unsafe-array-proc M)) + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (if (i . fx>= . j) (proc ij) 0)))) ;; =================================================================================================== ;; Embiggenment (this is a perfectly cromulent word) -(: matrix-augment (All (A) (Listof (Array A)) -> (Array A))) +(: matrix-augment (All (A) (Listof (Matrix A)) -> (Matrix A))) (define (matrix-augment as) (cond [(empty? as) (raise-argument-error 'matrix-augment "nonempty List" as)] [else (define m (matrix-num-rows (first as))) - (cond [(andmap (λ: ([a : (Array A)]) (= m (matrix-num-rows a))) (rest as)) + (cond [(andmap (λ: ([a : (Matrix A)]) (= m (matrix-num-rows a))) (rest as)) (array-append* as 1)] [else (error 'matrix-augment "matrices must have the same number of rows; given ~a" (format-matrices/error as))])])) -(: matrix-stack (All (A) (Listof (Array A)) -> (Array A))) +(: matrix-stack (All (A) (Listof (Matrix A)) -> (Matrix A))) (define (matrix-stack as) (cond [(empty? as) (raise-argument-error 'matrix-stack "nonempty List" as)] [else (define n (matrix-num-cols (first as))) - (cond [(andmap (λ: ([a : (Array A)]) (= n (matrix-num-cols a))) (rest as)) + (cond [(andmap (λ: ([a : (Matrix A)]) (= n (matrix-num-cols a))) (rest as)) (array-append* as 0)] [else (error 'matrix-stack @@ -132,81 +172,223 @@ (format-matrices/error as))])])) ;; =================================================================================================== -;; Matrix norms and Frobenius inner product +;; Inner product space (entrywise norm) -(: maximum-norm ((Array Number) -> Real)) -(define (maximum-norm a) - (array-all-max (array-magnitude a))) - -(: taxicab-norm ((Array Number) -> Real)) -(define (taxicab-norm a) +(: matrix-1norm ((Matrix Number) -> Nonnegative-Real)) +(define (matrix-1norm a) (array-all-sum (array-magnitude a))) -(: frobenius-norm ((Array Number) -> Real)) -(define (frobenius-norm a) +(: matrix-2norm ((Matrix Number) -> Nonnegative-Real)) +(define (matrix-2norm a) (let ([a (array-strict (array-magnitude a))]) ;; Compute this divided by the maximum to avoid underflow and overflow (define mx (array-all-max a)) (cond [(and (rational? mx) (positive? mx)) - (assert - (* mx (sqrt (array-all-sum (inline-array-map (λ: ([x : Number]) (sqr (/ x mx))) a)))) - real?)] + (* mx (sqrt (array-all-sum + (inline-array-map (λ: ([x : Nonnegative-Real]) (sqr (/ x mx))) a))))] [else mx]))) -(: p-norm ((Array Number) Positive-Real -> Real)) -(define (p-norm a p) +(: matrix-inf-norm ((Matrix Number) -> Nonnegative-Real)) +(define (matrix-inf-norm a) + (array-all-max (array-magnitude a))) + +(: matrix-p-norm ((Matrix Number) Positive-Real -> Nonnegative-Real)) +(define (matrix-p-norm a p) (let ([a (array-strict (array-magnitude a))]) ;; Compute this divided by the maximum to avoid underflow and overflow (define mx (array-all-max a)) (cond [(and (rational? mx) (positive? mx)) (assert - (* mx (expt (array-all-sum (inline-array-map (λ: ([x : Real]) (expt (/ x mx) p)) a)) + (* mx (expt (array-all-sum + (inline-array-map (λ: ([x : Nonnegative-Real]) (expt (/ x mx) p)) a)) (/ p))) - real?)] + (make-predicate Nonnegative-Real))] [else mx]))) -(: matrix-norm (case-> ((Array Number) -> Real) - ((Array Number) Real -> Real))) +(: matrix-norm (case-> ((Matrix Number) -> Nonnegative-Real) + ((Matrix Number) Real -> Nonnegative-Real))) ;; Computes the p norm of a matrix (define (matrix-norm a [p 2]) (cond [(not (matrix? a)) (raise-argument-error 'matrix-norm "matrix?" 0 a p)] - [(p . = . 2) (frobenius-norm a)] - [(p . = . +inf.0) (maximum-norm a)] - [(p . = . 1) (taxicab-norm a)] - [(p . > . 1) (p-norm a p)] + [(p . = . 1) (matrix-1norm a)] + [(p . = . 2) (matrix-2norm a)] + [(p . = . +inf.0) (matrix-inf-norm a)] + [(p . > . 1) (matrix-p-norm a p)] [else (raise-argument-error 'matrix-norm "Real >= 1" 1 a p)])) -(: matrix-dot (case-> ((Array Real) (Array Real) -> Real) - ((Array Number) (Array Number) -> Number))) -;; Computes the Frobenius inner product of two matrices -(define (matrix-dot a b) - (define-values (m n) (matrix-shapes 'matrix-dot a b)) - (define aproc (unsafe-array-proc a)) - (define bproc (unsafe-array-proc b)) - (array-all-sum - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([js : Indexes]) - (* (aproc js) (conjugate (bproc js))))))) +(: matrix-dot (case-> ((Matrix Real) -> Nonnegative-Real) + ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Number) -> Nonnegative-Real) + ((Matrix Number) (Matrix Number) -> Number))) +;; Computes the Frobenius inner product of a matrix with itself or of two matrices +(define matrix-dot + (case-lambda + [(a) + (assert + (array-all-sum + (inline-array-map + (λ (x) (* x (conjugate x))) + (ensure-matrix 'matrix-dot a))) + (make-predicate Nonnegative-Real))] + [(a b) + (define-values (m n) (matrix-shapes 'matrix-dot a b)) + (define aproc (unsafe-array-proc a)) + (define bproc (unsafe-array-proc b)) + (array-all-sum + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (* (aproc js) (conjugate (bproc js))))))])) + +(: matrix-angle (case-> ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Number) (Matrix Number) -> Number))) +(define (matrix-angle M N) + (acos (/ (matrix-dot M N) (* (matrix-2norm M) (matrix-2norm N))))) + +(: matrix-normalize + (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Real) Real -> (Matrix Real)) + ((Matrix Real) Real (-> A) -> (U A (Matrix Real))) + ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Real -> (Matrix Number)) + ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))) +(define matrix-normalize + (case-lambda + [(M) (matrix-normalize M 2)] + [(M p) (matrix-normalize M p (λ () (raise-argument-error + 'matrix-normalize "nonzero matrix?" 0 M p)))] + [(M p fail) + (array-strict! M) + (define x (matrix-norm M p)) + (cond [(and (zero? x) (exact? x)) (fail)] + [else (matrix-scale M (/ x))])])) ;; =================================================================================================== ;; Operators -(: matrix-transpose (All (A) (Array A) -> (Array A))) +(: matrix-transpose (All (A) (Matrix A) -> (Matrix A))) (define (matrix-transpose a) (array-axis-swap (ensure-matrix 'matrix-transpose a) 0 1)) -(: matrix-conjugate (case-> ((Array Real) -> (Array Real)) - ((Array Number) -> (Array Number)))) +(: matrix-conjugate (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Number) -> (Matrix Number)))) (define (matrix-conjugate a) (array-conjugate (ensure-matrix 'matrix-conjugate a))) -(: matrix-hermitian (case-> ((Array Real) -> (Array Real)) - ((Array Number) -> (Array Number)))) +(: matrix-hermitian (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Number) -> (Matrix Number)))) (define (matrix-hermitian a) (array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1)) -(: matrix-trace (case-> ((Array Real) -> Real) - ((Array Number) -> Number))) +(: matrix-trace (case-> ((Matrix Real) -> Real) + ((Matrix Number) -> Number))) (define (matrix-trace a) (array-all-sum (matrix-diagonal a))) + +;; =================================================================================================== +;; Row/column operations + +(: matrix-map-rows + (All (A B F) (case-> (((Matrix A) -> (Matrix B)) (Matrix A) -> (Matrix B)) + (((Matrix A) -> (U #f (Matrix B))) (Matrix A) (-> F) + -> (U F (Matrix B)))))) +(define matrix-map-rows + (case-lambda + [(f M) (matrix-stack (map f (matrix-rows M)))] + [(f M fail) + (define ms (matrix-rows M)) + (define n (f (first ms))) + (cond [n (let loop ([ms (rest ms)] [ns (list n)]) + (cond [(empty? ms) (matrix-stack (reverse ns))] + [else (define n (f (first ms))) + (cond [n (loop (rest ms) (cons n ns))] + [else (fail)])]))] + [else (fail)])])) + +(: matrix-map-cols + (All (A B F) (case-> (((Matrix A) -> (Matrix B)) (Matrix A) -> (Matrix B)) + (((Matrix A) -> (U #f (Matrix B))) (Matrix A) (-> F) + -> (U F (Matrix B)))))) +(define matrix-map-cols + (case-lambda + [(f M) (matrix-augment (map f (matrix-cols M)))] + [(f M fail) + (define ms (matrix-cols M)) + (define n (f (first ms))) + (cond [n (let loop ([ms (rest ms)] [ns (list n)]) + (cond [(empty? ms) (matrix-augment (reverse ns))] + [else (define n (f (first ms))) + (cond [n (loop (rest ms) (cons n ns))] + [else (fail)])]))] + [else (fail)])])) + +(: make-matrix-normalize (Real -> (case-> ((Matrix Real) -> (U #f (Matrix Real))) + ((Matrix Number) -> (U #f (Matrix Number)))))) +(define ((make-matrix-normalize p) M) + (matrix-normalize M p (λ () #f))) + +(: matrix-normalize-rows + (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Real) Real -> (Matrix Real)) + ((Matrix Real) Real (-> A) -> (U A (Matrix Real))) + ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Real -> (Matrix Number)) + ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))) +(define matrix-normalize-rows + (case-lambda + [(M) (matrix-normalize-rows M 2)] + [(M p) + (define (fail) (raise-argument-error 'matrix-normalize-rows "matrix? with nonzero rows" 0 M p)) + (matrix-normalize-rows M p fail)] + [(M p fail) + (matrix-map-rows (make-matrix-normalize p) M fail)])) + +(: matrix-normalize-cols + (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Real) Real -> (Matrix Real)) + ((Matrix Real) Real (-> A) -> (U A (Matrix Real))) + ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Real -> (Matrix Number)) + ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))) +(define matrix-normalize-cols + (case-lambda + [(M) (matrix-normalize-cols M 2)] + [(M p) + (define (fail) + (raise-argument-error 'matrix-normalize-cols "matrix? with nonzero columns" 0 M p)) + (matrix-normalize-cols M p fail)] + [(M p fail) + (matrix-map-cols (make-matrix-normalize p) M fail)])) + +;; =================================================================================================== +;; Robust predicates using entrywise norms + +(: matrix-zero? (case-> ((Matrix Number) -> Boolean) + ((Matrix Number) Real -> Boolean))) +(define (matrix-zero? M [eps 0]) + (cond [(negative? eps) (raise-argument-error 'matrix-zero? "Nonnegative-Real" 1 M eps)] + [else (<= (matrix-norm M +inf.0) eps)])) + +(: rows-orthogonal? ((Matrix Number) Nonnegative-Real -> Boolean)) +(define (rows-orthogonal? M eps) + (define rows (matrix->vector* M)) + (define m (vector-length rows)) + (let/ec: return : Boolean + (for*: ([i0 (in-range m)] [i1 (in-range (fx+ i0 1) m)]) + (define r0 (unsafe-vector-ref rows i0)) + (define r1 (unsafe-vector-ref rows i1)) + (when ((sqrt (magnitude (vector-dot r0 r1))) . >= . eps) (return #f))) + #t)) + +(: matrix-rows-orthogonal? (case-> ((Matrix Number) -> Boolean) + ((Matrix Number) Real -> Boolean))) +(define (matrix-rows-orthogonal? M [eps (* 10 epsilon.0)]) + (cond [(negative? eps) (raise-argument-error 'matrix-rows-orthogonal? "Nonnegative-Real" 1 M eps)] + [else (rows-orthogonal? M eps)])) + + +(: matrix-cols-orthogonal? (case-> ((Matrix Number) -> Boolean) + ((Matrix Number) Real -> Boolean))) +(define (matrix-cols-orthogonal? M [eps (* 10 epsilon.0)]) + (cond [(negative? eps) (raise-argument-error 'matrix-cols-orthogonal? "Nonnegative-Real" 1 M eps)] + [else (rows-orthogonal? (matrix-transpose M) eps)])) diff --git a/collects/math/private/matrix/matrix-basis.rkt b/collects/math/private/matrix/matrix-basis.rkt deleted file mode 100644 index 75bc4288ea..0000000000 --- a/collects/math/private/matrix/matrix-basis.rkt +++ /dev/null @@ -1,112 +0,0 @@ -#lang typed/racket - -(require racket/fixnum - math/array - math/matrix - "matrix-column.rkt" - "utils.rkt" - "../unsafe.rkt" - "../vector/vector-mutate.rkt" - ) - -(: col-matrix-project1 (case-> ((Matrix Real) (Matrix Real) Any -> (U #f (Matrix Real))) - ((Matrix Number) (Matrix Number) Any -> (U #f (Matrix Number))))) -(define (col-matrix-project1 v b unit?) - (cond [unit? (matrix-scale b (matrix-dot v b))] - [else (define b.b (matrix-dot b b)) - (cond [(and (zero? b.b) (exact? b.b)) #f] - [else (matrix-scale b (/ (matrix-dot v b) b.b))])])) - -(: col-matrix-project - (All (A) (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real)) - ((Matrix Real) (Matrix Real) Any -> (U A (Matrix Real))) - ((Matrix Real) (Matrix Real) Any (-> A) -> (U A (Matrix Real))) - ((Matrix Number) (Matrix Number) -> (Matrix Number)) - ((Matrix Number) (Matrix Number) Any -> (U A (Matrix Number))) - ((Matrix Number) (Matrix Number) Any (-> A) -> (U A (Matrix Number)))))) -(define col-matrix-project - (case-lambda - [(v B) (col-matrix-project v B #f)] - [(v B unit?) - (col-matrix-project - v B unit? - (λ () (error 'col-matrix-project "expected basis with nonzero column vectors; given ~e" B)))] - [(v B unit? fail) - (unless (col-matrix? v) (raise-argument-error 'col-matrix-project "col-matrix?" v)) - (define bs (matrix-cols (ensure-matrix 'col-matrix-project B))) - (define p (col-matrix-project1 v (first bs) unit?)) - (cond [p (let loop ([bs (rest bs)] [p p]) - (cond [(empty? bs) p] - [else (define q (col-matrix-project1 v (first bs) unit?)) - (if q (loop (rest bs) (matrix+ p q)) (fail))]))] - [else (fail)])])) - -(: find-nonzero-vector (case-> ((Vectorof (Vectorof Real)) -> (U #f Index)) - ((Vectorof (Vectorof Number)) -> (U #f Index)))) -(define (find-nonzero-vector vss) - (define n (vector-length vss)) - (cond [(= n 0) #f] - [else (let loop ([#{i : Nonnegative-Fixnum} 0]) - (cond [(i . fx< . n) - (define vs (unsafe-vector-ref vss i)) - (if (vector-zero? vs) (loop (fx+ i 1)) i)] - [else #f]))])) - -(: subtract-projections! - (case-> ((Vectorof (Vectorof Real)) Index Index (Vectorof Real) Any -> Void) - ((Vectorof (Vectorof Number)) Index Index (Vectorof Number) Any -> Void))) -(define (subtract-projections! cols n i ci unit?) - (let j-loop ([#{j : Nonnegative-Fixnum} (fx+ i 1)]) - (when (j . fx< . n) - (vector-sub-proj! (unsafe-vector-ref cols j) ci unit?) - (j-loop (fx+ j 1))))) - -(: matrix-gram-schmidt (All (A) (case-> ((Matrix Real) -> (Array Real)) - ((Matrix Real) Any -> (Array Real)) - ((Matrix Number) -> (Array Number)) - ((Matrix Number) Any -> (Array Number))))) -(define (matrix-gram-schmidt M [unit? #f]) - (define rows (matrix->vector* M)) - (define n (vector-length rows)) - (define i (find-nonzero-vector rows)) - (cond [i (define rowi (unsafe-vector-ref rows i)) - (subtract-projections! rows n i rowi #f) - (when unit? (vector-normalize! rowi)) - (let loop ([#{i : Nonnegative-Fixnum} (fx+ i 1)] [bs (list rowi)]) - (cond [(i . fx< . n) - (define rowi (unsafe-vector-ref rows i)) - (cond [(vector-zero? rowi) (loop (fx+ i 1) bs)] - [else (subtract-projections! rows n i rowi #f) - (when unit? (vector-normalize! rowi)) - (loop (fx+ i 1) (cons rowi bs))])] - [else - (vector*->matrix (list->vector (reverse bs)))]))] - [else - (make-array (vector 0 (matrix-num-cols M)) 0)])) -#| -(define a (col-matrix [1 2 1])) -(define b (col-matrix [1 -2 2])) - -(define basis - (gram-schmidt-orthogonal - (matrix-cols - (array #[#[2 1 0] #[2 2 1] #[0 2 0]])))) - -(column-project a b) -(col-matrix-project a b) - -(projection-on-orthogonal-basis a basis) -(col-matrix-project a (matrix-augment basis)) -(projection-on-orthonormal-basis a basis) -(col-matrix-project a (matrix-augment basis) 'orthonormal) - -(matrix-gram-schmidt - (matrix [[0 1 2] - [0 2 3] - [0 1 5]])) - -(matrix-gram-schmidt - (matrix [[5 1 2] - [2 2 3] - [-3 1 5]])) -|# diff --git a/collects/math/private/matrix/matrix-column.rkt b/collects/math/private/matrix/matrix-column.rkt deleted file mode 100644 index 9f452ac512..0000000000 --- a/collects/math/private/matrix/matrix-column.rkt +++ /dev/null @@ -1,126 +0,0 @@ -#lang typed/racket/base - -(require math/array - math/base - "matrix-types.rkt" - "matrix-conversion.rkt" - "matrix-arithmetic.rkt" - "../unsafe.rkt") - -(provide unit-column - column-height - unsafe-column->vector - column-scale - column+ - column-dot - column-norm - column-project - column-project/unit - column-normalize) - -(: unit-column : Integer Integer -> (Result-Column Number)) -(define (unit-column m i) - (cond - [(and (index? m) (index? i)) - (define v (make-vector m 0)) - (if (< i m) - (vector-set! v i 1) - (error 'unit-vector "dimension must be largest")) - (vector->matrix m 1 v)] - [else - (error 'unit-vector "expected two indices")])) - -(: column-height : (Column Number) -> Index) -(define (column-height v) - (if (vector? v) - (vector-length v) - (matrix-num-rows v))) - -(: unsafe-column->vector : (Column Number) -> (Vectorof Number)) -(define (unsafe-column->vector v) - (if (vector? v) v - (let () - (define-values (m n) (matrix-shape v)) - (if (= n 1) - (mutable-array-data (array->mutable-array v)) - (error 'unsafe-column->vector - "expected a column (vector or mx1 matrix), got ~a" v))))) - -(: column-scale : (Column Number) Number -> (Result-Column Number)) -(define (column-scale a s) - (if (vector? a) - (let*: ([n (vector-length a)] - [v : (Vectorof Number) (make-vector n 0)]) - (for: ([i (in-range 0 n)] - [x : Number (in-vector a)]) - (vector-set! v i (* s x))) - (->col-matrix v)) - (matrix-scale a s))) - -(: column+ : (Column Number) (Column Number) -> (Result-Column Number)) -(define (column+ v w) - (cond [(and (vector? v) (vector? w)) - (let ([n (vector-length v)] - [m (vector-length w)]) - (unless (= m n) - (error 'column+ - "expected two column vectors of the same length, got ~a and ~a" v w)) - (define: v+w : (Vectorof Number) (make-vector n 0)) - (for: ([i (in-range 0 n)] - [x : Number (in-vector v)] - [y : Number (in-vector w)]) - (vector-set! v+w i (+ x y))) - (->col-matrix v+w))] - [else - (unless (= (column-height v) (column-height w)) - (error 'column+ - "expected two column vectors of the same length, got ~a and ~a" v w)) - (array+ (->col-matrix v) (->col-matrix w))])) - -(: column-dot : (Column Number) (Column Number) -> Number) -(define (column-dot c d) - (define v (unsafe-column->vector c)) - (define w (unsafe-column->vector d)) - (define m (column-height v)) - (define s (column-height w)) - (cond - [(not (= m s)) (error 'column-dot - "expected two mx1 matrices with same number of rows, got ~a and ~a" - c d)] - [else - (for/sum: : Number ([i (in-range 0 m)]) - (assert i index?) - ; Note: If d is a vector of reals, - ; then the conjugate is a no-op - (* (unsafe-vector-ref v i) - (conjugate (unsafe-vector-ref w i))))])) - -(: column-norm : (Column Number) -> Real) -(define (column-norm v) - (define norm (sqrt (column-dot v v))) - (assert norm real?)) - -(: column-project : (Column Number) (Column Number) -> (Result-Column Number)) -; (column-project v w) -; Return the projection og vector v on vector w. -(define (column-project v w) - (let ([w.w (column-dot w w)]) - (if (zero? w.w) - (error 'column-project "projection on the zero vector not defined") - (matrix-scale (->col-matrix w) (/ (column-dot v w) w.w))))) - -(: column-project/unit : (Column Number) (Column Number) -> (Result-Column Number)) -; (column-project-on-unit v w) -; Return the projection og vector v on a unit vector w. -(define (column-project/unit v w) - (matrix-scale (->col-matrix w) (column-dot v w))) - -(: column-normalize : (Column Number) -> (Result-Column Number)) -; (column-vector-normalize v) -; Return unit vector with same direction as v. -; If v is the zero vector, the zero vector is returned. -(define (column-normalize w) - (let ([norm (column-norm w)] - [w (->col-matrix w)]) - (cond [(zero? norm) w] - [else (matrix-scale w (/ norm))]))) diff --git a/collects/math/private/matrix/matrix-comprehension.rkt b/collects/math/private/matrix/matrix-comprehension.rkt index ed07e4d971..f2a61ee90e 100644 --- a/collects/math/private/matrix/matrix-comprehension.rkt +++ b/collects/math/private/matrix/matrix-comprehension.rkt @@ -2,7 +2,7 @@ (require (for-syntax racket/base syntax/parse) - math/array) + "../array/array-comprehension.rkt") (provide for/matrix: for*/matrix: @@ -12,7 +12,7 @@ (module typed-defs typed/racket/base (require (for-syntax racket/base syntax/parse) - math/array) + "../array/array-comprehension.rkt") (provide (all-defined-out)) diff --git a/collects/math/private/matrix/matrix-constructors.rkt b/collects/math/private/matrix/matrix-constructors.rkt index e920990ac5..dd4f4f15e9 100644 --- a/collects/math/private/matrix/matrix-constructors.rkt +++ b/collects/math/private/matrix/matrix-constructors.rkt @@ -3,9 +3,12 @@ (require racket/fixnum racket/list racket/vector - math/array "matrix-types.rkt" - "../unsafe.rkt") + "../unsafe.rkt" + "../array/array-struct.rkt" + "../array/array-constructors.rkt" + "../array/array-unfold.rkt" + "../array/utils.rkt") (provide identity-matrix make-matrix @@ -42,7 +45,7 @@ ;; =================================================================================================== ;; Diagonal matrices -(: diagonal-matrix/zero (All (A) (Listof A) A -> (Array A))) +(: diagonal-matrix/zero (All (A) ((Listof A) A -> (Matrix A)))) (define (diagonal-matrix/zero xs zero) (cond [(empty? xs) (raise-argument-error 'diagonal-matrix "nonempty List" xs)] @@ -56,15 +59,14 @@ (cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)] [else zero])))])) -(: diagonal-matrix (case-> ((Listof Real) -> (Array Real)) - ((Listof Number) -> (Array Number)))) +(: diagonal-matrix (All (A) ((Listof A) -> (Matrix (U A 0))))) (define (diagonal-matrix xs) (diagonal-matrix/zero xs 0)) ;; =================================================================================================== ;; Block diagonal matrices -(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A))) +(: block-diagonal-matrix/zero* (All (A) (Vectorof (Matrix A)) A -> (Matrix A))) (define (block-diagonal-matrix/zero* as zero) (define num (vector-length as)) (define-values (ms ns) @@ -94,7 +96,7 @@ (vector-set! hs (unsafe-fx+ res-j j) k) (vector-set! js (unsafe-fx+ res-j j) (assert j index?)))) (values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n)))) - (define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as)) + (define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as)) (unsafe-build-array ((inst vector Index) res-m res-n) (λ: ([ij : Indexes]) @@ -114,7 +116,7 @@ [else zero])))) -(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A))) +(: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A)))) (define (block-diagonal-matrix/zero as zero) (let ([as (list->vector as)]) (define num (vector-length as)) @@ -125,8 +127,7 @@ [else (block-diagonal-matrix/zero* as zero)]))) -(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real)) - ((Listof (Array Number)) -> (Array Number)))) +(: block-diagonal-matrix (All (A) ((Listof (Matrix A)) -> (Matrix (U A 0))))) (define (block-diagonal-matrix as) (block-diagonal-matrix/zero as 0)) @@ -140,8 +141,8 @@ (cond [(real? x) (assert (expt x n) real?)] [else (expt x n)])) -(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real)) - ((Listof Number) Integer -> (Array Number)))) +(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Matrix Real)) + ((Listof Number) Integer -> (Matrix Number)))) (define (vandermonde-matrix xs n) (cond [(empty? xs) (raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)] diff --git a/collects/math/private/matrix/matrix-conversion.rkt b/collects/math/private/matrix/matrix-conversion.rkt index 57c08c7281..a65ff6aaa8 100644 --- a/collects/math/private/matrix/matrix-conversion.rkt +++ b/collects/math/private/matrix/matrix-conversion.rkt @@ -3,9 +3,13 @@ (require racket/fixnum racket/list racket/vector - math/array "matrix-types.rkt" "utils.rkt" + "../array/array-struct.rkt" + "../array/array-convert.rkt" + "../array/array-transform.rkt" + "../array/mutable-array.rkt" + "../array/array-fold.rkt" "../array/utils.rkt" "../unsafe.rkt") @@ -24,9 +28,9 @@ matrix->vector*) ;; =================================================================================================== -;; Flat conversion +;; Flat conversion to rectangular matrices -(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A)))) +(: list->matrix (All (A) (Integer Integer (Listof A) -> (Matrix A)))) (define (list->matrix m n xs) (cond [(or (not (index? m)) (= m 0)) (raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)] @@ -34,7 +38,7 @@ (raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)] [else (list->array (vector m n) xs)])) -(: matrix->list (All (A) ((Array A) -> (Listof A)))) +(: matrix->list (All (A) ((Matrix A) -> (Listof A)))) (define (matrix->list a) (array->list (ensure-matrix 'matrix->list a))) @@ -46,26 +50,18 @@ (raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)] [else (vector->array (vector m n) v)])) -(: matrix->vector (All (A) ((Array A) -> (Vectorof A)))) +(: matrix->vector (All (A) ((Matrix A) -> (Vectorof A)))) (define (matrix->vector a) (array->vector (ensure-matrix 'matrix->vector a))) -(: list->row-matrix (All (A) ((Listof A) -> (Array A)))) -(define (list->row-matrix xs) - (cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)] - [else (list->array ((inst vector Index) 1 (length xs)) xs)])) +;; =================================================================================================== +;; Flat conversion to column and row matrices -(: list->col-matrix (All (A) ((Listof A) -> (Array A)))) +(: list->col-matrix (All (A) ((Listof A) -> (Matrix A)))) (define (list->col-matrix xs) (cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)] [else (list->array ((inst vector Index) (length xs) 1) xs)])) -(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) -(define (vector->row-matrix xs) - (define n (vector-length xs)) - (cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)] - [else (vector->array ((inst vector Index) 1 n) xs)])) - (: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) (define (vector->col-matrix xs) (define n (vector-length xs)) @@ -80,40 +76,15 @@ (if (dk . > . 1) (values k dk) (loop (fx+ k 1)))] [else (values 0 0)]))) -(: array->row-matrix (All (A) ((Array A) -> (Array A)))) -(define (array->row-matrix arr) - (define (fail) - (raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr)) - (define ds (array-shape arr)) - (define dims (vector-length ds)) - (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) - (cond [(zero? (array-size arr)) (fail)] - [(row-matrix? arr) arr] - [(= num-ones dims) - (define: js : (Vectorof Index) (make-vector dims 0)) - (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) 1 1) - (λ: ([ij : Indexes]) (proc js)))] - [(= num-ones (- dims 1)) - (define-values (k n) (find-nontrivial-axis ds)) - (define js (make-thread-local-indexes dims)) - (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) 1 n) - (λ: ([ij : Indexes]) - (let ([js (js)]) - (unsafe-vector-set! js k (unsafe-vector-ref ij 1)) - (proc js))))] - [else (fail)])) - -(: array->col-matrix (All (A) ((Array A) -> (Array A)))) +(: array->col-matrix (All (A) ((Array A) -> (Matrix A)))) (define (array->col-matrix arr) (define (fail) - (raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr)) + (raise-argument-error 'array->col-matrix + "nonempty Array with exactly one axis of length >= 1" arr)) (define ds (array-shape arr)) (define dims (vector-length ds)) (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) (cond [(zero? (array-size arr)) (fail)] - [(col-matrix? arr) arr] [(= num-ones dims) (define: js : (Vectorof Index) (make-vector dims 0)) (define proc (unsafe-array-proc arr)) @@ -130,17 +101,19 @@ (proc js))))] [else (fail)])) -(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) -(define (->row-matrix xs) - (cond [(list? xs) (list->row-matrix xs)] - [(array? xs) (array->row-matrix xs)] - [else (vector->row-matrix xs)])) - -(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) +(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Matrix A)))) (define (->col-matrix xs) (cond [(list? xs) (list->col-matrix xs)] - [(array? xs) (array->col-matrix xs)] - [else (vector->col-matrix xs)])) + [(vector? xs) (vector->col-matrix xs)] + [(col-matrix? xs) xs] + [else (array->col-matrix xs)])) + +(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Matrix A)))) +(define (->row-matrix xs) + (cond [(list? xs) (array-axis-swap (list->col-matrix xs) 0 1)] + [(vector? xs) (array-axis-swap (vector->col-matrix xs) 0 1)] + [(row-matrix? xs) xs] + [else (array-axis-swap (array->col-matrix xs) 0 1)])) ;; =================================================================================================== ;; Nested conversion diff --git a/collects/math/private/matrix/matrix-expt.rkt b/collects/math/private/matrix/matrix-expt.rkt index b31391fad8..a3ed5752f6 100644 --- a/collects/math/private/matrix/matrix-expt.rkt +++ b/collects/math/private/matrix/matrix-expt.rkt @@ -1,7 +1,6 @@ -#lang typed/racket +#lang typed/racket/base -(require math/array - "matrix-types.rkt" +(require "matrix-types.rkt" "matrix-constructors.rkt" "matrix-arithmetic.rkt") diff --git a/collects/math/private/matrix/matrix-gauss-elim.rkt b/collects/math/private/matrix/matrix-gauss-elim.rkt new file mode 100644 index 0000000000..b13e73d6c2 --- /dev/null +++ b/collects/math/private/matrix/matrix-gauss-elim.rkt @@ -0,0 +1,62 @@ +#lang typed/racket/base + +(require racket/fixnum + racket/list + "matrix-types.rkt" + "matrix-conversion.rkt" + "utils.rkt" + "../unsafe.rkt" + "../vector/vector-mutate.rkt") + +(provide + matrix-gauss-elim + matrix-row-echelon) + +(: matrix-gauss-elim + (case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index))) + ((Matrix Real) Any -> (Values (Matrix Real) (Listof Index))) + ((Matrix Real) Any Any -> (Values (Matrix Real) (Listof Index))) + ((Matrix Number) -> (Values (Matrix Number) (Listof Index))) + ((Matrix Number) Any -> (Values (Matrix Number) (Listof Index))) + ((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index))))) +(define (matrix-gauss-elim M [jordan? #f] [unitize-pivot? #f]) + (define-values (m n) (matrix-shape M)) + (define rows (matrix->vector* M)) + (let loop ([#{i : Nonnegative-Fixnum} 0] + [#{j : Nonnegative-Fixnum} 0] + [#{without-pivot : (Listof Index)} empty]) + (cond + [(j . fx>= . n) + (values (vector*->matrix rows) + (reverse without-pivot))] + [(i . fx>= . m) + (values (vector*->matrix rows) + ;; None of the rest of the columns can have pivots + (let loop ([#{j : Nonnegative-Fixnum} j] [without-pivot without-pivot]) + (cond [(j . fx< . n) (loop (fx+ j 1) (cons j without-pivot))] + [else (reverse without-pivot)])))] + [else + (define-values (p pivot) (find-partial-pivot rows m i j)) + (cond + [(zero? pivot) (loop i (fx+ j 1) (cons j without-pivot))] + [else + ;; Swap pivot row with current + (vector-swap! rows i p) + ;; Possibly unitize the new current row + (let ([pivot (if unitize-pivot? + (begin (vector-scale! (unsafe-vector-ref rows i) (/ pivot)) + 1) + pivot)]) + (elim-rows! rows m i j pivot (if jordan? 0 (fx+ i 1))) + (loop (fx+ i 1) (fx+ j 1) without-pivot))])]))) + +(: matrix-row-echelon + (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Real) Any -> (Matrix Real)) + ((Matrix Real) Any Any -> (Matrix Real)) + ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) Any -> (Matrix Number)) + ((Matrix Number) Any Any -> (Matrix Number)))) +(define (matrix-row-echelon M [jordan? #f] [unitize-pivot? jordan?]) + (let-values ([(M _) (matrix-gauss-elim M jordan? unitize-pivot?)]) + M)) diff --git a/collects/math/private/matrix/matrix-gram-schmidt.rkt b/collects/math/private/matrix/matrix-gram-schmidt.rkt new file mode 100644 index 0000000000..741a8161de --- /dev/null +++ b/collects/math/private/matrix/matrix-gram-schmidt.rkt @@ -0,0 +1,80 @@ +#lang typed/racket/base + +(require racket/fixnum + racket/list + "matrix-types.rkt" + "matrix-basic.rkt" + "matrix-conversion.rkt" + "matrix-constructors.rkt" + "utils.rkt" + "../unsafe.rkt" + "../vector/vector-mutate.rkt" + "../array/array-struct.rkt" + "../array/array-constructors.rkt" + "../array/array-indexing.rkt") + +(provide matrix-gram-schmidt + matrix-basis-extension) + +(: find-nonzero-vector (case-> ((Vectorof (Vectorof Real)) -> (U #f Index)) + ((Vectorof (Vectorof Number)) -> (U #f Index)))) +(define (find-nonzero-vector vss) + (define n (vector-length vss)) + (cond [(= n 0) #f] + [else (let loop ([#{i : Nonnegative-Fixnum} 0]) + (cond [(i . fx< . n) + (define vs (unsafe-vector-ref vss i)) + (if (vector-zero? vs) (loop (fx+ i 1)) i)] + [else #f]))])) + +(: subtract-projections! + (case-> ((Vectorof (Vectorof Real)) Nonnegative-Fixnum Index (Vectorof Real) -> Void) + ((Vectorof (Vectorof Number)) Nonnegative-Fixnum Index (Vectorof Number) -> Void))) +(define (subtract-projections! rows i m row) + (let loop ([#{i : Nonnegative-Fixnum} i]) + (when (i . fx< . m) + (vector-sub-proj! (unsafe-vector-ref rows i) row #f) + (loop (fx+ i 1))))) + +(: matrix-gram-schmidt (case-> ((Matrix Real) -> (Array Real)) + ((Matrix Real) Any -> (Array Real)) + ((Matrix Real) Any Integer -> (Array Real)) + ((Matrix Number) -> (Array Number)) + ((Matrix Number) Any -> (Array Number)) + ((Matrix Number) Any Integer -> (Array Number)))) +;; Performs Gram-Schmidt orthogonalization on M, assuming the rows before `start' are already +;; orthogonal +(define (matrix-gram-schmidt M [normalize? #f] [start 0]) + (define rows (matrix->vector* (matrix-transpose M))) + (define m (vector-length rows)) + (define i (find-nonzero-vector rows)) + (cond [(not (index? start)) + (raise-argument-error 'matrix-gram-schmidt "Index" 2 M normalize? start)] + [i + (define rowi (unsafe-vector-ref rows i)) + (subtract-projections! rows (fxmax start (fx+ i 1)) m rowi) + (when normalize? (vector-normalize! rowi)) + (let loop ([#{i : Nonnegative-Fixnum} (fx+ i 1)] [bs (list rowi)]) + (cond [(i . fx< . m) + (define rowi (unsafe-vector-ref rows i)) + (cond [(vector-zero? rowi) (loop (fx+ i 1) bs)] + [else (subtract-projections! rows (fxmax start (fx+ i 1)) m rowi) + (when normalize? (vector-normalize! rowi)) + (loop (fx+ i 1) (cons rowi bs))])] + [else + (matrix-transpose (vector*->matrix (list->vector (reverse bs))))]))] + [else + (make-array (vector (matrix-num-rows M) 0) 0)])) + +(: matrix-basis-extension (case-> ((Matrix Real) -> (Array Real)) + ((Matrix Number) -> (Array Number)))) +(define (matrix-basis-extension B) + (define-values (m n) (matrix-shape B)) + (cond [(n . < . m) + (define S (matrix-gram-schmidt (matrix-augment (list B (identity-matrix m))) #f n)) + (define R (submatrix S (::) (:: n #f))) + (matrix-augment (take (sort/key (matrix-cols R) > matrix-norm) (- m n)))] + [(n . = . m) + (make-array (vector m 0) 0)] + [else + (raise-argument-error 'matrix-extend-row-basis "matrix? with width < height" B)])) diff --git a/collects/math/private/matrix/matrix-lu.rkt b/collects/math/private/matrix/matrix-lu.rkt new file mode 100644 index 0000000000..46b5f43f5e --- /dev/null +++ b/collects/math/private/matrix/matrix-lu.rkt @@ -0,0 +1,59 @@ +#lang typed/racket/base + +(require racket/fixnum + "matrix-types.rkt" + "matrix-conversion.rkt" + "matrix-arithmetic.rkt" + "utils.rkt" + "../unsafe.rkt" + "../vector/vector-mutate.rkt" + "../array/mutable-array.rkt") + +(provide matrix-lu) + +;; An LU factorization exists iff Gaussian elimination can be done without row swaps. + +(: matrix-lu + (All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real))) + ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) + ((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number)))))) +(define matrix-lu + (case-lambda + [(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))] + [(M fail) + (define m (square-matrix-size M)) + (define rows (matrix->vector* M)) + ;; Construct L in a weird way to prove to TR that it has the right type + (define L (array->mutable-array (matrix-scale M (ann 0 Real)))) + ;; Going to fill in the lower triangle by banging values into `ys' + (define ys (mutable-array-data L)) + (let loop ([#{i : Nonnegative-Fixnum} 0]) + (cond + [(i . fx< . m) + ;; Pivot must be on the diagonal + (define pivot (unsafe-vector2d-ref rows i i)) + (cond + [(zero? pivot) (values (fail) M)] + [else + ;; Zero out everything below the pivot + (let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)]) + (cond + [(l . fx< . m) + (define x_li (unsafe-vector2d-ref rows l i)) + (define y_li (/ x_li pivot)) + (unless (zero? x_li) + ;; Fill in lower triangle of L + (unsafe-vector-set! ys (+ (* l m) i) y_li) + ;; Add row i, scaled + (vector-scaled-add! (unsafe-vector-ref rows l) + (unsafe-vector-ref rows i) + (- y_li))) + (l-loop (fx+ l 1))] + [else + (loop (fx+ i 1))]))])] + [else + ;; L's lower triangle has been filled; now fill the diagonal with 1s + (for: ([i : Integer (in-range 0 m)]) + (vector-set! ys (+ (* i m) i) 1)) + (values L (vector*->matrix rows))]))])) diff --git a/collects/math/private/matrix/matrix-operations.rkt b/collects/math/private/matrix/matrix-operations.rkt deleted file mode 100644 index a2a994e394..0000000000 --- a/collects/math/private/matrix/matrix-operations.rkt +++ /dev/null @@ -1,442 +0,0 @@ -#lang typed/racket/base - -(require racket/fixnum - racket/list - racket/match - math/array - (only-in typed/racket conjugate) - "../unsafe.rkt" - "../vector/vector-mutate.rkt" - "matrix-types.rkt" - "matrix-constructors.rkt" - "matrix-conversion.rkt" - "matrix-arithmetic.rkt" - "matrix-basic.rkt" - "matrix-column.rkt" - "utils.rkt" - (for-syntax racket)) - -(provide - ;; Gaussian elimination - matrix-gauss-elim - matrix-row-echelon - ;; Derived functions - matrix-rank - matrix-nullity - matrix-determinant - matrix-determinant/row-reduction ; for testing - ;; Spaces - matrix-column-space - ;; Solving - matrix-invertible? - matrix-inverse - matrix-solve - ;; Projection - projection-on-orthogonal-basis - projection-on-orthonormal-basis - projection-on-subspace - gram-schmidt-orthogonal - gram-schmidt-orthonormal - ;; Decomposition - matrix-lu - matrix-qr - ) - -(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A))) -(define (unsafe-vector2d-ref vss i j) - (unsafe-vector-ref (unsafe-vector-ref vss i) j)) - -;; =================================================================================================== -;; Gaussian elimination - -(: find-partial-pivot - (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real)) - ((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number)))) -;; Find the element with maximum magnitude in a column -(define (find-partial-pivot rows m i j) - (define l (fx+ i 1)) - (define pivot (unsafe-vector2d-ref rows i j)) - (define mag-pivot (magnitude pivot)) - (let loop ([#{l : Nonnegative-Fixnum} l] [#{p : Index} i] [pivot pivot] [mag-pivot mag-pivot]) - (cond [(l . fx< . m) - (define new-pivot (unsafe-vector2d-ref rows l j)) - (define mag-new-pivot (magnitude new-pivot)) - (cond [(mag-new-pivot . > . mag-pivot) (loop (fx+ l 1) l new-pivot mag-new-pivot)] - [else (loop (fx+ l 1) p pivot mag-pivot)])] - [else (values p pivot)]))) - -(: elim-rows! - (case-> ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void) - ((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void))) -(define (elim-rows! rows m i j pivot start) - (let loop ([#{l : Nonnegative-Fixnum} start]) - (when (l . fx< . m) - (unless (l . fx= . i) - (define x_lj (unsafe-vector2d-ref rows l j)) - (unless (zero? x_lj) - (vector-scaled-add! (unsafe-vector-ref rows l) - (unsafe-vector-ref rows i) - (- (/ x_lj pivot))))) - (loop (fx+ l 1))))) - -(: matrix-gauss-elim (case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index))) - ((Matrix Real) Any -> (Values (Matrix Real) (Listof Index))) - ((Matrix Real) Any Any -> (Values (Matrix Real) (Listof Index))) - ((Matrix Number) -> (Values (Matrix Number) (Listof Index))) - ((Matrix Number) Any -> (Values (Matrix Number) (Listof Index))) - ((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index))))) -(define (matrix-gauss-elim M [jordan? #f] [unitize-pivot-row? #f]) - (define-values (m n) (matrix-shape M)) - (define rows (matrix->vector* M)) - (let loop ([#{i : Nonnegative-Fixnum} 0] - [#{j : Nonnegative-Fixnum} 0] - [#{without-pivot : (Listof Index)} empty]) - (cond - [(j . fx>= . n) - (values (vector*->matrix rows) - (reverse without-pivot))] - [(i . fx>= . m) - (values (vector*->matrix rows) - ;; None of the rest of the columns can have pivots - (let loop ([#{j : Nonnegative-Fixnum} j] [without-pivot without-pivot]) - (cond [(j . fx< . n) (loop (fx+ j 1) (cons j without-pivot))] - [else (reverse without-pivot)])))] - [else - (define-values (p pivot) (find-partial-pivot rows m i j)) - (cond - [(zero? pivot) (loop i (fx+ j 1) (cons j without-pivot))] - [else - ;; Swap pivot row with current - (vector-swap! rows i p) - ;; Possibly unitize the new current row - (let ([pivot (if unitize-pivot-row? - (begin (vector-scale! (unsafe-vector-ref rows i) (/ pivot)) - 1) - pivot)]) - (elim-rows! rows m i j pivot (if jordan? 0 (fx+ i 1))) - (loop (fx+ i 1) (fx+ j 1) without-pivot))])]))) - -;; =================================================================================================== -;; Simple functions derived from Gaussian elimination - -(: matrix-row-echelon - (case-> ((Matrix Real) -> (Matrix Real)) - ((Matrix Real) Any -> (Matrix Real)) - ((Matrix Real) Any Any -> (Matrix Real)) - ((Matrix Number) -> (Matrix Number)) - ((Matrix Number) Any -> (Matrix Number)) - ((Matrix Number) Any Any -> (Matrix Number)))) -(define (matrix-row-echelon M [jordan? #f] [unitize-pivot-row? jordan?]) - (let-values ([(M _) (matrix-gauss-elim M jordan? unitize-pivot-row?)]) - M)) - -(: matrix-rank : (Matrix Number) -> Index) -;; Returns the dimension of the column space (equiv. row space) of M -(define (matrix-rank M) - (define n (matrix-num-cols M)) - (define-values (_ cols-without-pivot) (matrix-gauss-elim M)) - (assert (- n (length cols-without-pivot)) index?)) - -(: matrix-nullity : (Matrix Number) -> Index) -;; Returns the dimension of the null space of M -(define (matrix-nullity M) - (define-values (_ cols-without-pivot) - (matrix-gauss-elim (ensure-matrix 'matrix-nullity M))) - (length cols-without-pivot)) - -(: maybe-cons-submatrix (All (A) ((Matrix A) Nonnegative-Fixnum Nonnegative-Fixnum (Listof (Matrix A)) - -> (Listof (Matrix A))))) -(define (maybe-cons-submatrix M j0 j1 Bs) - (cond [(= j0 j1) Bs] - [else (cons (submatrix M (::) (:: j0 j1)) Bs)])) - -(: matrix-column-space (All (A) (case-> ((Matrix Real) -> (Matrix Real)) - ((Matrix Real) (-> A) -> (U A (Matrix Real))) - ((Matrix Number) -> (Matrix Number)) - ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) -(define matrix-column-space - (case-lambda - [(M) (matrix-column-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))] - [(M fail) - (define n (matrix-num-cols M)) - (define-values (_ wps) (matrix-gauss-elim M)) - (cond [(empty? wps) M] - [(= (length wps) n) (fail)] - [else - (define next-j (first wps)) - (define Bs (maybe-cons-submatrix M 0 next-j empty)) - (let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs]) - (cond [(empty? wps) - (matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))] - [else - (define next-j (first wps)) - (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])])) - -;; =================================================================================================== -;; Determinant - -(: matrix-determinant (case-> ((Matrix Real) -> Real) - ((Matrix Number) -> Number))) -(define (matrix-determinant M) - (define m (square-matrix-size M)) - (cond - [(= m 1) (matrix-ref M 0 0)] - [(= m 2) (match-define (vector a b c d) - (mutable-array-data (array->mutable-array M))) - (- (* a d) (* b c))] - [(= m 3) (match-define (vector a b c d e f g h i) - (mutable-array-data (array->mutable-array M))) - (+ (* a (- (* e i) (* f h))) - (* (- b) (- (* d i) (* f g))) - (* c (- (* d h) (* e g))))] - [else - (matrix-determinant/row-reduction M)])) - -(: matrix-determinant/row-reduction (case-> ((Matrix Real) -> Real) - ((Matrix Number) -> Number))) -(define (matrix-determinant/row-reduction M) - (define m (square-matrix-size M)) - (define rows (matrix->vector* M)) - (let loop ([#{i : Nonnegative-Fixnum} 0] [#{sign : Real} 1]) - (cond - [(i . fx< . m) - (define-values (p pivot) (find-partial-pivot rows m i i)) - (cond - [(zero? pivot) 0] ; no pivot means non-invertible matrix - [else - (vector-swap! rows i p) ; negates determinant if i != p - (elim-rows! rows m i i pivot (fx+ i 1)) ; doesn't change the determinant - (loop (fx+ i 1) (if (= i p) sign (* -1 sign)))])] - [else - (define prod (unsafe-vector2d-ref rows 0 0)) - (let loop ([#{i : Nonnegative-Fixnum} 1] [prod prod]) - (cond [(i . fx< . m) - (loop (fx+ i 1) (* prod (unsafe-vector2d-ref rows i i)))] - [else (* prod sign)]))]))) - -;; =================================================================================================== -;; Inversion and solving linear systems - -(: matrix-invertible? ((Matrix Number) -> Boolean)) -(define (matrix-invertible? M) - (not (zero? (matrix-determinant M)))) - -(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real)) - ((Matrix Real) (-> A) -> (U A (Matrix Real))) - ((Matrix Number) -> (Matrix Number)) - ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) -(define matrix-inverse - (case-lambda - [(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))] - [(M fail) - (define m (square-matrix-size M)) - (define I (identity-matrix m)) - (define-values (IM^-1 wps) (matrix-gauss-elim (matrix-augment (list M I)) #t #t)) - (cond [(and (not (empty? wps)) (= (first wps) m)) - (submatrix IM^-1 (::) (:: m #f))] - [else (fail)])])) - -(: matrix-solve (All (A) (case-> - ((Matrix Real) (Matrix Real) -> (Matrix Real)) - ((Matrix Real) (Matrix Real) (-> A) -> (U A (Matrix Real))) - ((Matrix Number) (Matrix Number) -> (Matrix Number)) - ((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number)))))) -(define matrix-solve - (case-lambda - [(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))] - [(M B fail) - (define m (square-matrix-size M)) - (define-values (s t) (matrix-shape B)) - (cond [(= m s) - (define-values (IX wps) (matrix-gauss-elim (matrix-augment (list M B)) #t #t)) - (cond [(and (not (empty? wps)) (= (first wps) m)) - (submatrix IX (::) (:: m #f))] - [else (fail)])] - [else - (error 'matrix-solve - "matrices must have the same number of rows; given ~e and ~e" - M B)])])) - -;; =================================================================================================== -;; LU Factorization - -;; An LU factorization exists iff Gaussian elimination can be done without row swaps. - -(: matrix-lu - (All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) - ((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real))) - ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) - ((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number)))))) -(define matrix-lu - (case-lambda - [(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))] - [(M fail) - (define m (square-matrix-size M)) - (define rows (matrix->vector* M)) - ;; Construct L in a weird way to prove to TR that it has the right type - (define L (array->mutable-array (matrix-scale M (ann 0 Real)))) - ;; Going to fill in the lower triangle by banging values into `ys' - (define ys (mutable-array-data L)) - (let loop ([#{i : Nonnegative-Fixnum} 0]) - (cond - [(i . fx< . m) - ;; Pivot must be on the diagonal - (define pivot (unsafe-vector2d-ref rows i i)) - (cond - [(zero? pivot) (values (fail) M)] - [else - ;; Zero out everything below the pivot - (let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)]) - (cond - [(l . fx< . m) - (define x_li (unsafe-vector2d-ref rows l i)) - (define y_li (/ x_li pivot)) - (unless (zero? x_li) - ;; Fill in lower triangle of L - (unsafe-vector-set! ys (+ (* l m) i) y_li) - ;; Add row i, scaled - (vector-scaled-add! (unsafe-vector-ref rows l) - (unsafe-vector-ref rows i) - (- y_li))) - (l-loop (fx+ l 1))] - [else - (loop (fx+ i 1))]))])] - [else - ;; L's lower triangle has been filled; now fill the diagonal with 1s - (for: ([i : Integer (in-range 0 m)]) - (vector-set! ys (+ (* i m) i) 1)) - (values L (vector*->matrix rows))]))])) - -;; =================================================================================================== -;; Projections and orthogonalization - -(: projection-on-orthogonal-basis : - (Column Number) (Listof (Column Number)) -> (Result-Column Number)) -; (projection-on-orthogonal-basis v bs) -; Project the vector v on the orthogonal basis vectors in bs. -; The basis bs must be either the column vectors of a matrix -; or a sequence of column-vectors. -(define (projection-on-orthogonal-basis v bs) - (if (null? bs) - (error 'projection-on-orthogonal-basis - "received empty list of basis vectors") - (matrix-sum (map (λ: ([b : (Column Number)]) - (column-project v (->col-matrix b))) - bs)))) - -; (projection-on-orthonormal-basis v bs) -; Project the vector v on the orthonormal basis vectors in bs. -; The basis bs must be either the column vectors of a matrix -; or a sequence of column-vectors. -(: projection-on-orthonormal-basis : - (Column Number) (Listof (Column Number)) -> (Result-Column Number)) -(define (projection-on-orthonormal-basis v bs) - #;(for/matrix-sum ([b bs]) (matrix-scale b (column-dot v b))) - (define: sum : (U False (Result-Column Number)) #f) - (for ([b1 (in-list bs)]) - (define: b : (Result-Column Number) (->col-matrix b1)) - (cond [(not sum) (set! sum (column-project/unit v b))] - [else (set! sum (array+ (assert sum) (column-project/unit v b)))])) - (cond [sum (assert sum)] - [else (error 'projection-on-orthonormal-basis - "received empty list of basis vectors")])) - -(: gram-schmidt-orthogonal : (Listof (Column Number)) -> (Listof (Result-Column Number))) -; (gram-schmidt-orthogonal ws) -; Given a list ws of column vectors, produce -; an orthogonal basis for the span of the -; vectors in ws. -(define (gram-schmidt-orthogonal ws1) - (define ws (map (λ: ([w : (Column Number)]) (->col-matrix w)) ws1)) - (cond - [(null? ws) '()] - [(null? (cdr ws)) (list (car ws))] - [else - (: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number)) - -> (Listof (Result-Column Number))) - (define (loop vs ws) - (cond [(null? ws) vs] - [else - (define w (car ws)) - (let ([w-proj (projection-on-orthogonal-basis w vs)]) - ; Note: We project onto vs (not on the original ws) - ; in order to get numerical stability. - (let ([w-minus-proj (array-strict (array- w w-proj))]) - (if (matrix-zero? w-minus-proj) - (loop vs (cdr ws)) ; w in span{vs} => omit it - (loop (cons w-minus-proj vs) (cdr ws)))))])) - (reverse (loop (list (car ws)) (cdr ws)))])) - -(: gram-schmidt-orthonormal : (Listof (Column Number)) -> (Listof (Result-Column Number))) -; (gram-schmidt-orthonormal ws) -; Given a list ws of column vectors, produce -; an orthonormal basis for the span of the -; vectors in ws. -(define (gram-schmidt-orthonormal ws) - (map column-normalize (gram-schmidt-orthogonal ws))) - -(: projection-on-subspace : - (Column Number) (Listof (Column Number)) -> (Result-Column Number)) -; (projection-on-subspace v ws) -; Returns the projection of v on span{w_i}, w_i in ws. -(define (projection-on-subspace v ws) - (projection-on-orthogonal-basis v (gram-schmidt-orthogonal ws))) - -(: extend-span-to-basis : - (Listof (Matrix Number)) Integer -> (Listof (Matrix Number))) -; Extend the basis in vs to r-dimensional basis -(define (extend-span-to-basis vs r) - (define-values (m n) (matrix-shape (car vs))) - (: loop : (Listof (Matrix Number)) (Listof (Matrix Number)) Integer -> (Listof (Matrix Number))) - (define (loop vs ws i) - (if (>= i m) - ws - (let () - (define ei (unit-column m i)) - (define pi (projection-on-subspace ei vs)) - (if (matrix= ei pi) - (loop vs ws (+ i 1)) - (let ([w (array- ei pi)]) - (loop (cons w vs) (cons w ws) (+ i 1))))))) - (: norm> : (Matrix Number) (Matrix Number) -> Boolean) - (define (norm> v w) - (> (column-norm v) (column-norm w))) - (if (index? r) - (take (sort (loop vs '() 0) norm>) r) - (error 'extend-span-to-basis "expected index as second argument, got ~a" r))) - -;; =================================================================================================== -;; QR decomposition - -(: matrix-qr : (Matrix Number) -> (Values (Matrix Number) (Matrix Number))) -(define (matrix-qr M) - ; compute the QR-facorization - ; 1) QR = M - ; 2) columns of Q is are orthonormal - ; 3) R is upper-triangular - ; Note: columnspace(A)=columnspace(Q) ! - (define-values (m n) (matrix-shape M)) - (let* ([basis-for-column-space - (gram-schmidt-orthonormal (matrix-cols M))] - [extension - (extend-span-to-basis - basis-for-column-space (- n (length basis-for-column-space)))] - [Q (matrix-augment - (append basis-for-column-space - (map column-normalize - extension)))] - [R - (let () - (define v (make-vector (* n n) (ann 0 Number))) - (for*: ([i (in-range 0 n)] - [j (in-range 0 n)]) - (if (> i j) - (void) ; v(i,j)=0 already - (let () - (define: sum : Number 0) - (for: ([k (in-range m)]) - (set! sum (+ sum (* (matrix-ref Q k i) - (matrix-ref M k j))))) - (vector-set! v (+ (* i n) j) sum)))) - (vector->matrix n n v))]) - (values Q R))) diff --git a/collects/math/private/matrix/matrix-operator-norm.rkt b/collects/math/private/matrix/matrix-operator-norm.rkt new file mode 100644 index 0000000000..4ecb8a6d5e --- /dev/null +++ b/collects/math/private/matrix/matrix-operator-norm.rkt @@ -0,0 +1,121 @@ +#lang typed/racket/base + +#| +Two of the functions defined here currently just raise an error: `matrix-op-2norm' and +`matrix-op-angle'. They need to compute, respectively, the maximum and minimum singular values of +their matrix argument. + +See "How to Measure Errors" in the LAPACK manual for more details: + + http://www.netlib.org/lapack/lug/node75.html + http://www.netlib.org/lapack/lug/node76.html +|# + +(require racket/list + racket/fixnum + math/flonum + "matrix-types.rkt" + "matrix-arithmetic.rkt" + "matrix-constructors.rkt" + "matrix-basic.rkt" + "utils.rkt" + "../array/array-struct.rkt" + "../array/array-pointwise.rkt" + "../array/array-fold.rkt" + ) + + +(provide + ;; Operator norms + matrix-op-1norm + matrix-op-2norm + matrix-op-inf-norm + matrix-basis-angle + ;; Error measurement + matrix-error-norm + matrix-absolute-error + matrix-relative-error + ;; Approximate predicates + matrix-identity? + matrix-orthonormal? + ) + +(: matrix-op-1norm ((Matrix Number) -> Nonnegative-Real)) +;; When M is a column matrix, this is equivalent to matrix-1norm +(define (matrix-op-1norm M) + (assert (apply max (map matrix-1norm (matrix-cols M))) nonnegative?)) + +(: matrix-op-2norm ((Matrix Number) -> Nonnegative-Real)) +;; When M is a column matrix, this is equivalent to matrix-2norm +(define (matrix-op-2norm M) + ;(matrix-max-singular-value M) + ;(sqrt (matrix-max-eigenvalue M)) + (error 'unimplemented)) + +(: matrix-op-inf-norm ((Matrix Number) -> Nonnegative-Real)) +;; When M is a column matrix, this is equivalent to matrix-inf-norm +(define (matrix-op-inf-norm M) + (assert (apply max (map matrix-1norm (matrix-rows M))) nonnegative?)) + +(: matrix-basis-angle (case-> ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Number) (Matrix Number) -> Number))) +;; Returns the angle between the two subspaces spanned by the two given sets of column vectors +(define (matrix-basis-angle M R) + ;(acos (matrix-min-singular-value (matrix* (matrix-hermitian M) R))) + (error 'unimplemented)) + +;; =================================================================================================== +;; Error measurement + +(: matrix-error-norm (Parameterof ((Matrix Number) -> Nonnegative-Real))) +(define matrix-error-norm (make-parameter matrix-op-inf-norm)) + +(: matrix-absolute-error + (case-> ((Matrix Number) (Matrix Number) -> Nonnegative-Real) + ((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real) + -> Nonnegative-Real))) +(define (matrix-absolute-error M R [norm (matrix-error-norm)]) + (define-values (m n) (matrix-shapes 'matrix-absolute-error M R)) + (array-strict! M) + (array-strict! R) + (cond [(array-all-and (inline-array-map eqv? M R)) 0] + [(and (array-all-and (inline-array-map number-rational? M)) + (array-all-and (inline-array-map number-rational? R))) + (norm (matrix- (inline-array-map inexact->exact M) + (inline-array-map inexact->exact R)))] + [else +inf.0])) + +(: matrix-relative-error + (case-> ((Matrix Number) (Matrix Number) -> Nonnegative-Real) + ((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real) + -> Nonnegative-Real))) +(define (matrix-relative-error M R [norm (matrix-error-norm)]) + (define-values (m n) (matrix-shapes 'matrix-relative-error M R)) + (array-strict! M) + (array-strict! R) + (cond [(array-all-and (inline-array-map eqv? M R)) 0] + [(and (array-all-and (inline-array-map number-rational? M)) + (array-all-and (inline-array-map number-rational? R))) + (define num (norm (matrix- M R))) + (define den (norm R)) + (cond [(and (zero? num) (zero? den)) 0] + [(zero? den) +inf.0] + [else (assert (/ num den) nonnegative?)])] + [else +inf.0])) + +;; =================================================================================================== +;; Approximate predicates + +(: matrix-identity? (case-> ((Matrix Number) -> Boolean) + ((Matrix Number) Real -> Boolean))) +(define (matrix-identity? M [eps (* 10 epsilon.0)]) + (cond [(eps . < . 0) (raise-argument-error 'matrix-identity? "Nonnegative-Real" 1 M eps)] + [else (and (square-matrix? M) + (<= (matrix-relative-error M (identity-matrix (square-matrix-size M))) eps))])) + +(: matrix-orthonormal? (case-> ((Matrix Number) -> Boolean) + ((Matrix Number) Real -> Boolean))) +(define (matrix-orthonormal? M [eps (* 10 epsilon.0)]) + (cond [(eps . < . 0) (raise-argument-error 'matrix-orthonormal? "Nonnegative-Real" 1 M eps)] + [else (and (square-matrix? M) + (matrix-identity? (matrix* M (matrix-hermitian M)) eps))])) diff --git a/collects/math/private/matrix/matrix-qr.rkt b/collects/math/private/matrix/matrix-qr.rkt new file mode 100644 index 0000000000..530e803182 --- /dev/null +++ b/collects/math/private/matrix/matrix-qr.rkt @@ -0,0 +1,39 @@ +#lang typed/racket/base + +(require "matrix-types.rkt" + "matrix-basic.rkt" + "matrix-arithmetic.rkt" + "matrix-constructors.rkt" + "matrix-gram-schmidt.rkt" + "../array/array-transform.rkt") + +(provide matrix-qr) + +#| +QR decomposition currently does Gram-Schmidt twice, as suggested by + + Luc Giraud, Julien Langou, Miroslav Rozloznik. + On the round-off error analysis of the Gram-Schmidt algorithm with reorthogonalization. + Technical Report, 2002. + +It normalizes only the second time. + +I've verified experimentally that, with random, square matrices (elements in [0,1]), doing so +produces matrices for which `matrix-orthogonal?' returns #t with eps <= 10*epsilon.0, apparently +independently of the matrix size. +|# + +(: matrix-qr (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) + ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))) +(define (matrix-qr M [full? #t]) + (define B (matrix-gram-schmidt M #f)) + (define Q + (matrix-gram-schmidt + (cond [(or (square-matrix? B) (and (matrix? B) (not full?))) B] + [(matrix? B) (array-append* (list B (matrix-basis-extension B)) 1)] + [full? (identity-matrix (matrix-num-rows M))] + [else (matrix-col (identity-matrix (matrix-num-rows M)) 0)]) + #t)) + (values Q (matrix-upper-triangle (matrix* (matrix-hermitian Q) M)))) diff --git a/collects/math/private/matrix/matrix-solve.rkt b/collects/math/private/matrix/matrix-solve.rkt new file mode 100644 index 0000000000..b1b0313636 --- /dev/null +++ b/collects/math/private/matrix/matrix-solve.rkt @@ -0,0 +1,107 @@ +#lang typed/racket/base + +(require racket/fixnum + racket/match + racket/list + "matrix-types.rkt" + "matrix-constructors.rkt" + "matrix-conversion.rkt" + "matrix-basic.rkt" + "matrix-gauss-elim.rkt" + "utils.rkt" + "../vector/vector-mutate.rkt" + "../array/array-indexing.rkt" + "../array/mutable-array.rkt") + +(provide + matrix-determinant + matrix-determinant/row-reduction ; for testing + matrix-invertible? + matrix-inverse + matrix-solve) + +;; =================================================================================================== +;; Determinant + +(: matrix-determinant (case-> ((Matrix Real) -> Real) + ((Matrix Number) -> Number))) +(define (matrix-determinant M) + (define m (square-matrix-size M)) + (cond + [(= m 1) (matrix-ref M 0 0)] + [(= m 2) (match-define (vector a b c d) + (mutable-array-data (array->mutable-array M))) + (- (* a d) (* b c))] + [(= m 3) (match-define (vector a b c d e f g h i) + (mutable-array-data (array->mutable-array M))) + (+ (* a (- (* e i) (* f h))) + (* (- b) (- (* d i) (* f g))) + (* c (- (* d h) (* e g))))] + [else + (matrix-determinant/row-reduction M)])) + +(: matrix-determinant/row-reduction (case-> ((Matrix Real) -> Real) + ((Matrix Number) -> Number))) +(define (matrix-determinant/row-reduction M) + (define m (square-matrix-size M)) + (define rows (matrix->vector* M)) + (let loop ([#{i : Nonnegative-Fixnum} 0] [#{sign : Real} 1]) + (cond + [(i . fx< . m) + (define-values (p pivot) (find-partial-pivot rows m i i)) + (cond + [(zero? pivot) 0] ; no pivot means non-invertible matrix + [else + (let ([sign (if (= i p) sign (begin (vector-swap! rows i p) ; swapping negates sign + (* -1 sign)))]) + (elim-rows! rows m i i pivot (fx+ i 1)) ; adding scaled rows doesn't change it + (loop (fx+ i 1) sign))])] + [else + (define prod (unsafe-vector2d-ref rows 0 0)) + (let loop ([#{i : Nonnegative-Fixnum} 1] [prod prod]) + (cond [(i . fx< . m) + (loop (fx+ i 1) (* prod (unsafe-vector2d-ref rows i i)))] + [else (* prod sign)]))]))) + +;; =================================================================================================== +;; Inversion and solving linear systems + +(: matrix-invertible? ((Matrix Number) -> Boolean)) +(define (matrix-invertible? M) + (not (zero? (matrix-determinant M)))) + +(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Real) (-> A) -> (U A (Matrix Real))) + ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) +(define matrix-inverse + (case-lambda + [(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))] + [(M fail) + (define m (square-matrix-size M)) + (define I (identity-matrix m)) + (define-values (IM^-1 wps) (matrix-gauss-elim (matrix-augment (list M I)) #t #t)) + (cond [(and (not (empty? wps)) (= (first wps) m)) + (submatrix IM^-1 (::) (:: m #f))] + [else (fail)])])) + +(: matrix-solve (All (A) (case-> + ((Matrix Real) (Matrix Real) -> (Matrix Real)) + ((Matrix Real) (Matrix Real) (-> A) -> (U A (Matrix Real))) + ((Matrix Number) (Matrix Number) -> (Matrix Number)) + ((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number)))))) +(define matrix-solve + (case-lambda + [(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))] + [(M B fail) + (define m (square-matrix-size M)) + (define-values (s t) (matrix-shape B)) + (cond [(= m s) + (define-values (IX wps) (matrix-gauss-elim (matrix-augment (list M B)) #t #t)) + (cond [(and (not (empty? wps)) (= (first wps) m)) + (submatrix IX (::) (:: m #f))] + [else (fail)])] + [else + (error 'matrix-solve + "matrices must have the same number of rows; given ~e and ~e" + M B)])])) diff --git a/collects/math/private/matrix/matrix-subspace.rkt b/collects/math/private/matrix/matrix-subspace.rkt new file mode 100644 index 0000000000..f89c4c8f47 --- /dev/null +++ b/collects/math/private/matrix/matrix-subspace.rkt @@ -0,0 +1,57 @@ +#lang typed/racket/base + +(require racket/fixnum + racket/list + "matrix-types.rkt" + "matrix-basic.rkt" + "matrix-gauss-elim.rkt" + "utils.rkt" + "../array/array-indexing.rkt" + "../array/array-constructors.rkt") + +(provide + matrix-rank + matrix-nullity + matrix-col-space) + +(: matrix-rank : (Matrix Number) -> Index) +;; Returns the dimension of the column space (equiv. row space) of M +(define (matrix-rank M) + (define n (matrix-num-cols M)) + (define-values (_ cols-without-pivot) (matrix-gauss-elim M)) + (assert (- n (length cols-without-pivot)) index?)) + +(: matrix-nullity : (Matrix Number) -> Index) +;; Returns the dimension of the null space of M +(define (matrix-nullity M) + (define-values (_ cols-without-pivot) + (matrix-gauss-elim (ensure-matrix 'matrix-nullity M))) + (length cols-without-pivot)) + +(: maybe-cons-submatrix (All (A) ((Matrix A) Nonnegative-Fixnum Nonnegative-Fixnum (Listof (Matrix A)) + -> (Listof (Matrix A))))) +(define (maybe-cons-submatrix M j0 j1 Bs) + (cond [(= j0 j1) Bs] + [else (cons (submatrix M (::) (:: j0 j1)) Bs)])) + +(: matrix-col-space (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + ((Matrix Real) (-> A) -> (U A (Matrix Real))) + ((Matrix Number) -> (Matrix Number)) + ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) +(define matrix-col-space + (case-lambda + [(M) (matrix-col-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))] + [(M fail) + (define n (matrix-num-cols M)) + (define-values (_ wps) (matrix-gauss-elim M)) + (cond [(empty? wps) M] + [(= (length wps) n) (fail)] + [else + (define next-j (first wps)) + (define Bs (maybe-cons-submatrix M 0 next-j empty)) + (let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs]) + (cond [(empty? wps) + (matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))] + [else + (define next-j (first wps)) + (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])])) diff --git a/collects/math/private/matrix/matrix-syntax.rkt b/collects/math/private/matrix/matrix-syntax.rkt index 62ffc52d8c..48e3924cdb 100644 --- a/collects/math/private/matrix/matrix-syntax.rkt +++ b/collects/math/private/matrix/matrix-syntax.rkt @@ -3,7 +3,7 @@ (require (for-syntax racket/base syntax/parse) (only-in typed/racket/base :) - math/array) + "../array/array-struct.rkt") (provide matrix row-matrix col-matrix) diff --git a/collects/math/private/matrix/typed-matrix-arithmetic.rkt b/collects/math/private/matrix/typed-matrix-arithmetic.rkt index 0f53c02252..f2810d559a 100644 --- a/collects/math/private/matrix/typed-matrix-arithmetic.rkt +++ b/collects/math/private/matrix/typed-matrix-arithmetic.rkt @@ -1,10 +1,12 @@ #lang typed/racket/base (require racket/list - math/array "matrix-types.rkt" "utils.rkt" - (except-in "untyped-matrix-arithmetic.rkt" matrix-map)) + (except-in "untyped-matrix-arithmetic.rkt" matrix-map) + "../array/array-struct.rkt" + "../array/array-fold.rkt" + "../array/utils.rkt") (provide matrix-map matrix= @@ -14,16 +16,17 @@ matrix-scale matrix-sum) -(: matrix-map (All (R A B T ...) - (case-> ((A -> R) (Array A) -> (Array R)) - ((A B T ... T -> R) (Array A) (Array B) (Array T) ... T -> (Array R))))) +(: matrix-map + (All (R A B T ...) + (case-> ((A -> R) (Matrix A) -> (Matrix R)) + ((A B T ... T -> R) (Matrix A) (Matrix B) (Matrix T) ... T -> (Matrix R))))) (define matrix-map (case-lambda: - [([f : (A -> R)] [arr : (Array A)]) + [([f : (A -> R)] [arr : (Matrix A)]) (inline-matrix-map f arr)] - [([f : (A B -> R)] [arr0 : (Array A)] [arr1 : (Array B)]) + [([f : (A B -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix B)]) (inline-matrix-map f arr0 arr1)] - [([f : (A B T ... T -> R)] [arr0 : (Array A)] [arr1 : (Array B)] . [arrs : (Array T) ... T]) + [([f : (A B T ... T -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix B)] . [arrs : (Matrix T) ... T]) (define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs)) (define g0 (unsafe-array-proc arr0)) (define g1 (unsafe-array-proc arr1)) @@ -33,7 +36,7 @@ (λ: ([js : Indexes]) (apply f (g0 js) (g1 js) (map (λ: ([g : (Indexes -> T)]) (g js)) gs))))])) -(: matrix=? ((Array Number) (Array Number) -> Boolean)) +(: matrix=? ((Matrix Number) (Matrix Number) -> Boolean)) (define (matrix=? arr0 arr1) (define-values (m0 n0) (matrix-shape arr0)) (define-values (m1 n1) (matrix-shape arr1)) @@ -46,35 +49,35 @@ (λ: ([js : Indexes]) (= (proc0 js) (proc1 js)))))))) -(: matrix= (case-> ((Array Number) (Array Number) -> Boolean) - ((Array Number) (Array Number) (Array Number) (Array Number) * -> Boolean))) +(: matrix= (case-> ((Matrix Number) (Matrix Number) -> Boolean) + ((Matrix Number) (Matrix Number) (Matrix Number) (Matrix Number) * -> Boolean))) (define matrix= (case-lambda: - [([arr0 : (Array Number)] [arr1 : (Array Number)]) (matrix=? arr0 arr1)] - [([arr0 : (Array Number)] [arr1 : (Array Number)] . [arrs : (Array Number) *]) + [([arr0 : (Matrix Number)] [arr1 : (Matrix Number)]) (matrix=? arr0 arr1)] + [([arr0 : (Matrix Number)] [arr1 : (Matrix Number)] . [arrs : (Matrix Number) *]) (and (matrix=? arr0 arr1) - (let: loop : Boolean ([arr1 : (Array Number) arr1] - [arrs : (Listof (Array Number)) arrs]) + (let: loop : Boolean ([arr1 : (Matrix Number) arr1] + [arrs : (Listof (Matrix Number)) arrs]) (cond [(empty? arrs) #t] [else (and (matrix=? arr1 (first arrs)) (loop (first arrs) (rest arrs)))])))])) -(: matrix* (case-> ((Array Real) (Array Real) * -> (Array Real)) - ((Array Number) (Array Number) * -> (Array Number)))) +(: matrix* (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) + ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) (define (matrix* a . as) (let loop ([a a] [as as]) (cond [(empty? as) a] [else (loop (inline-matrix* a (first as)) (rest as))]))) -(: matrix+ (case-> ((Array Real) (Array Real) * -> (Array Real)) - ((Array Number) (Array Number) * -> (Array Number)))) +(: matrix+ (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) + ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) (define (matrix+ a . as) (let loop ([a a] [as as]) (cond [(empty? as) a] [else (loop (inline-matrix+ a (first as)) (rest as))]))) -(: matrix- (case-> ((Array Real) (Array Real) * -> (Array Real)) - ((Array Number) (Array Number) * -> (Array Number)))) +(: matrix- (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) + ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) (define (matrix- a . as) (cond [(empty? as) (inline-matrix- a)] [else @@ -82,12 +85,12 @@ (cond [(empty? as) a] [else (loop (inline-matrix- a (first as)) (rest as))]))])) -(: matrix-scale (case-> ((Array Real) Real -> (Array Real)) - ((Array Number) Number -> (Array Number)))) +(: matrix-scale (case-> ((Matrix Real) Real -> (Matrix Real)) + ((Matrix Number) Number -> (Matrix Number)))) (define (matrix-scale a x) (inline-matrix-scale a x)) -(: matrix-sum (case-> ((Listof (Array Real)) -> (Array Real)) - ((Listof (Array Number)) -> (Array Number)))) +(: matrix-sum (case-> ((Listof (Matrix Real)) -> (Matrix Real)) + ((Listof (Matrix Number)) -> (Matrix Number)))) (define (matrix-sum lst) (cond [(empty? lst) (raise-argument-error 'matrix-sum "nonempty List" lst)] [else (apply matrix+ lst)])) diff --git a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt index f7cd81d5b9..8db143233a 100644 --- a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt +++ b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt @@ -10,13 +10,16 @@ (module syntax-defs racket/base (require (for-syntax racket/base) (only-in typed/racket/base λ: : inst Index) - math/array "matrix-types.rkt" - "utils.rkt") + "utils.rkt" + "../array/array-struct.rkt" + "../array/array-fold.rkt" + "../array/array-transform.rkt" + "../array/utils.rkt") (provide (all-defined-out)) - ;(: matrix-multiply ((Array Number) (Array Number) -> (Array Number))) + ;(: matrix-multiply ((Matrix Number) (Matrix Number) -> (Matrix Number))) ;; This is a macro so the result can have as precise a type as possible (define-syntax-rule (inline-matrix-multiply arr-expr brr-expr) (let ([arr arr-expr] @@ -69,26 +72,31 @@ (define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...)) (define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...)) - (define-syntax-rule (inline-matrix-scale arr x) (inline-matrix-map (λ (y) (* x y)) arr)) + (define-syntax-rule (inline-matrix-scale arr x-expr) + (let ([x x-expr]) + (inline-matrix-map (λ (y) (* x y)) arr))) ) ; module (module untyped-defs typed/racket/base - (require math/array - (submod ".." syntax-defs) - "utils.rkt") + (require (submod ".." syntax-defs) + "matrix-types.rkt" + "utils.rkt" + "../array/array-struct.rkt" + "../array/utils.rkt") (provide matrix-map) - (: matrix-map (All (R A) (case-> ((A -> R) (Array A) -> (Array R)) - ((A A A * -> R) (Array A) (Array A) (Array A) * -> (Array R))))) + (: matrix-map + (All (R A) (case-> ((A -> R) (Matrix A) -> (Matrix R)) + ((A A A * -> R) (Matrix A) (Matrix A) (Matrix A) * -> (Matrix R))))) (define matrix-map (case-lambda: - [([f : (A -> R)] [arr : (Array A)]) + [([f : (A -> R)] [arr : (Matrix A)]) (inline-matrix-map f arr)] - [([f : (A A -> R)] [arr0 : (Array A)] [arr1 : (Array A)]) + [([f : (A A -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix A)]) (inline-matrix-map f arr0 arr1)] - [([f : (A A A * -> R)] [arr0 : (Array A)] [arr1 : (Array A)] . [arrs : (Array A) *]) + [([f : (A A A * -> R)] [arr0 : (Matrix A)] [arr1 : (Matrix A)] . [arrs : (Matrix A) *]) (define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs)) (define g0 (unsafe-array-proc arr0)) (define g1 (unsafe-array-proc arr1)) diff --git a/collects/math/private/matrix/utils.rkt b/collects/math/private/matrix/utils.rkt index 016e1ff0a7..c7c14089bc 100644 --- a/collects/math/private/matrix/utils.rkt +++ b/collects/math/private/matrix/utils.rkt @@ -1,9 +1,11 @@ #lang typed/racket/base -(require racket/match - racket/string - math/array - "matrix-types.rkt") +(require racket/string + racket/fixnum + "matrix-types.rkt" + "../unsafe.rkt" + "../array/array-struct.rkt" + "../vector/vector-mutate.rkt") (provide (all-defined-out)) @@ -36,3 +38,61 @@ (: ensure-matrix (All (A) Symbol (Array A) -> (Array A))) (define (ensure-matrix name a) (if (matrix? a) a (raise-argument-error name "matrix?" a))) + +(: ensure-row-matrix (All (A) Symbol (Array A) -> (Array A))) +(define (ensure-row-matrix name a) + (if (row-matrix? a) a (raise-argument-error name "row-matrix?" a))) + +(: ensure-col-matrix (All (A) Symbol (Array A) -> (Array A))) +(define (ensure-col-matrix name a) + (if (col-matrix? a) a (raise-argument-error name "col-matrix?" a))) + +(: sort/key (All (A B) (case-> ((Listof A) (B B -> Boolean) (A -> B) -> (Listof A)) + ((Listof A) (B B -> Boolean) (A -> B) Boolean -> (Listof A))))) +;; Sometimes necessary because TR can't do inference with keyword arguments yet +(define (sort/key lst lt? key [cache-keys? #f]) + ((inst sort A B) lst lt? #:key key #:cache-keys? cache-keys?)) + +(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A))) +(define (unsafe-vector2d-ref vss i j) + (unsafe-vector-ref (unsafe-vector-ref vss i) j)) + +;; Note: this accepts +nan.0 +(define nonnegative? + (λ: ([x : Real]) (not (x . < . 0)))) + +(define number-rational? + (λ: ([z : Number]) + (cond [(real? z) (rational? z)] + [else (and (rational? (real-part z)) + (rational? (imag-part z)))]))) + +(: find-partial-pivot + (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real)) + ((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number)))) +;; Find the element with maximum magnitude in a column +(define (find-partial-pivot rows m i j) + (define l (fx+ i 1)) + (define pivot (unsafe-vector2d-ref rows i j)) + (define mag-pivot (magnitude pivot)) + (let loop ([#{l : Nonnegative-Fixnum} l] [#{p : Index} i] [pivot pivot] [mag-pivot mag-pivot]) + (cond [(l . fx< . m) + (define new-pivot (unsafe-vector2d-ref rows l j)) + (define mag-new-pivot (magnitude new-pivot)) + (cond [(mag-new-pivot . > . mag-pivot) (loop (fx+ l 1) l new-pivot mag-new-pivot)] + [else (loop (fx+ l 1) p pivot mag-pivot)])] + [else (values p pivot)]))) + +(: elim-rows! + (case-> ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void) + ((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void))) +(define (elim-rows! rows m i j pivot start) + (let loop ([#{l : Nonnegative-Fixnum} start]) + (when (l . fx< . m) + (unless (l . fx= . i) + (define x_lj (unsafe-vector2d-ref rows l j)) + (unless (zero? x_lj) + (vector-scaled-add! (unsafe-vector-ref rows l) + (unsafe-vector-ref rows i) + (- (/ x_lj pivot))))) + (loop (fx+ l 1))))) diff --git a/collects/math/private/vector/vector-mutate.rkt b/collects/math/private/vector/vector-mutate.rkt index 9e4356c419..eea5b46ce2 100644 --- a/collects/math/private/vector/vector-mutate.rkt +++ b/collects/math/private/vector/vector-mutate.rkt @@ -16,9 +16,7 @@ (: mag^2 (Number -> Nonnegative-Real)) (define (mag^2 x) - (define y (* x (conjugate x))) - (cond [(and (real? y) (y . >= . 0)) y] - [else (error 'impossible)])) + (max 0 (real-part (* x (conjugate x))))) (: vector-swap! (All (A) ((Vectorof A) Integer Integer -> Void))) (define (vector-swap! vs i0 i1) diff --git a/collects/math/tests/matrix-tests.rkt b/collects/math/tests/matrix-tests.rkt index 368b69a12e..55340d7434 100644 --- a/collects/math/tests/matrix-tests.rkt +++ b/collects/math/tests/matrix-tests.rkt @@ -4,14 +4,24 @@ math/base math/flonum math/matrix - "../private/matrix/matrix-column.rkt" "test-utils.rkt") +(define-syntax (check-matrix=? stx) + (syntax-case stx () + [(_ a b) + (syntax/loc stx (check-true (matrix=? a b) (format "(matrix=? ~v ~v)" a b)))] + [(_ a b eps) + (syntax/loc stx (check-true (matrix=? a b eps) (format "(matrix=? ~v ~v ~v)" a b eps)))])) + (: random-matrix (case-> (Integer Integer -> (Matrix Integer)) - (Integer Integer Integer -> (Matrix Integer)))) + (Integer Integer Integer -> (Matrix Integer)) + (Integer Integer Integer Integer -> (Matrix Integer)))) ;; Generates a random matrix with Natural elements < k. Useful to test properties. -(define (random-matrix m n [k 100]) - (array-strict (build-array (vector m n) (λ (_) (random k))))) +(define random-matrix + (case-lambda + [(m n) (random-matrix m n 100)] + [(m n k) (array-strict (build-matrix m n (λ (i j) (random-natural k))))] + [(m n k0 k1) (array-strict (build-matrix m n (λ (i j) (random-integer k0 k1))))])) (define nonmatrices (list (make-array #() 0) @@ -21,6 +31,16 @@ (make-array #(0 0) 0) (make-array #(1 1 1) 0))) +(: matrix-l ((Matrix Number) -> (Matrix Number))) +(define (matrix-l M) + (define-values (L U) (matrix-lu M)) + L) + +(: matrix-q ((Matrix Number) -> (Matrix Number))) +(define (matrix-q M) + (define-values (Q R) (matrix-qr M)) + Q) + ;; =================================================================================================== ;; Literal syntax @@ -74,13 +94,6 @@ (for: ([a (in-list nonmatrices)]) (check-false (col-matrix? a))) -(check-true (matrix-zero? (make-matrix 4 3 0))) -(check-true (matrix-zero? (make-matrix 4 3 0.0))) -(check-true (matrix-zero? (make-matrix 4 3 0+0.0i))) -(check-false (matrix-zero? (row-matrix [0 0 0 0 1]))) -(for: ([a (in-list nonmatrices)]) - (check-exn exn:fail:contract? (λ () (matrix-zero? a)))) - ;; =================================================================================================== ;; Accessors @@ -425,17 +438,9 @@ ;; =================================================================================================== ;; Comprehensions -;; for/matrix and friends are defined in terms of for/array and friends, so we only need to test that -;; it works for one case each, and that they properly raise exceptions when given zero-length axes - -(check-equal? - (for/matrix 2 2 ([i (in-range 4)]) i) - (matrix [[0 1] [2 3]])) - -#;; TR can't type this, but it's defined using exactly the same wrapper as `for/matrix' -(check-equal? - (for*/matrix 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j)) - (matrix [[0 1] [1 2]])) +;; for:/matrix and friends are defined in terms of for:/array and friends, so we only need to test +;; that it works for one case each, and that they properly raise exceptions when given zero-length +;; axes (check-equal? (for/matrix: 2 2 ([i (in-range 4)]) i) @@ -445,11 +450,6 @@ (for*/matrix: 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j)) (matrix [[0 1] [1 2]])) -(check-exn exn:fail:contract? (λ () (for/matrix 2 0 () 0))) -(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0))) -(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0))) -(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0))) - (check-exn exn:fail:contract? (λ () (for/matrix: 2 0 () 0))) (check-exn exn:fail:contract? (λ () (for/matrix: 0 2 () 0))) (check-exn exn:fail:contract? (λ () (for*/matrix: 2 0 () 0))) @@ -531,6 +531,10 @@ (for: ([a (in-list nonmatrices)]) (check-exn exn:fail:contract? (λ () (matrix-cols a)))) +;; TODO: matrix-upper-triangle + +;; TODO: matrix-lower-triangle + ;; =================================================================================================== ;; Embiggenment (it's a perfectly cromulent word) @@ -626,6 +630,10 @@ (check-exn exn:fail:contract? (λ () (matrix-dot a (matrix [[1]])))) (check-exn exn:fail:contract? (λ () (matrix-dot (matrix [[1]]) a)))) +;; TODO: matrix-angle + +;; TODO: matrix-normalize + ;; =================================================================================================== ;; Simple operators @@ -647,8 +655,8 @@ ;; matrix-hermitian -(let ([a (array-make-rectangular (random-matrix 5 6) - (random-matrix 5 6))]) +(let ([a (array-make-rectangular (random-matrix 5 6 -100 100) + (random-matrix 5 6 -100 100))]) (check-equal? (matrix-hermitian a) (matrix-conjugate (matrix-transpose a))) (check-equal? (matrix-hermitian a) @@ -667,6 +675,86 @@ (for: ([a (in-list nonmatrices)]) (check-exn exn:fail:contract? (λ () (matrix-trace a)))) +;; =================================================================================================== +;; Row/column operators + +;; TODO: matrix-map-rows + +;; TODO: matrix-map-cols + +;; TODO: matrix-normalize-rows + +;; TODO: matrix-normalize-cols + +;; =================================================================================================== +;; Operator norms + +;; TODO: matrix-op-1norm + +;; TODO: matrix-op-2norm (after it's implemented) + +;; TODO: matrix-op-inf-norm + +;; =================================================================================================== +;; Error + +(for*: ([x (in-list '(-inf.0 -10.0 -1.0 -0.1 -0.0 0.0 0.1 1.0 10.0 +inf.0 +nan.0))] + [y (in-list '(-inf.0 -10.0 -1.0 -0.1 -0.0 0.0 0.1 1.0 10.0 +inf.0 +nan.0))]) + (check-eqv? (fl (matrix-absolute-error (row-matrix [x]) + (row-matrix [y]))) + (fl (absolute-error x y)) + (format "x = ~v y = ~v" x y)) + (check-eqv? (fl (matrix-relative-error (row-matrix [x]) + (row-matrix [y]))) + (fl (relative-error x y)) + (format "x = ~v y = ~v" x y))) + +(check-equal? (matrix-absolute-error (row-matrix [1 2]) + (row-matrix [1 2])) + 0) + +(check-equal? (matrix-absolute-error (row-matrix [1 2]) + (row-matrix [2 2])) + 1) + +(check-equal? (matrix-absolute-error (row-matrix [1 2]) + (row-matrix [2 +nan.0])) + +inf.0) + +(check-equal? (matrix-relative-error (row-matrix [1 2]) + (row-matrix [1 2])) + 0) + +(check-equal? (matrix-relative-error (row-matrix [1 2]) + (row-matrix [2 2])) + (/ 1 (matrix-op-inf-norm (row-matrix [2 2])))) + +(check-equal? (matrix-relative-error (row-matrix [1 2]) + (row-matrix [2 +nan.0])) + +inf.0) + +;; TODO: matrix-basis-angle + +;; =================================================================================================== +;; Approximate predicates + +;; matrix-zero? (TODO: approximations) + +(check-true (matrix-zero? (make-matrix 4 3 0))) +(check-true (matrix-zero? (make-matrix 4 3 0.0))) +(check-true (matrix-zero? (make-matrix 4 3 0+0.0i))) +(check-false (matrix-zero? (row-matrix [0 0 0 0 1]))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-zero? a)))) + +;; TODO: matrix-rows-orthogonal? + +;; TODO: matrix-cols-orthogonal? + +;; TODO: matrix-identity? + +;; TODO: matrix-orthonormal? + ;; =================================================================================================== ;; Gaussian elimination @@ -743,7 +831,7 @@ 5280) (for: ([_ (in-range 100)]) - (define a (array- (random-matrix 3 3 7) (array 3))) + (define a (random-matrix 3 3 -3 4)) (check-equal? (matrix-determinant/row-reduction a) (matrix-determinant a))) @@ -756,8 +844,8 @@ ;; Solving linear systems (for: ([_ (in-range 100)]) - (define M (array- (random-matrix 3 3 7) (array 3))) - (define B (array- (random-matrix 3 (+ 1 (random 10)) 7) (array 3))) + (define M (random-matrix 3 3 -3 4)) + (define B (random-matrix 3 (+ 1 (random 10)) -3 4)) (cond [(matrix-invertible? M) (define X (matrix-solve M B)) (check-equal? (matrix* M X) B (format "M = ~a B = ~a" M B))] @@ -779,7 +867,7 @@ ;; Inversion (for: ([_ (in-range 100)]) - (define a (array- (random-matrix 3 3 7) (array 3))) + (define a (random-matrix 3 3 -3 4)) (cond [(matrix-invertible? a) (check-equal? (matrix* a (matrix-inverse a)) (identity-matrix 3) @@ -815,11 +903,6 @@ [0 0 0 -13]])) (check-equal? (matrix* L V) M)) -(: matrix-l ((Matrix Number) -> Any)) -(define (matrix-l M) - (define-values (L U) (matrix-lu M)) - L) - (check-exn exn:fail? (λ () (matrix-l (matrix [[1 1 0 2] [0 2 0 1] [1 0 0 0] @@ -830,15 +913,57 @@ (for: ([a (in-list nonmatrices)]) (check-exn exn:fail:contract? (λ () (matrix-l a)))) +;; =================================================================================================== +;; Gram-Schmidt + +(check-equal? (matrix-gram-schmidt (matrix [[3 2] [1 2]])) + (matrix [[3 -2/5] [1 6/5]])) + +(check-equal? (matrix-gram-schmidt (matrix [[3 2] [1 2]]) #t) + (matrix-scale (matrix [[3 -1] [1 3]]) (sqrt 1/10))) + +(check-equal? (matrix-gram-schmidt (matrix [[12 -51 4] + [ 6 167 -68] + [-4 24 -41]]) + #t) + (matrix [[ 6/7 -69/175 -58/175] + [ 3/7 158/175 6/175] + [-2/7 6/35 -33/35 ]])) + +(check-equal? (matrix-gram-schmidt (matrix [[12 -51 4] + [ 6 167 -68] + [-4 24 -41]]) + #t) + (matrix [[ 6/7 -69/175 -58/175] + [ 3/7 158/175 6/175] + [-2/7 6/35 -33/35 ]])) + +(check-equal? (matrix-gram-schmidt (matrix [[12 -51] + [ 6 167] + [-4 24]]) + #t) + (matrix [[ 6/7 -69/175] + [ 3/7 158/175] + [-2/7 6/35 ]])) + +(check-equal? (matrix-gram-schmidt (col-matrix [12 6 -4]) #t) + (col-matrix [6/7 3/7 -2/7])) + +(check-equal? (matrix-gram-schmidt (col-matrix [12 6 -4]) #f) + (col-matrix [12 6 -4])) + +;; =================================================================================================== +;; QR decomposition + +(check-true (matrix-orthonormal? (matrix-q (index-array #(100 1))))) + #| ;; =================================================================================================== ;; Tests not yet converted to rackunit -(matrix-gram-schmidt - (matrix [[2 1 2] - [2 2 3] - [5 1 5]]) - #t) +;; A particularly tricky one used to demonstrate loss of orthogonality +(matrix-qr (matrix [[0.70000 0.70711] + [0.70001 0.70711]])) (begin diff --git a/collects/math/tests/matrix-untyped-tests.rkt b/collects/math/tests/matrix-untyped-tests.rkt new file mode 100644 index 0000000000..11d72324ef --- /dev/null +++ b/collects/math/tests/matrix-untyped-tests.rkt @@ -0,0 +1,51 @@ +#lang racket + +(require (for-syntax racket/match) + rackunit + math/matrix) + +;; =================================================================================================== +;; Contract tests + +(begin-for-syntax + (define exceptions (list 'matrix 'col-matrix 'row-matrix + 'matrix-determinant/row-reduction)) + + (define (looks-like-value? sym) + (define str (symbol->string sym)) + (and (not (char-upper-case? (string-ref str 0))) + (not (regexp-match #rx"for/" str)) + (not (regexp-match #rx"for\\*/" str)) + (not (member sym exceptions)))) + + (define matrix-exports + (let () + (match-define (list (list #f _ ...) + (list 1 _ ...) + (list 0 matrix-exports ...)) + (syntax-local-module-exports #'math/matrix)) + (filter looks-like-value? matrix-exports))) + ) + +(define-syntax (all-exports stx) + (with-syntax ([(matrix-exports ...) matrix-exports]) + (syntax/loc stx + (begin (void matrix-exports) ...)))) + +(all-exports) + +;; =================================================================================================== +;; Comprehensions + +(check-equal? + (for/matrix 2 2 ([i (in-range 4)]) i) + (matrix [[0 1] [2 3]])) + +(check-equal? + (for*/matrix 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j)) + (matrix [[0 1] [1 2]])) + +(check-exn exn:fail:contract? (λ () (for/matrix 2 0 () 0))) +(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0))) +(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0))) +(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0)))