From a1aa97c1fd9eb8e76bde526c395bd9e15a58b96a Mon Sep 17 00:00:00 2001 From: Neil Toronto Date: Wed, 16 Jan 2013 16:39:46 -0700 Subject: [PATCH] `math/matrix' fixes; please merge to 5.3.2. * Fixed type of `matrix-expt' * Made matrix functions respect `array-strictness' parameter (mostly wrapping functions with `parameterize' and return values with `array-default-strictness'; reindentation makes changes look larger) * Added strictness tests (cherry picked from commit f40ad2ca9deeb03749ed0ae2c572ff992766c1a4) --- collects/math/matrix.rkt | 10 +- collects/math/private/matrix/matrix-basic.rkt | 154 ++++++++++------- .../private/matrix/matrix-constructors.rkt | 50 +++--- .../math/private/matrix/matrix-conversion.rkt | 22 ++- collects/math/private/matrix/matrix-expt.rkt | 31 ++-- .../private/matrix/matrix-gram-schmidt.rkt | 30 ++-- collects/math/private/matrix/matrix-lu.rkt | 9 +- .../private/matrix/matrix-operator-norm.rkt | 50 +++--- collects/math/private/matrix/matrix-qr.rkt | 21 ++- collects/math/private/matrix/matrix-solve.rkt | 9 +- .../math/private/matrix/matrix-subspace.rkt | 38 ++-- .../matrix/typed-matrix-arithmetic.rkt | 57 +++--- .../matrix/untyped-matrix-arithmetic.rkt | 33 ++-- collects/math/private/matrix/utils.rkt | 13 +- .../math/tests/matrix-strictness-tests.rkt | 163 ++++++++++++++++++ collects/math/tests/matrix-untyped-tests.rkt | 15 ++ 16 files changed, 499 insertions(+), 206 deletions(-) create mode 100644 collects/math/tests/matrix-strictness-tests.rkt diff --git a/collects/math/matrix.rkt b/collects/math/matrix.rkt index 964ad20616..b57d64f3e1 100644 --- a/collects/math/matrix.rkt +++ b/collects/math/matrix.rkt @@ -6,9 +6,9 @@ "private/matrix/matrix-conversion.rkt" "private/matrix/matrix-syntax.rkt" "private/matrix/matrix-comprehension.rkt" - "private/matrix/matrix-expt.rkt" "private/matrix/matrix-types.rkt" "private/matrix/matrix-2d.rkt" + ;;"private/matrix/matrix-expt.rkt" ; all use require/untyped-contract ;;"private/matrix/matrix-gauss-elim.rkt" ; all use require/untyped-contract (except-in "private/matrix/matrix-solve.rkt" matrix-determinant @@ -36,6 +36,11 @@ ;;"private/matrix/matrix-gram-schmidt.rkt" ; all use require/untyped-contract ) +(require/untyped-contract + (begin (require "private/matrix/matrix-types.rkt")) + "private/matrix/matrix-expt.rkt" + [matrix-expt ((Matrix Number) Integer -> (Matrix Number))]) + (require/untyped-contract (begin (require "private/matrix/matrix-types.rkt" "private/matrix/matrix-gauss-elim.rkt")) @@ -148,9 +153,10 @@ "private/matrix/matrix-solve.rkt" "private/matrix/matrix-operator-norm.rkt" "private/matrix/matrix-comprehension.rkt" - "private/matrix/matrix-expt.rkt" "private/matrix/matrix-types.rkt" "private/matrix/matrix-2d.rkt") + ;; matrix/matrix-expt.rkt + matrix-expt ;; matrix-gauss-elim.rkt matrix-gauss-elim matrix-row-echelon diff --git a/collects/math/private/matrix/matrix-basic.rkt b/collects/math/private/matrix/matrix-basic.rkt index d1669e7764..5ed6a2f49f 100644 --- a/collects/math/private/matrix/matrix-basic.rkt +++ b/collects/math/private/matrix/matrix-basic.rkt @@ -81,13 +81,14 @@ (raise-argument-error 'matrix-row (format "Index < ~a" m) 1 a i)] [else (define proc (unsafe-array-proc a)) - (unsafe-build-array - ((inst vector Index) 1 n) - (λ: ([ij : Indexes]) - (unsafe-vector-set! ij 0 i) - (define res (proc ij)) - (unsafe-vector-set! ij 0 0) - res))])) + (array-default-strict + (unsafe-build-array + ((inst vector Index) 1 n) + (λ: ([ij : Indexes]) + (unsafe-vector-set! ij 0 i) + (define res (proc ij)) + (unsafe-vector-set! ij 0 0) + res)))])) (: matrix-col (All (A) (Matrix A) Integer -> (Matrix A))) (define (matrix-col a j) @@ -96,53 +97,61 @@ (raise-argument-error 'matrix-row (format "Index < ~a" n) 1 a j)] [else (define proc (unsafe-array-proc a)) - (unsafe-build-array - ((inst vector Index) m 1) - (λ: ([ij : Indexes]) - (unsafe-vector-set! ij 1 j) - (define res (proc ij)) - (unsafe-vector-set! ij 1 0) - res))])) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m 1) + (λ: ([ij : Indexes]) + (unsafe-vector-set! ij 1 j) + (define res (proc ij)) + (unsafe-vector-set! ij 1 0) + res)))])) (: matrix-rows (All (A) (Matrix A) -> (Listof (Matrix A)))) (define (matrix-rows a) - (array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0)) + (map (λ: ([a : (Matrix A)]) (array-default-strict a)) + (parameterize ([array-strictness #f]) + (array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0)))) (: matrix-cols (All (A) (Matrix A) -> (Listof (Matrix A)))) (define (matrix-cols a) - (array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1)) + (map (λ: ([a : (Matrix A)]) (array-default-strict a)) + (parameterize ([array-strictness #f]) + (array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1)))) (: matrix-diagonal (All (A) ((Matrix A) -> (Array A)))) (define (matrix-diagonal a) (define m (square-matrix-size a)) (define proc (unsafe-array-proc a)) - (unsafe-build-array - ((inst vector Index) m) - (λ: ([js : Indexes]) - (define i (unsafe-vector-ref js 0)) - (proc ((inst vector Index) i i))))) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m) + (λ: ([js : Indexes]) + (define i (unsafe-vector-ref js 0)) + (proc ((inst vector Index) i i)))))) (: matrix-upper-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) (define (matrix-upper-triangle M) (define-values (m n) (matrix-shape M)) (define proc (unsafe-array-proc M)) - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([ij : Indexes]) - (define i (unsafe-vector-ref ij 0)) - (define j (unsafe-vector-ref ij 1)) - (if (i . fx<= . j) (proc ij) 0)))) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (if (i . fx<= . j) (proc ij) 0))))) (: matrix-lower-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) (define (matrix-lower-triangle M) (define-values (m n) (matrix-shape M)) (define proc (unsafe-array-proc M)) - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([ij : Indexes]) - (define i (unsafe-vector-ref ij 0)) - (define j (unsafe-vector-ref ij 1)) - (if (i . fx>= . j) (proc ij) 0)))) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (if (i . fx>= . j) (proc ij) 0))))) ;; =================================================================================================== ;; Embiggenment (this is a perfectly cromulent word) @@ -176,34 +185,38 @@ (: matrix-1norm ((Matrix Number) -> Nonnegative-Real)) (define (matrix-1norm a) - (array-all-sum (array-magnitude a))) + (parameterize ([array-strictness #f]) + (array-all-sum (array-magnitude a)))) (: matrix-2norm ((Matrix Number) -> Nonnegative-Real)) (define (matrix-2norm a) - (let ([a (array-strict (array-magnitude a))]) - ;; Compute this divided by the maximum to avoid underflow and overflow - (define mx (array-all-max a)) - (cond [(and (rational? mx) (positive? mx)) - (* mx (sqrt (array-all-sum - (inline-array-map (λ: ([x : Nonnegative-Real]) (sqr (/ x mx))) a))))] - [else mx]))) + (parameterize ([array-strictness #f]) + (let ([a (array-strict (array-magnitude a))]) + ;; Compute this divided by the maximum to avoid underflow and overflow + (define mx (array-all-max a)) + (cond [(and (rational? mx) (positive? mx)) + (* mx (sqrt (array-all-sum + (inline-array-map (λ: ([x : Nonnegative-Real]) (sqr (/ x mx))) a))))] + [else mx])))) (: matrix-inf-norm ((Matrix Number) -> Nonnegative-Real)) (define (matrix-inf-norm a) - (array-all-max (array-magnitude a))) + (parameterize ([array-strictness #f]) + (array-all-max (array-magnitude a)))) (: matrix-p-norm ((Matrix Number) Positive-Real -> Nonnegative-Real)) (define (matrix-p-norm a p) - (let ([a (array-strict (array-magnitude a))]) - ;; Compute this divided by the maximum to avoid underflow and overflow - (define mx (array-all-max a)) - (cond [(and (rational? mx) (positive? mx)) - (assert - (* mx (expt (array-all-sum - (inline-array-map (λ: ([x : Nonnegative-Real]) (expt (/ x mx) p)) a)) - (/ p))) - (make-predicate Nonnegative-Real))] - [else mx]))) + (parameterize ([array-strictness #f]) + (let ([a (array-strict (array-magnitude a))]) + ;; Compute this divided by the maximum to avoid underflow and overflow + (define mx (array-all-max a)) + (cond [(and (rational? mx) (positive? mx)) + (assert + (* mx (expt (array-all-sum + (inline-array-map (λ: ([x : Nonnegative-Real]) (expt (/ x mx) p)) a)) + (/ p))) + (make-predicate Nonnegative-Real))] + [else mx])))) (: matrix-norm (case-> ((Matrix Number) -> Nonnegative-Real) ((Matrix Number) Real -> Nonnegative-Real))) @@ -224,21 +237,23 @@ (define matrix-dot (case-lambda [(a) - (assert - (array-all-sum - (inline-array-map - (λ (x) (* x (conjugate x))) - (ensure-matrix 'matrix-dot a))) - (make-predicate Nonnegative-Real))] + (parameterize ([array-strictness #f]) + (assert + (array-all-sum + (inline-array-map + (λ (x) (* x (conjugate x))) + (ensure-matrix 'matrix-dot a))) + (make-predicate Nonnegative-Real)))] [(a b) (define-values (m n) (matrix-shapes 'matrix-dot a b)) (define aproc (unsafe-array-proc a)) (define bproc (unsafe-array-proc b)) - (array-all-sum - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([js : Indexes]) - (* (aproc js) (conjugate (bproc js))))))])) + (parameterize ([array-strictness #f]) + (array-all-sum + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (* (aproc js) (conjugate (bproc js)))))))])) (: matrix-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real) ((Matrix Number) (Matrix Number) -> Number))) @@ -283,12 +298,15 @@ (: matrix-hermitian (case-> ((Matrix Real) -> (Matrix Real)) ((Matrix Number) -> (Matrix Number)))) (define (matrix-hermitian a) - (array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1)) + (array-default-strict + (parameterize ([array-strictness #f]) + (array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1)))) (: matrix-trace (case-> ((Matrix Real) -> Real) ((Matrix Number) -> Number))) (define (matrix-trace a) - (array-all-sum (matrix-diagonal a))) + (parameterize ([array-strictness #f]) + (array-all-sum (matrix-diagonal a)))) ;; =================================================================================================== ;; Row/column operations @@ -382,11 +400,13 @@ ((Matrix Number) Real -> Boolean))) (define (matrix-rows-orthogonal? M [eps (* 10 epsilon.0)]) (cond [(negative? eps) (raise-argument-error 'matrix-rows-orthogonal? "Nonnegative-Real" 1 M eps)] - [else (pairwise-orthogonal? (matrix-rows M) eps)])) + [else (parameterize ([array-strictness #f]) + (pairwise-orthogonal? (matrix-rows M) eps))])) (: matrix-cols-orthogonal? (case-> ((Matrix Number) -> Boolean) ((Matrix Number) Real -> Boolean))) (define (matrix-cols-orthogonal? M [eps (* 10 epsilon.0)]) (cond [(negative? eps) (raise-argument-error 'matrix-cols-orthogonal? "Nonnegative-Real" 1 M eps)] - [else (pairwise-orthogonal? (matrix-cols M) eps)])) + [else (parameterize ([array-strictness #f]) + (pairwise-orthogonal? (matrix-cols M) eps))])) diff --git a/collects/math/private/matrix/matrix-constructors.rkt b/collects/math/private/matrix/matrix-constructors.rkt index dd4f4f15e9..a666230e86 100644 --- a/collects/math/private/matrix/matrix-constructors.rkt +++ b/collects/math/private/matrix/matrix-constructors.rkt @@ -36,11 +36,12 @@ [(or (not (index? n)) (= n 0)) (raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)] [else - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([js : Indexes]) - (proc (unsafe-vector-ref js 0) - (unsafe-vector-ref js 1))))])) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (proc (unsafe-vector-ref js 0) + (unsafe-vector-ref js 1)))))])) ;; =================================================================================================== ;; Diagonal matrices @@ -52,7 +53,7 @@ [else (define vs (list->vector xs)) (define m (vector-length vs)) - (unsafe-build-array + (unsafe-build-simple-array ((inst vector Index) m m) (λ: ([js : Indexes]) (define i (unsafe-vector-ref js 0)) @@ -97,24 +98,25 @@ (vector-set! js (unsafe-fx+ res-j j) (assert j index?)))) (values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n)))) (define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as)) - (unsafe-build-array - ((inst vector Index) res-m res-n) - (λ: ([ij : Indexes]) - (define i (unsafe-vector-ref ij 0)) - (define j (unsafe-vector-ref ij 1)) - (define v (unsafe-vector-ref vs i)) - (cond [(fx= v (vector-ref hs j)) - (define proc (unsafe-vector-ref procs v)) - (define iv (unsafe-vector-ref is i)) - (define jv (unsafe-vector-ref js j)) - (unsafe-vector-set! ij 0 iv) - (unsafe-vector-set! ij 1 jv) - (define res (proc ij)) - (unsafe-vector-set! ij 0 i) - (unsafe-vector-set! ij 1 j) - res] - [else - zero])))) + (array-default-strict + (unsafe-build-array + ((inst vector Index) res-m res-n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (define v (unsafe-vector-ref vs i)) + (cond [(fx= v (vector-ref hs j)) + (define proc (unsafe-vector-ref procs v)) + (define iv (unsafe-vector-ref is i)) + (define jv (unsafe-vector-ref js j)) + (unsafe-vector-set! ij 0 iv) + (unsafe-vector-set! ij 1 jv) + (define res (proc ij)) + (unsafe-vector-set! ij 0 i) + (unsafe-vector-set! ij 1 j) + res] + [else + zero]))))) (: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A)))) (define (block-diagonal-matrix/zero as zero) diff --git a/collects/math/private/matrix/matrix-conversion.rkt b/collects/math/private/matrix/matrix-conversion.rkt index a65ff6aaa8..283a8b0ff8 100644 --- a/collects/math/private/matrix/matrix-conversion.rkt +++ b/collects/math/private/matrix/matrix-conversion.rkt @@ -88,17 +88,19 @@ [(= num-ones dims) (define: js : (Vectorof Index) (make-vector dims 0)) (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) 1 1) - (λ: ([ij : Indexes]) (proc js)))] + (array-default-strict + (unsafe-build-array ((inst vector Index) 1 1) + (λ: ([ij : Indexes]) (proc js))))] [(= num-ones (- dims 1)) (define-values (k m) (find-nontrivial-axis ds)) (define js (make-thread-local-indexes dims)) (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) m 1) - (λ: ([ij : Indexes]) - (let ([js (js)]) - (unsafe-vector-set! js k (unsafe-vector-ref ij 0)) - (proc js))))] + (array-default-strict + (unsafe-build-array ((inst vector Index) m 1) + (λ: ([ij : Indexes]) + (let ([js (js)]) + (unsafe-vector-set! js k (unsafe-vector-ref ij 0)) + (proc js)))))] [else (fail)])) (: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Matrix A)))) @@ -157,7 +159,8 @@ (: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A)))) (define (matrix->list* a) - (cond [(matrix? a) (array->list (array->list-array a 1))] + (cond [(matrix? a) (parameterize ([array-strictness #f]) + (array->list (array->list-array a 1)))] [else (raise-argument-error 'matrix->list* "matrix?" a)])) (: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A))) @@ -171,5 +174,6 @@ (: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A)))) (define (matrix->vector* a) - (cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))] + (cond [(matrix? a) (parameterize ([array-strictness #f]) + (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector)))] [else (raise-argument-error 'matrix->vector* "matrix?" a)])) diff --git a/collects/math/private/matrix/matrix-expt.rkt b/collects/math/private/matrix/matrix-expt.rkt index a3ed5752f6..42491c1ec9 100644 --- a/collects/math/private/matrix/matrix-expt.rkt +++ b/collects/math/private/matrix/matrix-expt.rkt @@ -2,19 +2,30 @@ (require "matrix-types.rkt" "matrix-constructors.rkt" - "matrix-arithmetic.rkt") + "matrix-arithmetic.rkt" + "utils.rkt") (provide matrix-expt) -(: matrix-expt : (Matrix Number) Integer -> (Matrix Number)) +(: matrix-expt/ns (case-> ((Matrix Real) Positive-Integer -> (Matrix Real)) + ((Matrix Number) Positive-Integer -> (Matrix Number)))) +(define (matrix-expt/ns a n) + (define n/2 (quotient n 2)) + (if (zero? n/2) + ;; n = 1 + a + (let ([m (* n/2 2)]) + (if (= n m) + ;; n is even + (let ([a^n/2 (matrix-expt/ns a n/2)]) + (matrix* a^n/2 a^n/2)) + ;; m = n - 1 + (matrix* a (matrix-expt/ns a m)))))) + +(: matrix-expt (case-> ((Matrix Real) Integer -> (Matrix Real)) + ((Matrix Number) Integer -> (Matrix Number)))) (define (matrix-expt a n) (cond [(not (square-matrix? a)) (raise-argument-error 'matrix-expt "square-matrix?" 0 a n)] [(negative? n) (raise-argument-error 'matrix-expt "Natural" 1 a n)] - [(zero? n) (identity-matrix (square-matrix-size a))] - [else - (let: loop : (Matrix Number) ([n : Positive-Integer n]) - (cond [(= n 1) a] - [(= n 2) (matrix* a a)] - [(even? n) (let ([a^n/2 (matrix-expt a (quotient n 2))]) - (matrix* a^n/2 a^n/2))] - [else (matrix* a (matrix-expt a (sub1 n)))]))])) + [(zero? n) (identity-matrix (square-matrix-size a))] + [else (call/ns (λ () (matrix-expt/ns a n)))])) diff --git a/collects/math/private/matrix/matrix-gram-schmidt.rkt b/collects/math/private/matrix/matrix-gram-schmidt.rkt index 741a8161de..5905b73b83 100644 --- a/collects/math/private/matrix/matrix-gram-schmidt.rkt +++ b/collects/math/private/matrix/matrix-gram-schmidt.rkt @@ -36,15 +36,11 @@ (vector-sub-proj! (unsafe-vector-ref rows i) row #f) (loop (fx+ i 1))))) -(: matrix-gram-schmidt (case-> ((Matrix Real) -> (Array Real)) - ((Matrix Real) Any -> (Array Real)) - ((Matrix Real) Any Integer -> (Array Real)) - ((Matrix Number) -> (Array Number)) - ((Matrix Number) Any -> (Array Number)) - ((Matrix Number) Any Integer -> (Array Number)))) +(: matrix-gram-schmidt/ns (case-> ((Matrix Real) Any Integer -> (Array Real)) + ((Matrix Number) Any Integer -> (Array Number)))) ;; Performs Gram-Schmidt orthogonalization on M, assuming the rows before `start' are already ;; orthogonal -(define (matrix-gram-schmidt M [normalize? #f] [start 0]) +(define (matrix-gram-schmidt/ns M normalize? start) (define rows (matrix->vector* (matrix-transpose M))) (define m (vector-length rows)) (define i (find-nonzero-vector rows)) @@ -66,9 +62,18 @@ [else (make-array (vector (matrix-num-rows M) 0) 0)])) -(: matrix-basis-extension (case-> ((Matrix Real) -> (Array Real)) - ((Matrix Number) -> (Array Number)))) -(define (matrix-basis-extension B) +(: matrix-gram-schmidt (case-> ((Matrix Real) -> (Array Real)) + ((Matrix Real) Any -> (Array Real)) + ((Matrix Real) Any Integer -> (Array Real)) + ((Matrix Number) -> (Array Number)) + ((Matrix Number) Any -> (Array Number)) + ((Matrix Number) Any Integer -> (Array Number)))) +(define (matrix-gram-schmidt M [normalize? #f] [start 0]) + (call/ns (λ () (matrix-gram-schmidt/ns M normalize? start)))) + +(: matrix-basis-extension/ns (case-> ((Matrix Real) -> (Array Real)) + ((Matrix Number) -> (Array Number)))) +(define (matrix-basis-extension/ns B) (define-values (m n) (matrix-shape B)) (cond [(n . < . m) (define S (matrix-gram-schmidt (matrix-augment (list B (identity-matrix m))) #f n)) @@ -78,3 +83,8 @@ (make-array (vector m 0) 0)] [else (raise-argument-error 'matrix-extend-row-basis "matrix? with width < height" B)])) + +(: matrix-basis-extension (case-> ((Matrix Real) -> (Array Real)) + ((Matrix Number) -> (Array Number)))) +(define (matrix-basis-extension B) + (call/ns (λ () (matrix-basis-extension/ns B)))) diff --git a/collects/math/private/matrix/matrix-lu.rkt b/collects/math/private/matrix/matrix-lu.rkt index 46b5f43f5e..50ad6a7893 100644 --- a/collects/math/private/matrix/matrix-lu.rkt +++ b/collects/math/private/matrix/matrix-lu.rkt @@ -7,7 +7,8 @@ "utils.rkt" "../unsafe.rkt" "../vector/vector-mutate.rkt" - "../array/mutable-array.rkt") + "../array/mutable-array.rkt" + "../array/array-struct.rkt") (provide matrix-lu) @@ -24,8 +25,10 @@ [(M fail) (define m (square-matrix-size M)) (define rows (matrix->vector* M)) - ;; Construct L in a weird way to prove to TR that it has the right type - (define L (array->mutable-array (matrix-scale M (ann 0 Real)))) + (define L + (parameterize ([array-strictness #f]) + ;; Construct L in a weird way to prove to TR that it has the right type + (array->mutable-array (matrix-scale M (ann 0 Real))))) ;; Going to fill in the lower triangle by banging values into `ys' (define ys (mutable-array-data L)) (let loop ([#{i : Nonnegative-Fixnum} 0]) diff --git a/collects/math/private/matrix/matrix-operator-norm.rkt b/collects/math/private/matrix/matrix-operator-norm.rkt index 6e300f7255..8a34eab489 100644 --- a/collects/math/private/matrix/matrix-operator-norm.rkt +++ b/collects/math/private/matrix/matrix-operator-norm.rkt @@ -45,7 +45,8 @@ See "How to Measure Errors" in the LAPACK manual for more details: (: matrix-op-1norm ((Matrix Number) -> Nonnegative-Real)) ;; When M is a column matrix, this is equivalent to matrix-1norm (define (matrix-op-1norm M) - (assert (apply max (map matrix-1norm (matrix-cols M))) nonnegative?)) + (parameterize ([array-strictness #f]) + (assert (apply max (map matrix-1norm (matrix-cols M))) nonnegative?))) (: matrix-op-2norm ((Matrix Number) -> Nonnegative-Real)) ;; When M is a column matrix, this is equivalent to matrix-2norm @@ -57,7 +58,8 @@ See "How to Measure Errors" in the LAPACK manual for more details: (: matrix-op-inf-norm ((Matrix Number) -> Nonnegative-Real)) ;; When M is a column matrix, this is equivalent to matrix-inf-norm (define (matrix-op-inf-norm M) - (assert (apply max (map matrix-1norm (matrix-rows M))) nonnegative?)) + (parameterize ([array-strictness #f]) + (assert (apply max (map matrix-1norm (matrix-rows M))) nonnegative?))) (: matrix-basis-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real) ((Matrix Number) (Matrix Number) -> Number))) @@ -83,33 +85,35 @@ See "How to Measure Errors" in the LAPACK manual for more details: ((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real) -> Nonnegative-Real))) (define (matrix-absolute-error M R [norm (matrix-error-norm)]) - (define-values (m n) (matrix-shapes 'matrix-absolute-error M R)) - (array-strict! M) - (array-strict! R) - (cond [(array-all-and (inline-array-map eqv? M R)) 0] - [(and (array-all-and (inline-array-map number-rational? M)) - (array-all-and (inline-array-map number-rational? R))) - (norm (matrix- (inline-array-map inexact->exact M) - (inline-array-map inexact->exact R)))] - [else +inf.0])) + (parameterize ([array-strictness #f]) + (define-values (m n) (matrix-shapes 'matrix-absolute-error M R)) + (array-strict! M) + (array-strict! R) + (cond [(array-all-and (inline-array-map eqv? M R)) 0] + [(and (array-all-and (inline-array-map number-rational? M)) + (array-all-and (inline-array-map number-rational? R))) + (norm (matrix- (inline-array-map inexact->exact M) + (inline-array-map inexact->exact R)))] + [else +inf.0]))) (: matrix-relative-error (case-> ((Matrix Number) (Matrix Number) -> Nonnegative-Real) ((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real) -> Nonnegative-Real))) (define (matrix-relative-error M R [norm (matrix-error-norm)]) - (define-values (m n) (matrix-shapes 'matrix-relative-error M R)) - (array-strict! M) - (array-strict! R) - (cond [(array-all-and (inline-array-map eqv? M R)) 0] - [(and (array-all-and (inline-array-map number-rational? M)) - (array-all-and (inline-array-map number-rational? R))) - (define num (norm (matrix- M R))) - (define den (norm R)) - (cond [(and (zero? num) (zero? den)) 0] - [(zero? den) +inf.0] - [else (assert (/ num den) nonnegative?)])] - [else +inf.0])) + (parameterize ([array-strictness #f]) + (define-values (m n) (matrix-shapes 'matrix-relative-error M R)) + (array-strict! M) + (array-strict! R) + (cond [(array-all-and (inline-array-map eqv? M R)) 0] + [(and (array-all-and (inline-array-map number-rational? M)) + (array-all-and (inline-array-map number-rational? R))) + (define num (norm (matrix- M R))) + (define den (norm R)) + (cond [(and (zero? num) (zero? den)) 0] + [(zero? den) +inf.0] + [else (assert (/ num den) nonnegative?)])] + [else +inf.0]))) ;; =================================================================================================== ;; Approximate predicates diff --git a/collects/math/private/matrix/matrix-qr.rkt b/collects/math/private/matrix/matrix-qr.rkt index 530e803182..e711793d47 100644 --- a/collects/math/private/matrix/matrix-qr.rkt +++ b/collects/math/private/matrix/matrix-qr.rkt @@ -5,7 +5,8 @@ "matrix-arithmetic.rkt" "matrix-constructors.rkt" "matrix-gram-schmidt.rkt" - "../array/array-transform.rkt") + "../array/array-transform.rkt" + "../array/array-struct.rkt") (provide matrix-qr) @@ -23,11 +24,9 @@ produces matrices for which `matrix-orthogonal?' returns #t with eps <= 10*epsil independently of the matrix size. |# -(: matrix-qr (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) - ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real))) - ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) - ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))) -(define (matrix-qr M [full? #t]) +(: matrix-qr/ns (case-> ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))) +(define (matrix-qr/ns M full?) (define B (matrix-gram-schmidt M #f)) (define Q (matrix-gram-schmidt @@ -37,3 +36,13 @@ independently of the matrix size. [else (matrix-col (identity-matrix (matrix-num-rows M)) 0)]) #t)) (values Q (matrix-upper-triangle (matrix* (matrix-hermitian Q) M)))) + +(: matrix-qr (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) + ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))) +(define (matrix-qr M [full? #t]) + (define-values (Q R) (parameterize ([array-strictness #f]) + (matrix-qr/ns M full?))) + (values (array-default-strict Q) + (array-default-strict R))) diff --git a/collects/math/private/matrix/matrix-solve.rkt b/collects/math/private/matrix/matrix-solve.rkt index c8dd1b0fc1..f5a7d66542 100644 --- a/collects/math/private/matrix/matrix-solve.rkt +++ b/collects/math/private/matrix/matrix-solve.rkt @@ -11,7 +11,8 @@ "utils.rkt" "../vector/vector-mutate.rkt" "../array/array-indexing.rkt" - "../array/mutable-array.rkt") + "../array/mutable-array.rkt" + "../array/array-struct.rkt") (provide matrix-determinant @@ -80,7 +81,8 @@ [(M fail) (define m (square-matrix-size M)) (define I (identity-matrix m)) - (define-values (IM^-1 wps) (matrix-gauss-elim (matrix-augment (list M I)) #t #t)) + (define-values (IM^-1 wps) (parameterize ([array-strictness #f]) + (matrix-gauss-elim (matrix-augment (list M I)) #t #t))) (cond [(and (not (empty? wps)) (= (first wps) m)) (submatrix IM^-1 (::) (:: m #f))] [else (fail)])])) @@ -100,7 +102,8 @@ (define m (square-matrix-size M)) (define-values (s t) (matrix-shape B)) (cond [(= m s) - (define-values (IX wps) (matrix-gauss-elim (matrix-augment (list M B)) #t #t)) + (define-values (IX wps) (parameterize ([array-strictness #f]) + (matrix-gauss-elim (matrix-augment (list M B)) #t #t))) (cond [(and (not (empty? wps)) (= (first wps) m)) (submatrix IX (::) (:: m #f))] [else (fail)])] diff --git a/collects/math/private/matrix/matrix-subspace.rkt b/collects/math/private/matrix/matrix-subspace.rkt index f89c4c8f47..07662d9e10 100644 --- a/collects/math/private/matrix/matrix-subspace.rkt +++ b/collects/math/private/matrix/matrix-subspace.rkt @@ -7,7 +7,8 @@ "matrix-gauss-elim.rkt" "utils.rkt" "../array/array-indexing.rkt" - "../array/array-constructors.rkt") + "../array/array-constructors.rkt" + "../array/array-struct.rkt") (provide matrix-rank @@ -34,24 +35,31 @@ (cond [(= j0 j1) Bs] [else (cons (submatrix M (::) (:: j0 j1)) Bs)])) +(: matrix-col-space/ns (All (A) (case-> ((Matrix Real) -> (U #f (Matrix Real))) + ((Matrix Number) -> (U #f (Matrix Number)))))) +(define (matrix-col-space/ns M) + (define n (matrix-num-cols M)) + (define-values (_ wps) (matrix-gauss-elim M)) + (cond [(empty? wps) M] + [(= (length wps) n) #f] + [else + (define next-j (first wps)) + (define Bs (maybe-cons-submatrix M 0 next-j empty)) + (let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs]) + (cond [(empty? wps) + (matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))] + [else + (define next-j (first wps)) + (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])) + (: matrix-col-space (All (A) (case-> ((Matrix Real) -> (Matrix Real)) ((Matrix Real) (-> A) -> (U A (Matrix Real))) ((Matrix Number) -> (Matrix Number)) ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) -(define matrix-col-space +(define matrix-col-space (case-lambda [(M) (matrix-col-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))] [(M fail) - (define n (matrix-num-cols M)) - (define-values (_ wps) (matrix-gauss-elim M)) - (cond [(empty? wps) M] - [(= (length wps) n) (fail)] - [else - (define next-j (first wps)) - (define Bs (maybe-cons-submatrix M 0 next-j empty)) - (let loop ([#{j : Index} next-j] [wps (rest wps)] [Bs Bs]) - (cond [(empty? wps) - (matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))] - [else - (define next-j (first wps)) - (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])])) + (define S (parameterize ([array-strictness #f]) + (matrix-col-space/ns M))) + (if S (array-default-strict S) (fail))])) diff --git a/collects/math/private/matrix/typed-matrix-arithmetic.rkt b/collects/math/private/matrix/typed-matrix-arithmetic.rkt index f2810d559a..516d62ff54 100644 --- a/collects/math/private/matrix/typed-matrix-arithmetic.rkt +++ b/collects/math/private/matrix/typed-matrix-arithmetic.rkt @@ -31,10 +31,11 @@ (define g0 (unsafe-array-proc arr0)) (define g1 (unsafe-array-proc arr1)) (define gs (map unsafe-array-proc arrs)) - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([js : Indexes]) (apply f (g0 js) (g1 js) - (map (λ: ([g : (Indexes -> T)]) (g js)) gs))))])) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) (apply f (g0 js) (g1 js) + (map (λ: ([g : (Indexes -> T)]) (g js)) gs)))))])) (: matrix=? ((Matrix Number) (Matrix Number) -> Boolean)) (define (matrix=? arr0 arr1) @@ -44,10 +45,11 @@ (= n0 n1) (let ([proc0 (unsafe-array-proc arr0)] [proc1 (unsafe-array-proc arr1)]) - (array-all-and (unsafe-build-array - ((inst vector Index) m0 n0) - (λ: ([js : Indexes]) - (= (proc0 js) (proc1 js)))))))) + (parameterize ([array-strictness #f]) + (array-all-and (unsafe-build-array + ((inst vector Index) m0 n0) + (λ: ([js : Indexes]) + (= (proc0 js) (proc1 js))))))))) (: matrix= (case-> ((Matrix Number) (Matrix Number) -> Boolean) ((Matrix Number) (Matrix Number) (Matrix Number) (Matrix Number) * -> Boolean))) @@ -62,28 +64,41 @@ [else (and (matrix=? arr1 (first arrs)) (loop (first arrs) (rest arrs)))])))])) + +(: matrix*/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) + ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) +(define (matrix*/ns a as) + (cond [(empty? as) a] + [else (matrix*/ns (inline-matrix-multiply a (first as)) (rest as))])) + (: matrix* (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) -(define (matrix* a . as) - (let loop ([a a] [as as]) - (cond [(empty? as) a] - [else (loop (inline-matrix* a (first as)) (rest as))]))) +(define (matrix* a . as) (call/ns (λ () (matrix*/ns a as)))) + + +(: matrix+/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) + ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) +(define (matrix+/ns a as) + (cond [(empty? as) a] + [else (matrix+/ns (inline-matrix+ a (first as)) (rest as))])) (: matrix+ (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) -(define (matrix+ a . as) - (let loop ([a a] [as as]) - (cond [(empty? as) a] - [else (loop (inline-matrix+ a (first as)) (rest as))]))) +(define (matrix+ a . as) (call/ns (λ () (matrix+/ns a as)))) + + +(: matrix-/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) + ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) +(define (matrix-/ns a as) + (cond [(empty? as) a] + [else (matrix-/ns (inline-matrix- a (first as)) (rest as))])) (: matrix- (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) (define (matrix- a . as) - (cond [(empty? as) (inline-matrix- a)] - [else - (let loop ([a a] [as as]) - (cond [(empty? as) a] - [else (loop (inline-matrix- a (first as)) (rest as))]))])) + (call/ns (λ () (cond [(empty? as) (inline-matrix- a)] + [else (matrix-/ns a as)])))) + (: matrix-scale (case-> ((Matrix Real) Real -> (Matrix Real)) ((Matrix Number) Number -> (Matrix Number)))) diff --git a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt index 8db143233a..326a25bfd3 100644 --- a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt +++ b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt @@ -1,6 +1,7 @@ #lang racket/base -(provide inline-matrix* +(provide inline-matrix-multiply + inline-matrix* inline-matrix+ inline-matrix- inline-matrix-scale @@ -41,12 +42,17 @@ (* (arr-proc js) (brr-proc js)))) 2))))) - (define-syntax (inline-matrix* stx) + (define-syntax (do-inline-matrix* stx) (syntax-case stx () [(_ arr) (syntax/loc stx arr)] [(_ arr brr crrs ...) - (syntax/loc stx (inline-matrix* (inline-matrix-multiply arr brr) crrs ...))])) + (syntax/loc stx (do-inline-matrix* (inline-matrix-multiply arr brr) crrs ...))])) + + (define-syntax-rule (inline-matrix* arr brrs ...) + (array-default-strict + (parameterize ([array-strictness #f]) + (do-inline-matrix* arr brrs ...)))) (define-syntax (inline-matrix-map stx) (syntax-case stx () @@ -55,7 +61,8 @@ (let*-values ([(arr) arr-expr] [(m n) (matrix-shape arr)] [(proc) (unsafe-array-proc arr)]) - (unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (f (proc js))))))] + (array-default-strict + (unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (f (proc js)))))))] [(_ f arr-expr brr-exprs ...) (with-syntax ([(brrs ...) (generate-temporaries #'(brr-exprs ...))] [(procs ...) (generate-temporaries #'(brr-exprs ...))]) @@ -65,10 +72,11 @@ (let-values ([(m n) (matrix-shapes 'matrix-map arr brrs ...)] [(proc) (unsafe-array-proc arr)] [(procs) (unsafe-array-proc brrs)] ...) - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([js : Indexes]) - (f (proc js) (procs js) ...)))))))])) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (f (proc js) (procs js) ...))))))))])) (define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...)) (define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...)) @@ -101,10 +109,11 @@ (define g0 (unsafe-array-proc arr0)) (define g1 (unsafe-array-proc arr1)) (define gs (map (inst unsafe-array-proc A) arrs)) - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([js : Indexes]) (apply f (g0 js) (g1 js) - (map (λ: ([g : (Indexes -> A)]) (g js)) gs))))])) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) (apply f (g0 js) (g1 js) + (map (λ: ([g : (Indexes -> A)]) (g js)) gs)))))])) ) ; module diff --git a/collects/math/private/matrix/utils.rkt b/collects/math/private/matrix/utils.rkt index 9f3b17bf33..21ca2fde64 100644 --- a/collects/math/private/matrix/utils.rkt +++ b/collects/math/private/matrix/utils.rkt @@ -1,6 +1,7 @@ #lang typed/racket/base -(require racket/string +(require racket/performance-hint + racket/string racket/fixnum "matrix-types.rkt" "../unsafe.rkt" @@ -113,3 +114,13 @@ ;; Make sure the element below the pivot is zero (unsafe-vector-set! row_l j (- x_lj x_lj)))) (loop (fx+ l 1))))) + +(begin-encourage-inline + + (: call/ns (All (A) ((-> (Matrix A)) -> (Matrix A)))) + (define (call/ns thnk) + (array-default-strict + (parameterize ([array-strictness #f]) + (thnk)))) + + ) ; begin-encourage-inline diff --git a/collects/math/tests/matrix-strictness-tests.rkt b/collects/math/tests/matrix-strictness-tests.rkt new file mode 100644 index 0000000000..5b2b61f264 --- /dev/null +++ b/collects/math/tests/matrix-strictness-tests.rkt @@ -0,0 +1,163 @@ +#lang typed/racket + +(require math/matrix + math/array + typed/rackunit) + +(: matrix-double ((Matrix Real) -> (Matrix Real))) +(define (matrix-double M) (matrix-scale M 2)) + +(define nonstrict-2x2-arr + (parameterize ([array-strictness #f]) + (build-matrix 2 2 (λ: ([i : Index] [j : Index]) (if (= i j) 1 0))))) + +(define strict-2x2-arr + (parameterize ([array-strictness #t]) + (build-matrix 2 2 (λ: ([i : Index] [j : Index]) (if (= i j) 1 0))))) + +(check-false (array-strict? nonstrict-2x2-arr)) +(check-true (array-strict? strict-2x2-arr)) + +(define (check-always) + (printf "(array-strictness) = ~v~n" (array-strictness)) + (check-true (array-strict? (matrix [[1 2] [3 4]]))) + (check-true (array-strict? (row-matrix [1 2 3 4]))) + (check-true (array-strict? (col-matrix [1 2 3 4]))) + (check-true (array-strict? (make-matrix 4 4 0))) + (check-true (array-strict? (identity-matrix 6))) + (check-true (array-strict? (diagonal-matrix '(1 2 3 4)))) + (check-true (array-strict? (list->matrix 2 2 '(1 2 3 4)))) + (check-true (array-strict? (vector->matrix 2 2 #(1 2 3 4)))) + (check-true (array-strict? (list*->matrix '((1 2) (3 4))))) + (check-true (array-strict? ((inst vector*->matrix Integer) #(#(1 2) #(3 4))))) + + (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]) + (check-true (array-strict? (matrix-row-echelon M))) + (let-values ([(L U) (matrix-lu M)]) + (check-true (array-strict? L)) + (check-true (array-strict? U)))) + ) + +(parameterize ([array-strictness #t]) + (check-always) + + (check-true (array-strict? (block-diagonal-matrix (list nonstrict-2x2-arr strict-2x2-arr)))) + (check-true (array-strict? (vandermonde-matrix '(1 2 3 4) 10))) + + (check-true (array-strict? (->col-matrix '(1 2 3 4)))) + (check-true (array-strict? (->col-matrix #(1 2 3 4)))) + (check-true (array-strict? (->col-matrix (array #[1 2 3 4])))) + (check-true (array-strict? (->col-matrix (array #[#[1 2 3 4]])))) + (check-true (array-strict? (->col-matrix (array #[#[1] #[2] #[3] #[4]])))) + + (check-true (array-strict? (->row-matrix '(1 2 3 4)))) + (check-true (array-strict? (->row-matrix #(1 2 3 4)))) + (check-true (array-strict? (->row-matrix (array #[1 2 3 4])))) + (check-true (array-strict? (->row-matrix (array #[#[1 2 3 4]])))) + (check-true (array-strict? (->row-matrix (array #[#[1] #[2] #[3] #[4]])))) + + (for*: ([M1 (list nonstrict-2x2-arr strict-2x2-arr)] + [M2 (list nonstrict-2x2-arr strict-2x2-arr)]) + (check-true (array-strict? (matrix* M1 M2))) + (check-true (array-strict? (matrix+ M1 M2))) + (check-true (array-strict? (matrix- M1 M2))) + (check-true (array-strict? (matrix-map * M1 M2))) + (check-true (array-strict? (matrix-sum (list M1 M2)))) + (check-true (array-strict? (matrix-augment (list M1 M2)))) + (check-true (array-strict? (matrix-stack (list M1 M2)))) + (check-true (array-strict? (matrix-solve M1 M2)))) + + (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]) + (check-true (array-strict? (matrix-scale M -1))) + (check-true (array-strict? (matrix-expt M 0))) + (check-true (equal? (array-strict? (matrix-expt M 1)) (array-strict? M))) + (check-true (array-strict? (matrix-expt M 2))) + (check-true (array-strict? (matrix-expt M 3))) + (check-true (array-strict? (matrix-diagonal M))) + (check-true (andmap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-rows M))) + (check-true (andmap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-cols M))) + (check-true (array-strict? (matrix-map-rows matrix-double M))) + (check-true (array-strict? (matrix-map-cols matrix-double M))) + (check-true (array-strict? (matrix-conjugate M))) + (check-true (array-strict? (matrix-transpose M))) + (check-true (array-strict? (matrix-hermitian M))) + (check-true (array-strict? (matrix-normalize M))) + (check-true (array-strict? (matrix-normalize-rows M))) + (check-true (array-strict? (matrix-normalize-cols M))) + (check-true (array-strict? (matrix-inverse M))) + (check-true (array-strict? (matrix-gram-schmidt M))) + (let-values ([(Q R) (matrix-qr M)]) + (check-true (array-strict? Q)) + (check-true (array-strict? R)))) + + (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)] + [i (list 0 1)]) + (check-true (array-strict? (matrix-row M i))) + (check-true (array-strict? (matrix-col M i)))) + + (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)] + [spec (list '(0) 0)]) + (check-true (array-strict? (submatrix M (::) spec)))) + ) + +(parameterize ([array-strictness #f]) + (check-always) + + (check-false (array-strict? (block-diagonal-matrix (list nonstrict-2x2-arr strict-2x2-arr)))) + (check-false (array-strict? (vandermonde-matrix '(1 2 3 4) 10))) + + (check-true (array-strict? (->col-matrix '(1 2 3 4)))) + (check-true (array-strict? (->col-matrix #(1 2 3 4)))) + (check-false (array-strict? (->col-matrix (array #[1 2 3 4])))) + (check-false (array-strict? (->col-matrix (array #[#[1 2 3 4]])))) + (check-true (array-strict? (->col-matrix (array #[#[1] #[2] #[3] #[4]])))) + + (check-false (array-strict? (->row-matrix '(1 2 3 4)))) + (check-false (array-strict? (->row-matrix #(1 2 3 4)))) + (check-false (array-strict? (->row-matrix (array #[1 2 3 4])))) + (check-true (array-strict? (->row-matrix (array #[#[1 2 3 4]])))) + (check-false (array-strict? (->row-matrix (array #[#[1] #[2] #[3] #[4]])))) + + (for*: ([M1 (list nonstrict-2x2-arr strict-2x2-arr)] + [M2 (list nonstrict-2x2-arr strict-2x2-arr)]) + (check-false (array-strict? (matrix* M1 M2))) + (check-false (array-strict? (matrix+ M1 M2))) + (check-false (array-strict? (matrix- M1 M2))) + (check-false (array-strict? (matrix-map * M1 M2))) + (check-false (array-strict? (matrix-sum (list M1 M2)))) + (check-false (array-strict? (matrix-augment (list M1 M2)))) + (check-false (array-strict? (matrix-stack (list M1 M2)))) + (check-false (array-strict? (matrix-solve M1 M2)))) + + (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]) + (check-false (array-strict? (matrix-scale M -1))) + (check-true (array-strict? (matrix-expt M 0))) + (check-false (array-strict? (matrix-expt (array-lazy M) 1))) + (check-false (array-strict? (matrix-expt M 2))) + (check-false (array-strict? (matrix-expt M 3))) + (check-false (array-strict? (matrix-diagonal M))) + (check-false (ormap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-rows M))) + (check-false (ormap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-cols M))) + (check-false (array-strict? (matrix-map-rows matrix-double M))) + (check-false (array-strict? (matrix-map-cols matrix-double M))) + (check-false (array-strict? (matrix-conjugate M))) + (check-false (array-strict? (matrix-transpose M))) + (check-false (array-strict? (matrix-hermitian M))) + (check-false (array-strict? (matrix-normalize M))) + (check-false (array-strict? (matrix-normalize-rows M))) + (check-false (array-strict? (matrix-normalize-cols M))) + (check-false (array-strict? (matrix-inverse M))) + (check-false (array-strict? (matrix-gram-schmidt M))) + (let-values ([(Q R) (matrix-qr M)]) + (check-false (array-strict? Q)) + (check-false (array-strict? R)))) + + (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)] + [spec (list '(0) 0)]) + (check-false (array-strict? (submatrix M (::) spec)))) + + (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)] + [i (list 0 1)]) + (check-false (array-strict? (matrix-row M i))) + (check-false (array-strict? (matrix-col M i)))) + ) diff --git a/collects/math/tests/matrix-untyped-tests.rkt b/collects/math/tests/matrix-untyped-tests.rkt index 11d72324ef..2d0fcd86fd 100644 --- a/collects/math/tests/matrix-untyped-tests.rkt +++ b/collects/math/tests/matrix-untyped-tests.rkt @@ -49,3 +49,18 @@ (check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0))) (check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0))) (check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0))) + +;; =================================================================================================== +;; Arithmetic and mapping macros + +(check-equal? (matrix* (identity-matrix 4) (identity-matrix 4)) + (identity-matrix 4)) + +(check-equal? (matrix+ (identity-matrix 4) (identity-matrix 4)) + (matrix-scale (identity-matrix 4) 2)) + +(check-equal? (matrix- (identity-matrix 4) (identity-matrix 4)) + (make-matrix 4 4 0)) + +(check-equal? (matrix-map (λ (x) (* x -1)) (identity-matrix 4)) + (matrix-scale (identity-matrix 4) -1))