`math/matrix' fixes; please merge to 5.3.2.

* Fixed type of `matrix-expt'

* Made matrix functions respect `array-strictness' parameter (mostly
  wrapping functions with `parameterize' and return values with
  `array-default-strictness'; reindentation makes changes look larger)

* Added strictness tests
(cherry picked from commit f40ad2ca9d)
This commit is contained in:
Neil Toronto 2013-01-16 16:39:46 -07:00 committed by Ryan Culpepper
parent 554713f408
commit a1aa97c1fd
16 changed files with 499 additions and 206 deletions

View File

@ -6,9 +6,9 @@
"private/matrix/matrix-conversion.rkt" "private/matrix/matrix-conversion.rkt"
"private/matrix/matrix-syntax.rkt" "private/matrix/matrix-syntax.rkt"
"private/matrix/matrix-comprehension.rkt" "private/matrix/matrix-comprehension.rkt"
"private/matrix/matrix-expt.rkt"
"private/matrix/matrix-types.rkt" "private/matrix/matrix-types.rkt"
"private/matrix/matrix-2d.rkt" "private/matrix/matrix-2d.rkt"
;;"private/matrix/matrix-expt.rkt" ; all use require/untyped-contract
;;"private/matrix/matrix-gauss-elim.rkt" ; all use require/untyped-contract ;;"private/matrix/matrix-gauss-elim.rkt" ; all use require/untyped-contract
(except-in "private/matrix/matrix-solve.rkt" (except-in "private/matrix/matrix-solve.rkt"
matrix-determinant matrix-determinant
@ -36,6 +36,11 @@
;;"private/matrix/matrix-gram-schmidt.rkt" ; all use require/untyped-contract ;;"private/matrix/matrix-gram-schmidt.rkt" ; all use require/untyped-contract
) )
(require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt"))
"private/matrix/matrix-expt.rkt"
[matrix-expt ((Matrix Number) Integer -> (Matrix Number))])
(require/untyped-contract (require/untyped-contract
(begin (require "private/matrix/matrix-types.rkt" (begin (require "private/matrix/matrix-types.rkt"
"private/matrix/matrix-gauss-elim.rkt")) "private/matrix/matrix-gauss-elim.rkt"))
@ -148,9 +153,10 @@
"private/matrix/matrix-solve.rkt" "private/matrix/matrix-solve.rkt"
"private/matrix/matrix-operator-norm.rkt" "private/matrix/matrix-operator-norm.rkt"
"private/matrix/matrix-comprehension.rkt" "private/matrix/matrix-comprehension.rkt"
"private/matrix/matrix-expt.rkt"
"private/matrix/matrix-types.rkt" "private/matrix/matrix-types.rkt"
"private/matrix/matrix-2d.rkt") "private/matrix/matrix-2d.rkt")
;; matrix/matrix-expt.rkt
matrix-expt
;; matrix-gauss-elim.rkt ;; matrix-gauss-elim.rkt
matrix-gauss-elim matrix-gauss-elim
matrix-row-echelon matrix-row-echelon

View File

@ -81,13 +81,14 @@
(raise-argument-error 'matrix-row (format "Index < ~a" m) 1 a i)] (raise-argument-error 'matrix-row (format "Index < ~a" m) 1 a i)]
[else [else
(define proc (unsafe-array-proc a)) (define proc (unsafe-array-proc a))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) 1 n) ((inst vector Index) 1 n)
(λ: ([ij : Indexes]) (λ: ([ij : Indexes])
(unsafe-vector-set! ij 0 i) (unsafe-vector-set! ij 0 i)
(define res (proc ij)) (define res (proc ij))
(unsafe-vector-set! ij 0 0) (unsafe-vector-set! ij 0 0)
res))])) res)))]))
(: matrix-col (All (A) (Matrix A) Integer -> (Matrix A))) (: matrix-col (All (A) (Matrix A) Integer -> (Matrix A)))
(define (matrix-col a j) (define (matrix-col a j)
@ -96,53 +97,61 @@
(raise-argument-error 'matrix-row (format "Index < ~a" n) 1 a j)] (raise-argument-error 'matrix-row (format "Index < ~a" n) 1 a j)]
[else [else
(define proc (unsafe-array-proc a)) (define proc (unsafe-array-proc a))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m 1) ((inst vector Index) m 1)
(λ: ([ij : Indexes]) (λ: ([ij : Indexes])
(unsafe-vector-set! ij 1 j) (unsafe-vector-set! ij 1 j)
(define res (proc ij)) (define res (proc ij))
(unsafe-vector-set! ij 1 0) (unsafe-vector-set! ij 1 0)
res))])) res)))]))
(: matrix-rows (All (A) (Matrix A) -> (Listof (Matrix A)))) (: matrix-rows (All (A) (Matrix A) -> (Listof (Matrix A))))
(define (matrix-rows a) (define (matrix-rows a)
(array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0)) (map (λ: ([a : (Matrix A)]) (array-default-strict a))
(parameterize ([array-strictness #f])
(array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0))))
(: matrix-cols (All (A) (Matrix A) -> (Listof (Matrix A)))) (: matrix-cols (All (A) (Matrix A) -> (Listof (Matrix A))))
(define (matrix-cols a) (define (matrix-cols a)
(array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1)) (map (λ: ([a : (Matrix A)]) (array-default-strict a))
(parameterize ([array-strictness #f])
(array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1))))
(: matrix-diagonal (All (A) ((Matrix A) -> (Array A)))) (: matrix-diagonal (All (A) ((Matrix A) -> (Array A))))
(define (matrix-diagonal a) (define (matrix-diagonal a)
(define m (square-matrix-size a)) (define m (square-matrix-size a))
(define proc (unsafe-array-proc a)) (define proc (unsafe-array-proc a))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m) ((inst vector Index) m)
(λ: ([js : Indexes]) (λ: ([js : Indexes])
(define i (unsafe-vector-ref js 0)) (define i (unsafe-vector-ref js 0))
(proc ((inst vector Index) i i))))) (proc ((inst vector Index) i i))))))
(: matrix-upper-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) (: matrix-upper-triangle (All (A) ((Matrix A) -> (Matrix (U A 0)))))
(define (matrix-upper-triangle M) (define (matrix-upper-triangle M)
(define-values (m n) (matrix-shape M)) (define-values (m n) (matrix-shape M))
(define proc (unsafe-array-proc M)) (define proc (unsafe-array-proc M))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m n) ((inst vector Index) m n)
(λ: ([ij : Indexes]) (λ: ([ij : Indexes])
(define i (unsafe-vector-ref ij 0)) (define i (unsafe-vector-ref ij 0))
(define j (unsafe-vector-ref ij 1)) (define j (unsafe-vector-ref ij 1))
(if (i . fx<= . j) (proc ij) 0)))) (if (i . fx<= . j) (proc ij) 0)))))
(: matrix-lower-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) (: matrix-lower-triangle (All (A) ((Matrix A) -> (Matrix (U A 0)))))
(define (matrix-lower-triangle M) (define (matrix-lower-triangle M)
(define-values (m n) (matrix-shape M)) (define-values (m n) (matrix-shape M))
(define proc (unsafe-array-proc M)) (define proc (unsafe-array-proc M))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m n) ((inst vector Index) m n)
(λ: ([ij : Indexes]) (λ: ([ij : Indexes])
(define i (unsafe-vector-ref ij 0)) (define i (unsafe-vector-ref ij 0))
(define j (unsafe-vector-ref ij 1)) (define j (unsafe-vector-ref ij 1))
(if (i . fx>= . j) (proc ij) 0)))) (if (i . fx>= . j) (proc ij) 0)))))
;; =================================================================================================== ;; ===================================================================================================
;; Embiggenment (this is a perfectly cromulent word) ;; Embiggenment (this is a perfectly cromulent word)
@ -176,24 +185,28 @@
(: matrix-1norm ((Matrix Number) -> Nonnegative-Real)) (: matrix-1norm ((Matrix Number) -> Nonnegative-Real))
(define (matrix-1norm a) (define (matrix-1norm a)
(array-all-sum (array-magnitude a))) (parameterize ([array-strictness #f])
(array-all-sum (array-magnitude a))))
(: matrix-2norm ((Matrix Number) -> Nonnegative-Real)) (: matrix-2norm ((Matrix Number) -> Nonnegative-Real))
(define (matrix-2norm a) (define (matrix-2norm a)
(parameterize ([array-strictness #f])
(let ([a (array-strict (array-magnitude a))]) (let ([a (array-strict (array-magnitude a))])
;; Compute this divided by the maximum to avoid underflow and overflow ;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max a)) (define mx (array-all-max a))
(cond [(and (rational? mx) (positive? mx)) (cond [(and (rational? mx) (positive? mx))
(* mx (sqrt (array-all-sum (* mx (sqrt (array-all-sum
(inline-array-map (λ: ([x : Nonnegative-Real]) (sqr (/ x mx))) a))))] (inline-array-map (λ: ([x : Nonnegative-Real]) (sqr (/ x mx))) a))))]
[else mx]))) [else mx]))))
(: matrix-inf-norm ((Matrix Number) -> Nonnegative-Real)) (: matrix-inf-norm ((Matrix Number) -> Nonnegative-Real))
(define (matrix-inf-norm a) (define (matrix-inf-norm a)
(array-all-max (array-magnitude a))) (parameterize ([array-strictness #f])
(array-all-max (array-magnitude a))))
(: matrix-p-norm ((Matrix Number) Positive-Real -> Nonnegative-Real)) (: matrix-p-norm ((Matrix Number) Positive-Real -> Nonnegative-Real))
(define (matrix-p-norm a p) (define (matrix-p-norm a p)
(parameterize ([array-strictness #f])
(let ([a (array-strict (array-magnitude a))]) (let ([a (array-strict (array-magnitude a))])
;; Compute this divided by the maximum to avoid underflow and overflow ;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max a)) (define mx (array-all-max a))
@ -203,7 +216,7 @@
(inline-array-map (λ: ([x : Nonnegative-Real]) (expt (/ x mx) p)) a)) (inline-array-map (λ: ([x : Nonnegative-Real]) (expt (/ x mx) p)) a))
(/ p))) (/ p)))
(make-predicate Nonnegative-Real))] (make-predicate Nonnegative-Real))]
[else mx]))) [else mx]))))
(: matrix-norm (case-> ((Matrix Number) -> Nonnegative-Real) (: matrix-norm (case-> ((Matrix Number) -> Nonnegative-Real)
((Matrix Number) Real -> Nonnegative-Real))) ((Matrix Number) Real -> Nonnegative-Real)))
@ -224,21 +237,23 @@
(define matrix-dot (define matrix-dot
(case-lambda (case-lambda
[(a) [(a)
(parameterize ([array-strictness #f])
(assert (assert
(array-all-sum (array-all-sum
(inline-array-map (inline-array-map
(λ (x) (* x (conjugate x))) (λ (x) (* x (conjugate x)))
(ensure-matrix 'matrix-dot a))) (ensure-matrix 'matrix-dot a)))
(make-predicate Nonnegative-Real))] (make-predicate Nonnegative-Real)))]
[(a b) [(a b)
(define-values (m n) (matrix-shapes 'matrix-dot a b)) (define-values (m n) (matrix-shapes 'matrix-dot a b))
(define aproc (unsafe-array-proc a)) (define aproc (unsafe-array-proc a))
(define bproc (unsafe-array-proc b)) (define bproc (unsafe-array-proc b))
(parameterize ([array-strictness #f])
(array-all-sum (array-all-sum
(unsafe-build-array (unsafe-build-array
((inst vector Index) m n) ((inst vector Index) m n)
(λ: ([js : Indexes]) (λ: ([js : Indexes])
(* (aproc js) (conjugate (bproc js))))))])) (* (aproc js) (conjugate (bproc js)))))))]))
(: matrix-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real) (: matrix-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real)
((Matrix Number) (Matrix Number) -> Number))) ((Matrix Number) (Matrix Number) -> Number)))
@ -283,12 +298,15 @@
(: matrix-hermitian (case-> ((Matrix Real) -> (Matrix Real)) (: matrix-hermitian (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Number) -> (Matrix Number)))) ((Matrix Number) -> (Matrix Number))))
(define (matrix-hermitian a) (define (matrix-hermitian a)
(array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1)) (array-default-strict
(parameterize ([array-strictness #f])
(array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1))))
(: matrix-trace (case-> ((Matrix Real) -> Real) (: matrix-trace (case-> ((Matrix Real) -> Real)
((Matrix Number) -> Number))) ((Matrix Number) -> Number)))
(define (matrix-trace a) (define (matrix-trace a)
(array-all-sum (matrix-diagonal a))) (parameterize ([array-strictness #f])
(array-all-sum (matrix-diagonal a))))
;; =================================================================================================== ;; ===================================================================================================
;; Row/column operations ;; Row/column operations
@ -382,11 +400,13 @@
((Matrix Number) Real -> Boolean))) ((Matrix Number) Real -> Boolean)))
(define (matrix-rows-orthogonal? M [eps (* 10 epsilon.0)]) (define (matrix-rows-orthogonal? M [eps (* 10 epsilon.0)])
(cond [(negative? eps) (raise-argument-error 'matrix-rows-orthogonal? "Nonnegative-Real" 1 M eps)] (cond [(negative? eps) (raise-argument-error 'matrix-rows-orthogonal? "Nonnegative-Real" 1 M eps)]
[else (pairwise-orthogonal? (matrix-rows M) eps)])) [else (parameterize ([array-strictness #f])
(pairwise-orthogonal? (matrix-rows M) eps))]))
(: matrix-cols-orthogonal? (case-> ((Matrix Number) -> Boolean) (: matrix-cols-orthogonal? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean))) ((Matrix Number) Real -> Boolean)))
(define (matrix-cols-orthogonal? M [eps (* 10 epsilon.0)]) (define (matrix-cols-orthogonal? M [eps (* 10 epsilon.0)])
(cond [(negative? eps) (raise-argument-error 'matrix-cols-orthogonal? "Nonnegative-Real" 1 M eps)] (cond [(negative? eps) (raise-argument-error 'matrix-cols-orthogonal? "Nonnegative-Real" 1 M eps)]
[else (pairwise-orthogonal? (matrix-cols M) eps)])) [else (parameterize ([array-strictness #f])
(pairwise-orthogonal? (matrix-cols M) eps))]))

View File

@ -36,11 +36,12 @@
[(or (not (index? n)) (= n 0)) [(or (not (index? n)) (= n 0))
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)] (raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
[else [else
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m n) ((inst vector Index) m n)
(λ: ([js : Indexes]) (λ: ([js : Indexes])
(proc (unsafe-vector-ref js 0) (proc (unsafe-vector-ref js 0)
(unsafe-vector-ref js 1))))])) (unsafe-vector-ref js 1)))))]))
;; =================================================================================================== ;; ===================================================================================================
;; Diagonal matrices ;; Diagonal matrices
@ -52,7 +53,7 @@
[else [else
(define vs (list->vector xs)) (define vs (list->vector xs))
(define m (vector-length vs)) (define m (vector-length vs))
(unsafe-build-array (unsafe-build-simple-array
((inst vector Index) m m) ((inst vector Index) m m)
(λ: ([js : Indexes]) (λ: ([js : Indexes])
(define i (unsafe-vector-ref js 0)) (define i (unsafe-vector-ref js 0))
@ -97,6 +98,7 @@
(vector-set! js (unsafe-fx+ res-j j) (assert j index?)))) (vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n)))) (values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
(define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as)) (define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) res-m res-n) ((inst vector Index) res-m res-n)
(λ: ([ij : Indexes]) (λ: ([ij : Indexes])
@ -114,7 +116,7 @@
(unsafe-vector-set! ij 1 j) (unsafe-vector-set! ij 1 j)
res] res]
[else [else
zero])))) zero])))))
(: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A)))) (: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A))))
(define (block-diagonal-matrix/zero as zero) (define (block-diagonal-matrix/zero as zero)

View File

@ -88,17 +88,19 @@
[(= num-ones dims) [(= num-ones dims)
(define: js : (Vectorof Index) (make-vector dims 0)) (define: js : (Vectorof Index) (make-vector dims 0))
(define proc (unsafe-array-proc arr)) (define proc (unsafe-array-proc arr))
(array-default-strict
(unsafe-build-array ((inst vector Index) 1 1) (unsafe-build-array ((inst vector Index) 1 1)
(λ: ([ij : Indexes]) (proc js)))] (λ: ([ij : Indexes]) (proc js))))]
[(= num-ones (- dims 1)) [(= num-ones (- dims 1))
(define-values (k m) (find-nontrivial-axis ds)) (define-values (k m) (find-nontrivial-axis ds))
(define js (make-thread-local-indexes dims)) (define js (make-thread-local-indexes dims))
(define proc (unsafe-array-proc arr)) (define proc (unsafe-array-proc arr))
(array-default-strict
(unsafe-build-array ((inst vector Index) m 1) (unsafe-build-array ((inst vector Index) m 1)
(λ: ([ij : Indexes]) (λ: ([ij : Indexes])
(let ([js (js)]) (let ([js (js)])
(unsafe-vector-set! js k (unsafe-vector-ref ij 0)) (unsafe-vector-set! js k (unsafe-vector-ref ij 0))
(proc js))))] (proc js)))))]
[else (fail)])) [else (fail)]))
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Matrix A)))) (: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Matrix A))))
@ -157,7 +159,8 @@
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A)))) (: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
(define (matrix->list* a) (define (matrix->list* a)
(cond [(matrix? a) (array->list (array->list-array a 1))] (cond [(matrix? a) (parameterize ([array-strictness #f])
(array->list (array->list-array a 1)))]
[else (raise-argument-error 'matrix->list* "matrix?" a)])) [else (raise-argument-error 'matrix->list* "matrix?" a)]))
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A))) (: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
@ -171,5 +174,6 @@
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A)))) (: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
(define (matrix->vector* a) (define (matrix->vector* a)
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))] (cond [(matrix? a) (parameterize ([array-strictness #f])
(array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector)))]
[else (raise-argument-error 'matrix->vector* "matrix?" a)])) [else (raise-argument-error 'matrix->vector* "matrix?" a)]))

View File

@ -2,19 +2,30 @@
(require "matrix-types.rkt" (require "matrix-types.rkt"
"matrix-constructors.rkt" "matrix-constructors.rkt"
"matrix-arithmetic.rkt") "matrix-arithmetic.rkt"
"utils.rkt")
(provide matrix-expt) (provide matrix-expt)
(: matrix-expt : (Matrix Number) Integer -> (Matrix Number)) (: matrix-expt/ns (case-> ((Matrix Real) Positive-Integer -> (Matrix Real))
((Matrix Number) Positive-Integer -> (Matrix Number))))
(define (matrix-expt/ns a n)
(define n/2 (quotient n 2))
(if (zero? n/2)
;; n = 1
a
(let ([m (* n/2 2)])
(if (= n m)
;; n is even
(let ([a^n/2 (matrix-expt/ns a n/2)])
(matrix* a^n/2 a^n/2))
;; m = n - 1
(matrix* a (matrix-expt/ns a m))))))
(: matrix-expt (case-> ((Matrix Real) Integer -> (Matrix Real))
((Matrix Number) Integer -> (Matrix Number))))
(define (matrix-expt a n) (define (matrix-expt a n)
(cond [(not (square-matrix? a)) (raise-argument-error 'matrix-expt "square-matrix?" 0 a n)] (cond [(not (square-matrix? a)) (raise-argument-error 'matrix-expt "square-matrix?" 0 a n)]
[(negative? n) (raise-argument-error 'matrix-expt "Natural" 1 a n)] [(negative? n) (raise-argument-error 'matrix-expt "Natural" 1 a n)]
[(zero? n) (identity-matrix (square-matrix-size a))] [(zero? n) (identity-matrix (square-matrix-size a))]
[else [else (call/ns (λ () (matrix-expt/ns a n)))]))
(let: loop : (Matrix Number) ([n : Positive-Integer n])
(cond [(= n 1) a]
[(= n 2) (matrix* a a)]
[(even? n) (let ([a^n/2 (matrix-expt a (quotient n 2))])
(matrix* a^n/2 a^n/2))]
[else (matrix* a (matrix-expt a (sub1 n)))]))]))

View File

@ -36,15 +36,11 @@
(vector-sub-proj! (unsafe-vector-ref rows i) row #f) (vector-sub-proj! (unsafe-vector-ref rows i) row #f)
(loop (fx+ i 1))))) (loop (fx+ i 1)))))
(: matrix-gram-schmidt (case-> ((Matrix Real) -> (Array Real)) (: matrix-gram-schmidt/ns (case-> ((Matrix Real) Any Integer -> (Array Real))
((Matrix Real) Any -> (Array Real))
((Matrix Real) Any Integer -> (Array Real))
((Matrix Number) -> (Array Number))
((Matrix Number) Any -> (Array Number))
((Matrix Number) Any Integer -> (Array Number)))) ((Matrix Number) Any Integer -> (Array Number))))
;; Performs Gram-Schmidt orthogonalization on M, assuming the rows before `start' are already ;; Performs Gram-Schmidt orthogonalization on M, assuming the rows before `start' are already
;; orthogonal ;; orthogonal
(define (matrix-gram-schmidt M [normalize? #f] [start 0]) (define (matrix-gram-schmidt/ns M normalize? start)
(define rows (matrix->vector* (matrix-transpose M))) (define rows (matrix->vector* (matrix-transpose M)))
(define m (vector-length rows)) (define m (vector-length rows))
(define i (find-nonzero-vector rows)) (define i (find-nonzero-vector rows))
@ -66,9 +62,18 @@
[else [else
(make-array (vector (matrix-num-rows M) 0) 0)])) (make-array (vector (matrix-num-rows M) 0) 0)]))
(: matrix-basis-extension (case-> ((Matrix Real) -> (Array Real)) (: matrix-gram-schmidt (case-> ((Matrix Real) -> (Array Real))
((Matrix Real) Any -> (Array Real))
((Matrix Real) Any Integer -> (Array Real))
((Matrix Number) -> (Array Number))
((Matrix Number) Any -> (Array Number))
((Matrix Number) Any Integer -> (Array Number))))
(define (matrix-gram-schmidt M [normalize? #f] [start 0])
(call/ns (λ () (matrix-gram-schmidt/ns M normalize? start))))
(: matrix-basis-extension/ns (case-> ((Matrix Real) -> (Array Real))
((Matrix Number) -> (Array Number)))) ((Matrix Number) -> (Array Number))))
(define (matrix-basis-extension B) (define (matrix-basis-extension/ns B)
(define-values (m n) (matrix-shape B)) (define-values (m n) (matrix-shape B))
(cond [(n . < . m) (cond [(n . < . m)
(define S (matrix-gram-schmidt (matrix-augment (list B (identity-matrix m))) #f n)) (define S (matrix-gram-schmidt (matrix-augment (list B (identity-matrix m))) #f n))
@ -78,3 +83,8 @@
(make-array (vector m 0) 0)] (make-array (vector m 0) 0)]
[else [else
(raise-argument-error 'matrix-extend-row-basis "matrix? with width < height" B)])) (raise-argument-error 'matrix-extend-row-basis "matrix? with width < height" B)]))
(: matrix-basis-extension (case-> ((Matrix Real) -> (Array Real))
((Matrix Number) -> (Array Number))))
(define (matrix-basis-extension B)
(call/ns (λ () (matrix-basis-extension/ns B))))

View File

@ -7,7 +7,8 @@
"utils.rkt" "utils.rkt"
"../unsafe.rkt" "../unsafe.rkt"
"../vector/vector-mutate.rkt" "../vector/vector-mutate.rkt"
"../array/mutable-array.rkt") "../array/mutable-array.rkt"
"../array/array-struct.rkt")
(provide matrix-lu) (provide matrix-lu)
@ -24,8 +25,10 @@
[(M fail) [(M fail)
(define m (square-matrix-size M)) (define m (square-matrix-size M))
(define rows (matrix->vector* M)) (define rows (matrix->vector* M))
(define L
(parameterize ([array-strictness #f])
;; Construct L in a weird way to prove to TR that it has the right type ;; Construct L in a weird way to prove to TR that it has the right type
(define L (array->mutable-array (matrix-scale M (ann 0 Real)))) (array->mutable-array (matrix-scale M (ann 0 Real)))))
;; Going to fill in the lower triangle by banging values into `ys' ;; Going to fill in the lower triangle by banging values into `ys'
(define ys (mutable-array-data L)) (define ys (mutable-array-data L))
(let loop ([#{i : Nonnegative-Fixnum} 0]) (let loop ([#{i : Nonnegative-Fixnum} 0])

View File

@ -45,7 +45,8 @@ See "How to Measure Errors" in the LAPACK manual for more details:
(: matrix-op-1norm ((Matrix Number) -> Nonnegative-Real)) (: matrix-op-1norm ((Matrix Number) -> Nonnegative-Real))
;; When M is a column matrix, this is equivalent to matrix-1norm ;; When M is a column matrix, this is equivalent to matrix-1norm
(define (matrix-op-1norm M) (define (matrix-op-1norm M)
(assert (apply max (map matrix-1norm (matrix-cols M))) nonnegative?)) (parameterize ([array-strictness #f])
(assert (apply max (map matrix-1norm (matrix-cols M))) nonnegative?)))
(: matrix-op-2norm ((Matrix Number) -> Nonnegative-Real)) (: matrix-op-2norm ((Matrix Number) -> Nonnegative-Real))
;; When M is a column matrix, this is equivalent to matrix-2norm ;; When M is a column matrix, this is equivalent to matrix-2norm
@ -57,7 +58,8 @@ See "How to Measure Errors" in the LAPACK manual for more details:
(: matrix-op-inf-norm ((Matrix Number) -> Nonnegative-Real)) (: matrix-op-inf-norm ((Matrix Number) -> Nonnegative-Real))
;; When M is a column matrix, this is equivalent to matrix-inf-norm ;; When M is a column matrix, this is equivalent to matrix-inf-norm
(define (matrix-op-inf-norm M) (define (matrix-op-inf-norm M)
(assert (apply max (map matrix-1norm (matrix-rows M))) nonnegative?)) (parameterize ([array-strictness #f])
(assert (apply max (map matrix-1norm (matrix-rows M))) nonnegative?)))
(: matrix-basis-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real) (: matrix-basis-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real)
((Matrix Number) (Matrix Number) -> Number))) ((Matrix Number) (Matrix Number) -> Number)))
@ -83,6 +85,7 @@ See "How to Measure Errors" in the LAPACK manual for more details:
((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real) ((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real)
-> Nonnegative-Real))) -> Nonnegative-Real)))
(define (matrix-absolute-error M R [norm (matrix-error-norm)]) (define (matrix-absolute-error M R [norm (matrix-error-norm)])
(parameterize ([array-strictness #f])
(define-values (m n) (matrix-shapes 'matrix-absolute-error M R)) (define-values (m n) (matrix-shapes 'matrix-absolute-error M R))
(array-strict! M) (array-strict! M)
(array-strict! R) (array-strict! R)
@ -91,13 +94,14 @@ See "How to Measure Errors" in the LAPACK manual for more details:
(array-all-and (inline-array-map number-rational? R))) (array-all-and (inline-array-map number-rational? R)))
(norm (matrix- (inline-array-map inexact->exact M) (norm (matrix- (inline-array-map inexact->exact M)
(inline-array-map inexact->exact R)))] (inline-array-map inexact->exact R)))]
[else +inf.0])) [else +inf.0])))
(: matrix-relative-error (: matrix-relative-error
(case-> ((Matrix Number) (Matrix Number) -> Nonnegative-Real) (case-> ((Matrix Number) (Matrix Number) -> Nonnegative-Real)
((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real) ((Matrix Number) (Matrix Number) ((Matrix Number) -> Nonnegative-Real)
-> Nonnegative-Real))) -> Nonnegative-Real)))
(define (matrix-relative-error M R [norm (matrix-error-norm)]) (define (matrix-relative-error M R [norm (matrix-error-norm)])
(parameterize ([array-strictness #f])
(define-values (m n) (matrix-shapes 'matrix-relative-error M R)) (define-values (m n) (matrix-shapes 'matrix-relative-error M R))
(array-strict! M) (array-strict! M)
(array-strict! R) (array-strict! R)
@ -109,7 +113,7 @@ See "How to Measure Errors" in the LAPACK manual for more details:
(cond [(and (zero? num) (zero? den)) 0] (cond [(and (zero? num) (zero? den)) 0]
[(zero? den) +inf.0] [(zero? den) +inf.0]
[else (assert (/ num den) nonnegative?)])] [else (assert (/ num den) nonnegative?)])]
[else +inf.0])) [else +inf.0])))
;; =================================================================================================== ;; ===================================================================================================
;; Approximate predicates ;; Approximate predicates

View File

@ -5,7 +5,8 @@
"matrix-arithmetic.rkt" "matrix-arithmetic.rkt"
"matrix-constructors.rkt" "matrix-constructors.rkt"
"matrix-gram-schmidt.rkt" "matrix-gram-schmidt.rkt"
"../array/array-transform.rkt") "../array/array-transform.rkt"
"../array/array-struct.rkt")
(provide matrix-qr) (provide matrix-qr)
@ -23,11 +24,9 @@ produces matrices for which `matrix-orthogonal?' returns #t with eps <= 10*epsil
independently of the matrix size. independently of the matrix size.
|# |#
(: matrix-qr (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) (: matrix-qr/ns (case-> ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))) ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number)))))
(define (matrix-qr M [full? #t]) (define (matrix-qr/ns M full?)
(define B (matrix-gram-schmidt M #f)) (define B (matrix-gram-schmidt M #f))
(define Q (define Q
(matrix-gram-schmidt (matrix-gram-schmidt
@ -37,3 +36,13 @@ independently of the matrix size.
[else (matrix-col (identity-matrix (matrix-num-rows M)) 0)]) [else (matrix-col (identity-matrix (matrix-num-rows M)) 0)])
#t)) #t))
(values Q (matrix-upper-triangle (matrix* (matrix-hermitian Q) M)))) (values Q (matrix-upper-triangle (matrix* (matrix-hermitian Q) M))))
(: matrix-qr (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number)))))
(define (matrix-qr M [full? #t])
(define-values (Q R) (parameterize ([array-strictness #f])
(matrix-qr/ns M full?)))
(values (array-default-strict Q)
(array-default-strict R)))

View File

@ -11,7 +11,8 @@
"utils.rkt" "utils.rkt"
"../vector/vector-mutate.rkt" "../vector/vector-mutate.rkt"
"../array/array-indexing.rkt" "../array/array-indexing.rkt"
"../array/mutable-array.rkt") "../array/mutable-array.rkt"
"../array/array-struct.rkt")
(provide (provide
matrix-determinant matrix-determinant
@ -80,7 +81,8 @@
[(M fail) [(M fail)
(define m (square-matrix-size M)) (define m (square-matrix-size M))
(define I (identity-matrix m)) (define I (identity-matrix m))
(define-values (IM^-1 wps) (matrix-gauss-elim (matrix-augment (list M I)) #t #t)) (define-values (IM^-1 wps) (parameterize ([array-strictness #f])
(matrix-gauss-elim (matrix-augment (list M I)) #t #t)))
(cond [(and (not (empty? wps)) (= (first wps) m)) (cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IM^-1 (::) (:: m #f))] (submatrix IM^-1 (::) (:: m #f))]
[else (fail)])])) [else (fail)])]))
@ -100,7 +102,8 @@
(define m (square-matrix-size M)) (define m (square-matrix-size M))
(define-values (s t) (matrix-shape B)) (define-values (s t) (matrix-shape B))
(cond [(= m s) (cond [(= m s)
(define-values (IX wps) (matrix-gauss-elim (matrix-augment (list M B)) #t #t)) (define-values (IX wps) (parameterize ([array-strictness #f])
(matrix-gauss-elim (matrix-augment (list M B)) #t #t)))
(cond [(and (not (empty? wps)) (= (first wps) m)) (cond [(and (not (empty? wps)) (= (first wps) m))
(submatrix IX (::) (:: m #f))] (submatrix IX (::) (:: m #f))]
[else (fail)])] [else (fail)])]

View File

@ -7,7 +7,8 @@
"matrix-gauss-elim.rkt" "matrix-gauss-elim.rkt"
"utils.rkt" "utils.rkt"
"../array/array-indexing.rkt" "../array/array-indexing.rkt"
"../array/array-constructors.rkt") "../array/array-constructors.rkt"
"../array/array-struct.rkt")
(provide (provide
matrix-rank matrix-rank
@ -34,18 +35,13 @@
(cond [(= j0 j1) Bs] (cond [(= j0 j1) Bs]
[else (cons (submatrix M (::) (:: j0 j1)) Bs)])) [else (cons (submatrix M (::) (:: j0 j1)) Bs)]))
(: matrix-col-space (All (A) (case-> ((Matrix Real) -> (Matrix Real)) (: matrix-col-space/ns (All (A) (case-> ((Matrix Real) -> (U #f (Matrix Real)))
((Matrix Real) (-> A) -> (U A (Matrix Real))) ((Matrix Number) -> (U #f (Matrix Number))))))
((Matrix Number) -> (Matrix Number)) (define (matrix-col-space/ns M)
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-col-space
(case-lambda
[(M) (matrix-col-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))]
[(M fail)
(define n (matrix-num-cols M)) (define n (matrix-num-cols M))
(define-values (_ wps) (matrix-gauss-elim M)) (define-values (_ wps) (matrix-gauss-elim M))
(cond [(empty? wps) M] (cond [(empty? wps) M]
[(= (length wps) n) (fail)] [(= (length wps) n) #f]
[else [else
(define next-j (first wps)) (define next-j (first wps))
(define Bs (maybe-cons-submatrix M 0 next-j empty)) (define Bs (maybe-cons-submatrix M 0 next-j empty))
@ -54,4 +50,16 @@
(matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))] (matrix-augment (reverse (maybe-cons-submatrix M (fx+ j 1) n Bs)))]
[else [else
(define next-j (first wps)) (define next-j (first wps))
(loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])])) (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))]))
(: matrix-col-space (All (A) (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Real) (-> A) -> (U A (Matrix Real)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) (-> A) -> (U A (Matrix Number))))))
(define matrix-col-space
(case-lambda
[(M) (matrix-col-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))]
[(M fail)
(define S (parameterize ([array-strictness #f])
(matrix-col-space/ns M)))
(if S (array-default-strict S) (fail))]))

View File

@ -31,10 +31,11 @@
(define g0 (unsafe-array-proc arr0)) (define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1)) (define g1 (unsafe-array-proc arr1))
(define gs (map unsafe-array-proc arrs)) (define gs (map unsafe-array-proc arrs))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m n) ((inst vector Index) m n)
(λ: ([js : Indexes]) (apply f (g0 js) (g1 js) (λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> T)]) (g js)) gs))))])) (map (λ: ([g : (Indexes -> T)]) (g js)) gs)))))]))
(: matrix=? ((Matrix Number) (Matrix Number) -> Boolean)) (: matrix=? ((Matrix Number) (Matrix Number) -> Boolean))
(define (matrix=? arr0 arr1) (define (matrix=? arr0 arr1)
@ -44,10 +45,11 @@
(= n0 n1) (= n0 n1)
(let ([proc0 (unsafe-array-proc arr0)] (let ([proc0 (unsafe-array-proc arr0)]
[proc1 (unsafe-array-proc arr1)]) [proc1 (unsafe-array-proc arr1)])
(parameterize ([array-strictness #f])
(array-all-and (unsafe-build-array (array-all-and (unsafe-build-array
((inst vector Index) m0 n0) ((inst vector Index) m0 n0)
(λ: ([js : Indexes]) (λ: ([js : Indexes])
(= (proc0 js) (proc1 js)))))))) (= (proc0 js) (proc1 js)))))))))
(: matrix= (case-> ((Matrix Number) (Matrix Number) -> Boolean) (: matrix= (case-> ((Matrix Number) (Matrix Number) -> Boolean)
((Matrix Number) (Matrix Number) (Matrix Number) (Matrix Number) * -> Boolean))) ((Matrix Number) (Matrix Number) (Matrix Number) (Matrix Number) * -> Boolean)))
@ -62,28 +64,41 @@
[else (and (matrix=? arr1 (first arrs)) [else (and (matrix=? arr1 (first arrs))
(loop (first arrs) (rest arrs)))])))])) (loop (first arrs) (rest arrs)))])))]))
(: matrix*/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real))
((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number))))
(define (matrix*/ns a as)
(cond [(empty? as) a]
[else (matrix*/ns (inline-matrix-multiply a (first as)) (rest as))]))
(: matrix* (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) (: matrix* (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real))
((Matrix Number) (Matrix Number) * -> (Matrix Number)))) ((Matrix Number) (Matrix Number) * -> (Matrix Number))))
(define (matrix* a . as) (define (matrix* a . as) (call/ns (λ () (matrix*/ns a as))))
(let loop ([a a] [as as])
(: matrix+/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real))
((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number))))
(define (matrix+/ns a as)
(cond [(empty? as) a] (cond [(empty? as) a]
[else (loop (inline-matrix* a (first as)) (rest as))]))) [else (matrix+/ns (inline-matrix+ a (first as)) (rest as))]))
(: matrix+ (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) (: matrix+ (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real))
((Matrix Number) (Matrix Number) * -> (Matrix Number)))) ((Matrix Number) (Matrix Number) * -> (Matrix Number))))
(define (matrix+ a . as) (define (matrix+ a . as) (call/ns (λ () (matrix+/ns a as))))
(let loop ([a a] [as as])
(: matrix-/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real))
((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number))))
(define (matrix-/ns a as)
(cond [(empty? as) a] (cond [(empty? as) a]
[else (loop (inline-matrix+ a (first as)) (rest as))]))) [else (matrix-/ns (inline-matrix- a (first as)) (rest as))]))
(: matrix- (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) (: matrix- (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real))
((Matrix Number) (Matrix Number) * -> (Matrix Number)))) ((Matrix Number) (Matrix Number) * -> (Matrix Number))))
(define (matrix- a . as) (define (matrix- a . as)
(cond [(empty? as) (inline-matrix- a)] (call/ns (λ () (cond [(empty? as) (inline-matrix- a)]
[else [else (matrix-/ns a as)]))))
(let loop ([a a] [as as])
(cond [(empty? as) a]
[else (loop (inline-matrix- a (first as)) (rest as))]))]))
(: matrix-scale (case-> ((Matrix Real) Real -> (Matrix Real)) (: matrix-scale (case-> ((Matrix Real) Real -> (Matrix Real))
((Matrix Number) Number -> (Matrix Number)))) ((Matrix Number) Number -> (Matrix Number))))

View File

@ -1,6 +1,7 @@
#lang racket/base #lang racket/base
(provide inline-matrix* (provide inline-matrix-multiply
inline-matrix*
inline-matrix+ inline-matrix+
inline-matrix- inline-matrix-
inline-matrix-scale inline-matrix-scale
@ -41,12 +42,17 @@
(* (arr-proc js) (brr-proc js)))) (* (arr-proc js) (brr-proc js))))
2))))) 2)))))
(define-syntax (inline-matrix* stx) (define-syntax (do-inline-matrix* stx)
(syntax-case stx () (syntax-case stx ()
[(_ arr) [(_ arr)
(syntax/loc stx arr)] (syntax/loc stx arr)]
[(_ arr brr crrs ...) [(_ arr brr crrs ...)
(syntax/loc stx (inline-matrix* (inline-matrix-multiply arr brr) crrs ...))])) (syntax/loc stx (do-inline-matrix* (inline-matrix-multiply arr brr) crrs ...))]))
(define-syntax-rule (inline-matrix* arr brrs ...)
(array-default-strict
(parameterize ([array-strictness #f])
(do-inline-matrix* arr brrs ...))))
(define-syntax (inline-matrix-map stx) (define-syntax (inline-matrix-map stx)
(syntax-case stx () (syntax-case stx ()
@ -55,7 +61,8 @@
(let*-values ([(arr) arr-expr] (let*-values ([(arr) arr-expr]
[(m n) (matrix-shape arr)] [(m n) (matrix-shape arr)]
[(proc) (unsafe-array-proc arr)]) [(proc) (unsafe-array-proc arr)])
(unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (f (proc js))))))] (array-default-strict
(unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (f (proc js)))))))]
[(_ f arr-expr brr-exprs ...) [(_ f arr-expr brr-exprs ...)
(with-syntax ([(brrs ...) (generate-temporaries #'(brr-exprs ...))] (with-syntax ([(brrs ...) (generate-temporaries #'(brr-exprs ...))]
[(procs ...) (generate-temporaries #'(brr-exprs ...))]) [(procs ...) (generate-temporaries #'(brr-exprs ...))])
@ -65,10 +72,11 @@
(let-values ([(m n) (matrix-shapes 'matrix-map arr brrs ...)] (let-values ([(m n) (matrix-shapes 'matrix-map arr brrs ...)]
[(proc) (unsafe-array-proc arr)] [(proc) (unsafe-array-proc arr)]
[(procs) (unsafe-array-proc brrs)] ...) [(procs) (unsafe-array-proc brrs)] ...)
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m n) ((inst vector Index) m n)
(λ: ([js : Indexes]) (λ: ([js : Indexes])
(f (proc js) (procs js) ...)))))))])) (f (proc js) (procs js) ...))))))))]))
(define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...)) (define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...))
(define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...)) (define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...))
@ -101,10 +109,11 @@
(define g0 (unsafe-array-proc arr0)) (define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1)) (define g1 (unsafe-array-proc arr1))
(define gs (map (inst unsafe-array-proc A) arrs)) (define gs (map (inst unsafe-array-proc A) arrs))
(array-default-strict
(unsafe-build-array (unsafe-build-array
((inst vector Index) m n) ((inst vector Index) m n)
(λ: ([js : Indexes]) (apply f (g0 js) (g1 js) (λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> A)]) (g js)) gs))))])) (map (λ: ([g : (Indexes -> A)]) (g js)) gs)))))]))
) ; module ) ; module

View File

@ -1,6 +1,7 @@
#lang typed/racket/base #lang typed/racket/base
(require racket/string (require racket/performance-hint
racket/string
racket/fixnum racket/fixnum
"matrix-types.rkt" "matrix-types.rkt"
"../unsafe.rkt" "../unsafe.rkt"
@ -113,3 +114,13 @@
;; Make sure the element below the pivot is zero ;; Make sure the element below the pivot is zero
(unsafe-vector-set! row_l j (- x_lj x_lj)))) (unsafe-vector-set! row_l j (- x_lj x_lj))))
(loop (fx+ l 1))))) (loop (fx+ l 1)))))
(begin-encourage-inline
(: call/ns (All (A) ((-> (Matrix A)) -> (Matrix A))))
(define (call/ns thnk)
(array-default-strict
(parameterize ([array-strictness #f])
(thnk))))
) ; begin-encourage-inline

View File

@ -0,0 +1,163 @@
#lang typed/racket
(require math/matrix
math/array
typed/rackunit)
(: matrix-double ((Matrix Real) -> (Matrix Real)))
(define (matrix-double M) (matrix-scale M 2))
(define nonstrict-2x2-arr
(parameterize ([array-strictness #f])
(build-matrix 2 2 (λ: ([i : Index] [j : Index]) (if (= i j) 1 0)))))
(define strict-2x2-arr
(parameterize ([array-strictness #t])
(build-matrix 2 2 (λ: ([i : Index] [j : Index]) (if (= i j) 1 0)))))
(check-false (array-strict? nonstrict-2x2-arr))
(check-true (array-strict? strict-2x2-arr))
(define (check-always)
(printf "(array-strictness) = ~v~n" (array-strictness))
(check-true (array-strict? (matrix [[1 2] [3 4]])))
(check-true (array-strict? (row-matrix [1 2 3 4])))
(check-true (array-strict? (col-matrix [1 2 3 4])))
(check-true (array-strict? (make-matrix 4 4 0)))
(check-true (array-strict? (identity-matrix 6)))
(check-true (array-strict? (diagonal-matrix '(1 2 3 4))))
(check-true (array-strict? (list->matrix 2 2 '(1 2 3 4))))
(check-true (array-strict? (vector->matrix 2 2 #(1 2 3 4))))
(check-true (array-strict? (list*->matrix '((1 2) (3 4)))))
(check-true (array-strict? ((inst vector*->matrix Integer) #(#(1 2) #(3 4)))))
(for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (matrix-row-echelon M)))
(let-values ([(L U) (matrix-lu M)])
(check-true (array-strict? L))
(check-true (array-strict? U))))
)
(parameterize ([array-strictness #t])
(check-always)
(check-true (array-strict? (block-diagonal-matrix (list nonstrict-2x2-arr strict-2x2-arr))))
(check-true (array-strict? (vandermonde-matrix '(1 2 3 4) 10)))
(check-true (array-strict? (->col-matrix '(1 2 3 4))))
(check-true (array-strict? (->col-matrix #(1 2 3 4))))
(check-true (array-strict? (->col-matrix (array #[1 2 3 4]))))
(check-true (array-strict? (->col-matrix (array #[#[1 2 3 4]]))))
(check-true (array-strict? (->col-matrix (array #[#[1] #[2] #[3] #[4]]))))
(check-true (array-strict? (->row-matrix '(1 2 3 4))))
(check-true (array-strict? (->row-matrix #(1 2 3 4))))
(check-true (array-strict? (->row-matrix (array #[1 2 3 4]))))
(check-true (array-strict? (->row-matrix (array #[#[1 2 3 4]]))))
(check-true (array-strict? (->row-matrix (array #[#[1] #[2] #[3] #[4]]))))
(for*: ([M1 (list nonstrict-2x2-arr strict-2x2-arr)]
[M2 (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (matrix* M1 M2)))
(check-true (array-strict? (matrix+ M1 M2)))
(check-true (array-strict? (matrix- M1 M2)))
(check-true (array-strict? (matrix-map * M1 M2)))
(check-true (array-strict? (matrix-sum (list M1 M2))))
(check-true (array-strict? (matrix-augment (list M1 M2))))
(check-true (array-strict? (matrix-stack (list M1 M2))))
(check-true (array-strict? (matrix-solve M1 M2))))
(for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (matrix-scale M -1)))
(check-true (array-strict? (matrix-expt M 0)))
(check-true (equal? (array-strict? (matrix-expt M 1)) (array-strict? M)))
(check-true (array-strict? (matrix-expt M 2)))
(check-true (array-strict? (matrix-expt M 3)))
(check-true (array-strict? (matrix-diagonal M)))
(check-true (andmap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-rows M)))
(check-true (andmap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-cols M)))
(check-true (array-strict? (matrix-map-rows matrix-double M)))
(check-true (array-strict? (matrix-map-cols matrix-double M)))
(check-true (array-strict? (matrix-conjugate M)))
(check-true (array-strict? (matrix-transpose M)))
(check-true (array-strict? (matrix-hermitian M)))
(check-true (array-strict? (matrix-normalize M)))
(check-true (array-strict? (matrix-normalize-rows M)))
(check-true (array-strict? (matrix-normalize-cols M)))
(check-true (array-strict? (matrix-inverse M)))
(check-true (array-strict? (matrix-gram-schmidt M)))
(let-values ([(Q R) (matrix-qr M)])
(check-true (array-strict? Q))
(check-true (array-strict? R))))
(for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]
[i (list 0 1)])
(check-true (array-strict? (matrix-row M i)))
(check-true (array-strict? (matrix-col M i))))
(for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]
[spec (list '(0) 0)])
(check-true (array-strict? (submatrix M (::) spec))))
)
(parameterize ([array-strictness #f])
(check-always)
(check-false (array-strict? (block-diagonal-matrix (list nonstrict-2x2-arr strict-2x2-arr))))
(check-false (array-strict? (vandermonde-matrix '(1 2 3 4) 10)))
(check-true (array-strict? (->col-matrix '(1 2 3 4))))
(check-true (array-strict? (->col-matrix #(1 2 3 4))))
(check-false (array-strict? (->col-matrix (array #[1 2 3 4]))))
(check-false (array-strict? (->col-matrix (array #[#[1 2 3 4]]))))
(check-true (array-strict? (->col-matrix (array #[#[1] #[2] #[3] #[4]]))))
(check-false (array-strict? (->row-matrix '(1 2 3 4))))
(check-false (array-strict? (->row-matrix #(1 2 3 4))))
(check-false (array-strict? (->row-matrix (array #[1 2 3 4]))))
(check-true (array-strict? (->row-matrix (array #[#[1 2 3 4]]))))
(check-false (array-strict? (->row-matrix (array #[#[1] #[2] #[3] #[4]]))))
(for*: ([M1 (list nonstrict-2x2-arr strict-2x2-arr)]
[M2 (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (matrix* M1 M2)))
(check-false (array-strict? (matrix+ M1 M2)))
(check-false (array-strict? (matrix- M1 M2)))
(check-false (array-strict? (matrix-map * M1 M2)))
(check-false (array-strict? (matrix-sum (list M1 M2))))
(check-false (array-strict? (matrix-augment (list M1 M2))))
(check-false (array-strict? (matrix-stack (list M1 M2))))
(check-false (array-strict? (matrix-solve M1 M2))))
(for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (matrix-scale M -1)))
(check-true (array-strict? (matrix-expt M 0)))
(check-false (array-strict? (matrix-expt (array-lazy M) 1)))
(check-false (array-strict? (matrix-expt M 2)))
(check-false (array-strict? (matrix-expt M 3)))
(check-false (array-strict? (matrix-diagonal M)))
(check-false (ormap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-rows M)))
(check-false (ormap (λ: ([M : (Matrix Real)]) (array-strict? M)) (matrix-cols M)))
(check-false (array-strict? (matrix-map-rows matrix-double M)))
(check-false (array-strict? (matrix-map-cols matrix-double M)))
(check-false (array-strict? (matrix-conjugate M)))
(check-false (array-strict? (matrix-transpose M)))
(check-false (array-strict? (matrix-hermitian M)))
(check-false (array-strict? (matrix-normalize M)))
(check-false (array-strict? (matrix-normalize-rows M)))
(check-false (array-strict? (matrix-normalize-cols M)))
(check-false (array-strict? (matrix-inverse M)))
(check-false (array-strict? (matrix-gram-schmidt M)))
(let-values ([(Q R) (matrix-qr M)])
(check-false (array-strict? Q))
(check-false (array-strict? R))))
(for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]
[spec (list '(0) 0)])
(check-false (array-strict? (submatrix M (::) spec))))
(for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]
[i (list 0 1)])
(check-false (array-strict? (matrix-row M i)))
(check-false (array-strict? (matrix-col M i))))
)

View File

@ -49,3 +49,18 @@
(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0))) (check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0)))
(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0))) (check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0)))
(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0))) (check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0)))
;; ===================================================================================================
;; Arithmetic and mapping macros
(check-equal? (matrix* (identity-matrix 4) (identity-matrix 4))
(identity-matrix 4))
(check-equal? (matrix+ (identity-matrix 4) (identity-matrix 4))
(matrix-scale (identity-matrix 4) 2))
(check-equal? (matrix- (identity-matrix 4) (identity-matrix 4))
(make-matrix 4 4 0))
(check-equal? (matrix-map (λ (x) (* x -1)) (identity-matrix 4))
(matrix-scale (identity-matrix 4) -1))