From f42cc6f14aec8a46f9f5c3ff801cd6570a2d55b7 Mon Sep 17 00:00:00 2001 From: Neil Toronto Date: Mon, 21 Jan 2013 21:55:22 -0700 Subject: [PATCH] Fixed major performance issue with matrix arithmetic; please merge to 5.3.2 The fix consists of three parts: 1. Rewriting `inline-matrix*'. The material change here is that the expansion now contains only direct applications of `+' and `*'. TR's optimizer replaces them with `unsafe-fx+' and `unsafe-fx*', which keeps intermediate flonum values from being boxed. 2. Making the types of all functions that operate on (Matrix Number) values more precise. Now TR can prove that matrix operations preserve inexactness. For example, matrix-conjugate : (Matrix Flonum) -> (Matrix Flonum) and three other cases for Real, Float-Complex, and Number. 3. Changing the return types of some functions that used to return things like (Matrix (U A 0)). Now that we worry about preserving inexactness, we can't have `matrix-upper-triangle' always return a matrix that contains exact zeros. It now accepts an optional `zero' argument of type A. --- collects/math/matrix.rkt | 22 +++ collects/math/private/matrix/matrix-basic.rkt | 173 ++++++++++++------ .../private/matrix/matrix-constructors.rkt | 49 +++-- collects/math/private/matrix/matrix-expt.rkt | 22 ++- .../math/private/matrix/matrix-gauss-elim.rkt | 24 ++- .../private/matrix/matrix-gram-schmidt.rkt | 40 +++- collects/math/private/matrix/matrix-lu.rkt | 18 +- .../private/matrix/matrix-operator-norm.rkt | 23 ++- collects/math/private/matrix/matrix-qr.rkt | 23 ++- collects/math/private/matrix/matrix-solve.rkt | 38 ++-- .../math/private/matrix/matrix-subspace.rkt | 16 +- .../matrix/typed-matrix-arithmetic.rkt | 43 +++-- .../matrix/untyped-matrix-arithmetic.rkt | 92 +++++++--- collects/math/private/matrix/utils.rkt | 66 ++++++- .../math/private/vector/vector-mutate.rkt | 98 +++++++--- collects/math/scribblings/math-matrix.scrbl | 76 +++++--- .../math/tests/matrix-strictness-tests.rkt | 2 +- 17 files changed, 606 insertions(+), 219 deletions(-) diff --git a/collects/math/matrix.rkt b/collects/math/matrix.rkt index b57d64f3e1..f042208d78 100644 --- a/collects/math/matrix.rkt +++ b/collects/math/matrix.rkt @@ -17,6 +17,10 @@ (except-in "private/matrix/matrix-constructors.rkt" vandermonde-matrix) (except-in "private/matrix/matrix-basic.rkt" + matrix-1norm + matrix-2norm + matrix-inf-norm + matrix-norm matrix-dot matrix-cos-angle matrix-angle @@ -29,6 +33,9 @@ (except-in "private/matrix/matrix-subspace.rkt" matrix-col-space) (except-in "private/matrix/matrix-operator-norm.rkt" + matrix-op-1norm + matrix-op-2norm + matrix-op-inf-norm matrix-basis-cos-angle matrix-basis-angle) ;;"private/matrix/matrix-qr.rkt" ; all use require/untyped-contract @@ -77,6 +84,11 @@ (require/untyped-contract (begin (require "private/matrix/matrix-types.rkt")) "private/matrix/matrix-basic.rkt" + [matrix-1norm ((Matrix Number) -> Nonnegative-Real)] + [matrix-2norm ((Matrix Number) -> Nonnegative-Real)] + [matrix-inf-norm ((Matrix Number) -> Nonnegative-Real)] + [matrix-norm (case-> ((Matrix Number) -> Nonnegative-Real) + ((Matrix Number) Real -> Nonnegative-Real))] [matrix-dot (case-> ((Matrix Number) -> Nonnegative-Real) ((Matrix Number) (Matrix Number) -> Number))] @@ -113,6 +125,9 @@ (require/untyped-contract (begin (require "private/matrix/matrix-types.rkt")) "private/matrix/matrix-operator-norm.rkt" + [matrix-op-1norm ((Matrix Number) -> Nonnegative-Real)] + [matrix-op-2norm ((Matrix Number) -> Nonnegative-Real)] + [matrix-op-inf-norm ((Matrix Number) -> Nonnegative-Real)] [matrix-basis-cos-angle ((Matrix Number) (Matrix Number) -> Number)] [matrix-basis-angle @@ -167,6 +182,10 @@ ;; matrix-constructors.rkt vandermonde-matrix ;; matrix-basic.rkt + matrix-1norm + matrix-2norm + matrix-inf-norm + matrix-norm matrix-dot matrix-cos-angle matrix-angle @@ -179,6 +198,9 @@ ;; matrix-subspace.rkt matrix-col-space ;; matrix-operator-norm.rkt + matrix-op-1norm + matrix-op-2norm + matrix-op-inf-norm matrix-basis-cos-angle matrix-basis-angle ;; matrix-qr.rkt diff --git a/collects/math/private/matrix/matrix-basic.rkt b/collects/math/private/matrix/matrix-basic.rkt index 2556544c96..0bf2993886 100644 --- a/collects/math/private/matrix/matrix-basic.rkt +++ b/collects/math/private/matrix/matrix-basic.rkt @@ -130,29 +130,37 @@ (define i (unsafe-vector-ref js 0)) (proc ((inst vector Index) i i)))))) -(: matrix-upper-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) -(define (matrix-upper-triangle M) - (define-values (m n) (matrix-shape M)) - (define proc (unsafe-array-proc M)) - (array-default-strict - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([ij : Indexes]) - (define i (unsafe-vector-ref ij 0)) - (define j (unsafe-vector-ref ij 1)) - (if (i . fx<= . j) (proc ij) 0))))) +(: matrix-upper-triangle (All (A) (case-> ((Matrix A) -> (Matrix (U A 0))) + ((Matrix A) A -> (Matrix A))))) +(define matrix-upper-triangle + (case-lambda + [(M) (matrix-upper-triangle M 0)] + [(M zero) + (define-values (m n) (matrix-shape M)) + (define proc (unsafe-array-proc M)) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (if (i . fx<= . j) (proc ij) zero))))])) -(: matrix-lower-triangle (All (A) ((Matrix A) -> (Matrix (U A 0))))) -(define (matrix-lower-triangle M) - (define-values (m n) (matrix-shape M)) - (define proc (unsafe-array-proc M)) - (array-default-strict - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([ij : Indexes]) - (define i (unsafe-vector-ref ij 0)) - (define j (unsafe-vector-ref ij 1)) - (if (i . fx>= . j) (proc ij) 0))))) +(: matrix-lower-triangle (All (A) (case-> ((Matrix A) -> (Matrix (U A 0))) + ((Matrix A) A -> (Matrix A))))) +(define matrix-lower-triangle + (case-lambda + [(M) (matrix-lower-triangle M 0)] + [(M zero) + (define-values (m n) (matrix-shape M)) + (define proc (unsafe-array-proc M)) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (if (i . fx>= . j) (proc ij) zero))))])) ;; =================================================================================================== ;; Embiggenment (this is a perfectly cromulent word) @@ -184,42 +192,63 @@ ;; =================================================================================================== ;; Inner product space (entrywise norm) -(: matrix-1norm ((Matrix Number) -> Nonnegative-Real)) -(define (matrix-1norm a) - (parameterize ([array-strictness #f]) - (array-all-sum (array-magnitude a)))) +(: nonstupid-magnitude (case-> (Flonum -> Nonnegative-Flonum) + (Real -> Nonnegative-Real) + (Float-Complex -> Nonnegative-Flonum) + (Number -> Nonnegative-Real))) +(define (nonstupid-magnitude x) + (if (real? x) (abs x) (magnitude x))) -(: matrix-2norm ((Matrix Number) -> Nonnegative-Real)) -(define (matrix-2norm a) +(: matrix-1norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Real) -> Nonnegative-Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Number) -> Nonnegative-Real))) +(define (matrix-1norm M) (parameterize ([array-strictness #f]) - (let ([a (array-strict (array-magnitude a))]) + (array-all-sum (inline-array-map nonstupid-magnitude M)))) + +(: matrix-2norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Real) -> Nonnegative-Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Number) -> Nonnegative-Real))) +(define (matrix-2norm M) + (parameterize ([array-strictness #f]) + (let ([M (array-strict (inline-array-map nonstupid-magnitude M))]) ;; Compute this divided by the maximum to avoid underflow and overflow - (define mx (array-all-max a)) + (define mx (array-all-max M)) (cond [(and (rational? mx) (positive? mx)) - (* mx (sqrt (array-all-sum - (inline-array-map (λ: ([x : Nonnegative-Real]) (sqr (/ x mx))) a))))] + (* mx (sqrt (array-all-sum (inline-array-map (λ (x) (sqr (/ x mx))) M))))] [else mx])))) -(: matrix-inf-norm ((Matrix Number) -> Nonnegative-Real)) -(define (matrix-inf-norm a) +(: matrix-inf-norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Real) -> Nonnegative-Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Number) -> Nonnegative-Real))) +(define (matrix-inf-norm M) (parameterize ([array-strictness #f]) - (array-all-max (array-magnitude a)))) + (array-all-max (inline-array-map nonstupid-magnitude M)))) -(: matrix-p-norm ((Matrix Number) Positive-Real -> Nonnegative-Real)) -(define (matrix-p-norm a p) +(: matrix-p-norm (case-> ((Matrix Flonum) Positive-Real -> Nonnegative-Flonum) + ((Matrix Real) Positive-Real -> Nonnegative-Real) + ((Matrix Float-Complex) Positive-Real -> Nonnegative-Flonum) + ((Matrix Number) Positive-Real -> Nonnegative-Real))) +(define (matrix-p-norm M p) (parameterize ([array-strictness #f]) - (let ([a (array-strict (array-magnitude a))]) + (let ([M (array-strict (inline-array-map nonstupid-magnitude M))]) ;; Compute this divided by the maximum to avoid underflow and overflow - (define mx (array-all-max a)) + (define mx (array-all-max M)) (cond [(and (rational? mx) (positive? mx)) - (assert - (* mx (expt (array-all-sum - (inline-array-map (λ: ([x : Nonnegative-Real]) (expt (/ x mx) p)) a)) - (/ p))) - (make-predicate Nonnegative-Real))] + (* mx (expt (array-all-sum (inline-array-map (λ (x) (expt (abs (/ x mx)) p)) M)) + (/ p)))] [else mx])))) -(: matrix-norm (case-> ((Matrix Number) -> Nonnegative-Real) +(: matrix-norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Flonum) Real -> Nonnegative-Flonum) + ((Matrix Real) -> Nonnegative-Real) + ((Matrix Real) Real -> Nonnegative-Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Float-Complex) Real -> Nonnegative-Flonum) + ((Matrix Number) -> Nonnegative-Real) ((Matrix Number) Real -> Nonnegative-Real))) ;; Computes the p norm of a matrix (define (matrix-norm a [p 2]) @@ -230,8 +259,12 @@ [(p . > . 1) (matrix-p-norm a p)] [else (raise-argument-error 'matrix-norm "Real >= 1" 1 a p)])) -(: matrix-dot (case-> ((Matrix Real) -> Nonnegative-Real) +(: matrix-dot (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Flonum) (Matrix Flonum) -> Flonum) + ((Matrix Real) -> Nonnegative-Real) ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex) ((Matrix Number) -> Nonnegative-Real) ((Matrix Number) (Matrix Number) -> Number))) ;; Computes the Frobenius inner product of a matrix with itself or of two matrices @@ -256,20 +289,30 @@ (λ: ([js : Indexes]) (* (aproc js) (conjugate (bproc js)))))))])) -(: matrix-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real) +(: matrix-cos-angle (case-> ((Matrix Flonum) (Matrix Flonum) -> Flonum) + ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex) ((Matrix Number) (Matrix Number) -> Number))) (define (matrix-cos-angle M N) (/ (matrix-dot M N) (* (matrix-2norm M) (matrix-2norm N)))) -(: matrix-angle (case-> ((Matrix Real) (Matrix Real) -> Real) +(: matrix-angle (case-> ((Matrix Flonum) (Matrix Flonum) -> Flonum) + ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex) ((Matrix Number) (Matrix Number) -> Number))) (define (matrix-angle M N) (acos (matrix-cos-angle M N))) (: matrix-normalize - (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + (All (A) (case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Flonum) Real -> (Matrix Flonum)) + ((Matrix Flonum) Real (-> A) -> (U A (Matrix Flonum))) + ((Matrix Real) -> (Matrix Real)) ((Matrix Real) Real -> (Matrix Real)) ((Matrix Real) Real (-> A) -> (U A (Matrix Real))) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Real -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Real (-> A) -> (U A (Matrix Float-Complex))) ((Matrix Number) -> (Matrix Number)) ((Matrix Number) Real -> (Matrix Number)) ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))) @@ -291,19 +334,25 @@ (define (matrix-transpose a) (array-axis-swap (ensure-matrix 'matrix-transpose a) 0 1)) -(: matrix-conjugate (case-> ((Matrix Real) -> (Matrix Real)) +(: matrix-conjugate (case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Real) -> (Matrix Real)) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) ((Matrix Number) -> (Matrix Number)))) (define (matrix-conjugate a) (array-conjugate (ensure-matrix 'matrix-conjugate a))) -(: matrix-hermitian (case-> ((Matrix Real) -> (Matrix Real)) +(: matrix-hermitian (case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Real) -> (Matrix Real)) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) ((Matrix Number) -> (Matrix Number)))) (define (matrix-hermitian a) (array-default-strict (parameterize ([array-strictness #f]) (array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1)))) -(: matrix-trace (case-> ((Matrix Real) -> Real) +(: matrix-trace (case-> ((Matrix Flonum) -> Flonum) + ((Matrix Real) -> Real) + ((Matrix Float-Complex) -> Float-Complex) ((Matrix Number) -> Number))) (define (matrix-trace a) (cond [(square-matrix? a) @@ -349,15 +398,23 @@ [else (fail)])]))] [else (fail)])])) -(: make-matrix-normalize (Real -> (case-> ((Matrix Real) -> (U #f (Matrix Real))) +(: make-matrix-normalize (Real -> (case-> ((Matrix Flonum) -> (U #f (Matrix Flonum))) + ((Matrix Real) -> (U #f (Matrix Real))) + ((Matrix Float-Complex) -> (U #f (Matrix Float-Complex))) ((Matrix Number) -> (U #f (Matrix Number)))))) (define ((make-matrix-normalize p) M) (matrix-normalize M p (λ () #f))) (: matrix-normalize-rows - (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + (All (A) (case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Flonum) Real -> (Matrix Flonum)) + ((Matrix Flonum) Real (-> A) -> (U A (Matrix Flonum))) + ((Matrix Real) -> (Matrix Real)) ((Matrix Real) Real -> (Matrix Real)) ((Matrix Real) Real (-> A) -> (U A (Matrix Real))) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Real -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Real (-> A) -> (U A (Matrix Float-Complex))) ((Matrix Number) -> (Matrix Number)) ((Matrix Number) Real -> (Matrix Number)) ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))) @@ -371,9 +428,15 @@ (matrix-map-rows (make-matrix-normalize p) M fail)])) (: matrix-normalize-cols - (All (A) (case-> ((Matrix Real) -> (Matrix Real)) + (All (A) (case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Flonum) Real -> (Matrix Flonum)) + ((Matrix Flonum) Real (-> A) -> (U A (Matrix Flonum))) + ((Matrix Real) -> (Matrix Real)) ((Matrix Real) Real -> (Matrix Real)) ((Matrix Real) Real (-> A) -> (U A (Matrix Real))) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Real -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Real (-> A) -> (U A (Matrix Float-Complex))) ((Matrix Number) -> (Matrix Number)) ((Matrix Number) Real -> (Matrix Number)) ((Matrix Number) Real (-> A) -> (U A (Matrix Number)))))) diff --git a/collects/math/private/matrix/matrix-constructors.rkt b/collects/math/private/matrix/matrix-constructors.rkt index a666230e86..2eac16b1b6 100644 --- a/collects/math/private/matrix/matrix-constructors.rkt +++ b/collects/math/private/matrix/matrix-constructors.rkt @@ -1,8 +1,10 @@ #lang typed/racket/base (require racket/fixnum + racket/flonum racket/list racket/vector + math/base "matrix-types.rkt" "../unsafe.rkt" "../array/array-struct.rkt" @@ -22,8 +24,14 @@ ;; =================================================================================================== ;; Basic constructors -(: identity-matrix (Integer -> (Matrix (U 0 1)))) -(define (identity-matrix m) (diagonal-array 2 m 1 0)) +(: identity-matrix (All (A) (case-> (Integer -> (Matrix (U 1 0))) + (Integer A -> (Matrix (U A 0))) + (Integer A A -> (Matrix A))))) +(define identity-matrix + (case-lambda + [(m) (diagonal-array 2 m 1 0)] + [(m one) (diagonal-array 2 m one 0)] + [(m one zero) (diagonal-array 2 m one zero)])) (: make-matrix (All (A) (Integer Integer A -> (Matrix A)))) (define (make-matrix m n x) @@ -60,9 +68,12 @@ (cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)] [else zero])))])) -(: diagonal-matrix (All (A) ((Listof A) -> (Matrix (U A 0))))) -(define (diagonal-matrix xs) - (diagonal-matrix/zero xs 0)) +(: diagonal-matrix (All (A) (case-> ((Listof A) -> (Matrix (U A 0))) + ((Listof A) A -> (Matrix A))))) +(define diagonal-matrix + (case-lambda + [(xs) (diagonal-matrix/zero xs 0)] + [(xs zero) (diagonal-matrix/zero xs zero)])) ;; =================================================================================================== ;; Block diagonal matrices @@ -129,21 +140,29 @@ [else (block-diagonal-matrix/zero* as zero)]))) -(: block-diagonal-matrix (All (A) ((Listof (Matrix A)) -> (Matrix (U A 0))))) -(define (block-diagonal-matrix as) - (block-diagonal-matrix/zero as 0)) +(: block-diagonal-matrix (All (A) (case-> ((Listof (Matrix A)) -> (Matrix (U A 0))) + ((Listof (Matrix A)) A -> (Matrix A))))) +(define block-diagonal-matrix + (case-lambda + [(as) (block-diagonal-matrix/zero as 0)] + [(as zero) (block-diagonal-matrix/zero as zero)])) ;; =================================================================================================== ;; Special matrices -(: expt-hack (case-> (Real Integer -> Real) - (Number Integer -> Number))) -;; Stop using this when TR correctly derives expt : Real Integer -> Real -(define (expt-hack x n) - (cond [(real? x) (assert (expt x n) real?)] +(: sane-expt (case-> (Flonum Index -> Flonum) + (Real Index -> Real) + (Float-Complex Index -> Float-Complex) + (Number Index -> Number))) +(define (sane-expt x n) + (cond [(flonum? x) (flexpt x (real->double-flonum n))] + [(real? x) (real-part (expt x n))] ; remove `real-part' when expt : Real Index -> Real + [(float-complex? x) (number->float-complex (expt x n))] [else (expt x n)])) -(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Matrix Real)) +(: vandermonde-matrix (case-> ((Listof Flonum) Integer -> (Matrix Flonum)) + ((Listof Real) Integer -> (Matrix Real)) + ((Listof Float-Complex) Integer -> (Matrix Float-Complex)) ((Listof Number) Integer -> (Matrix Number)))) (define (vandermonde-matrix xs n) (cond [(empty? xs) @@ -151,4 +170,4 @@ [(or (not (index? n)) (zero? n)) (raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)] [else - (array-axis-expand (list->array xs) 1 n expt-hack)])) + (array-axis-expand (list->array xs) 1 n sane-expt)])) diff --git a/collects/math/private/matrix/matrix-expt.rkt b/collects/math/private/matrix/matrix-expt.rkt index 42491c1ec9..bdff8e6af7 100644 --- a/collects/math/private/matrix/matrix-expt.rkt +++ b/collects/math/private/matrix/matrix-expt.rkt @@ -3,11 +3,16 @@ (require "matrix-types.rkt" "matrix-constructors.rkt" "matrix-arithmetic.rkt" - "utils.rkt") + "utils.rkt" + "../array/array-struct.rkt" + "../array/utils.rkt" + "../unsafe.rkt") (provide matrix-expt) -(: matrix-expt/ns (case-> ((Matrix Real) Positive-Integer -> (Matrix Real)) +(: matrix-expt/ns (case-> ((Matrix Flonum) Positive-Integer -> (Matrix Flonum)) + ((Matrix Real) Positive-Integer -> (Matrix Real)) + ((Matrix Float-Complex) Positive-Integer -> (Matrix Float-Complex)) ((Matrix Number) Positive-Integer -> (Matrix Number)))) (define (matrix-expt/ns a n) (define n/2 (quotient n 2)) @@ -22,10 +27,19 @@ ;; m = n - 1 (matrix* a (matrix-expt/ns a m)))))) -(: matrix-expt (case-> ((Matrix Real) Integer -> (Matrix Real)) +(: matrix-expt (case-> ((Matrix Flonum) Integer -> (Matrix Flonum)) + ((Matrix Real) Integer -> (Matrix Real)) + ((Matrix Float-Complex) Integer -> (Matrix Float-Complex)) ((Matrix Number) Integer -> (Matrix Number)))) (define (matrix-expt a n) (cond [(not (square-matrix? a)) (raise-argument-error 'matrix-expt "square-matrix?" 0 a n)] [(negative? n) (raise-argument-error 'matrix-expt "Natural" 1 a n)] - [(zero? n) (identity-matrix (square-matrix-size a))] + [(zero? n) + (define proc (unsafe-array-proc a)) + (array-default-strict + (unsafe-build-array + (array-shape a) + (λ: ([ij : Indexes]) + (define x (proc ij)) + (if (= (unsafe-vector-ref ij 0) (unsafe-vector-ref ij 1)) (one* x) (zero* x)))))] [else (call/ns (λ () (matrix-expt/ns a n)))])) diff --git a/collects/math/private/matrix/matrix-gauss-elim.rkt b/collects/math/private/matrix/matrix-gauss-elim.rkt index 19719ac7d6..f334f8647b 100644 --- a/collects/math/private/matrix/matrix-gauss-elim.rkt +++ b/collects/math/private/matrix/matrix-gauss-elim.rkt @@ -16,10 +16,18 @@ (define-type Pivoting (U 'first 'partial)) (: matrix-gauss-elim - (case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index))) + (case-> ((Matrix Flonum) -> (Values (Matrix Flonum) (Listof Index))) + ((Matrix Flonum) Any -> (Values (Matrix Flonum) (Listof Index))) + ((Matrix Flonum) Any Any -> (Values (Matrix Flonum) (Listof Index))) + ((Matrix Flonum) Any Any Pivoting -> (Values (Matrix Flonum) (Listof Index))) + ((Matrix Real) -> (Values (Matrix Real) (Listof Index))) ((Matrix Real) Any -> (Values (Matrix Real) (Listof Index))) ((Matrix Real) Any Any -> (Values (Matrix Real) (Listof Index))) ((Matrix Real) Any Any Pivoting -> (Values (Matrix Real) (Listof Index))) + ((Matrix Float-Complex) -> (Values (Matrix Float-Complex) (Listof Index))) + ((Matrix Float-Complex) Any -> (Values (Matrix Float-Complex) (Listof Index))) + ((Matrix Float-Complex) Any Any -> (Values (Matrix Float-Complex) (Listof Index))) + ((Matrix Float-Complex) Any Any Pivoting -> (Values (Matrix Float-Complex) (Listof Index))) ((Matrix Number) -> (Values (Matrix Number) (Listof Index))) ((Matrix Number) Any -> (Values (Matrix Number) (Listof Index))) ((Matrix Number) Any Any -> (Values (Matrix Number) (Listof Index))) @@ -52,17 +60,25 @@ (vector-swap! rows i p) ;; Possibly unitize the new current row (let ([pivot (if unitize-pivot? - (begin (vector-scale! (unsafe-vector-ref rows i) (/ pivot)) - 1) + (begin (vector-scale! (unsafe-vector-ref rows i) (/ 1 pivot)) + (/ pivot pivot)) pivot)]) (elim-rows! rows m i j pivot (if jordan? 0 (fx+ i 1))) (loop (fx+ i 1) (fx+ j 1) without-pivot))])]))) (: matrix-row-echelon - (case-> ((Matrix Real) -> (Matrix Real)) + (case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Flonum) Any -> (Matrix Flonum)) + ((Matrix Flonum) Any Any -> (Matrix Flonum)) + ((Matrix Flonum) Any Any Pivoting -> (Matrix Flonum)) + ((Matrix Real) -> (Matrix Real)) ((Matrix Real) Any -> (Matrix Real)) ((Matrix Real) Any Any -> (Matrix Real)) ((Matrix Real) Any Any Pivoting -> (Matrix Real)) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Any -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Any Any -> (Matrix Float-Complex)) + ((Matrix Float-Complex) Any Any Pivoting -> (Matrix Float-Complex)) ((Matrix Number) -> (Matrix Number)) ((Matrix Number) Any -> (Matrix Number)) ((Matrix Number) Any Any -> (Matrix Number)) diff --git a/collects/math/private/matrix/matrix-gram-schmidt.rkt b/collects/math/private/matrix/matrix-gram-schmidt.rkt index 5905b73b83..f8f0886b8c 100644 --- a/collects/math/private/matrix/matrix-gram-schmidt.rkt +++ b/collects/math/private/matrix/matrix-gram-schmidt.rkt @@ -16,7 +16,9 @@ (provide matrix-gram-schmidt matrix-basis-extension) -(: find-nonzero-vector (case-> ((Vectorof (Vectorof Real)) -> (U #f Index)) +(: find-nonzero-vector (case-> ((Vectorof (Vectorof Flonum)) -> (U #f Index)) + ((Vectorof (Vectorof Real)) -> (U #f Index)) + ((Vectorof (Vectorof Float-Complex)) -> (U #f Index)) ((Vectorof (Vectorof Number)) -> (U #f Index)))) (define (find-nonzero-vector vss) (define n (vector-length vss)) @@ -28,7 +30,10 @@ [else #f]))])) (: subtract-projections! - (case-> ((Vectorof (Vectorof Real)) Nonnegative-Fixnum Index (Vectorof Real) -> Void) + (case-> ((Vectorof (Vectorof Flonum)) Nonnegative-Fixnum Index (Vectorof Flonum) -> Void) + ((Vectorof (Vectorof Real)) Nonnegative-Fixnum Index (Vectorof Real) -> Void) + ((Vectorof (Vectorof Float-Complex)) Nonnegative-Fixnum Index (Vectorof Float-Complex) + -> Void) ((Vectorof (Vectorof Number)) Nonnegative-Fixnum Index (Vectorof Number) -> Void))) (define (subtract-projections! rows i m row) (let loop ([#{i : Nonnegative-Fixnum} i]) @@ -36,7 +41,9 @@ (vector-sub-proj! (unsafe-vector-ref rows i) row #f) (loop (fx+ i 1))))) -(: matrix-gram-schmidt/ns (case-> ((Matrix Real) Any Integer -> (Array Real)) +(: matrix-gram-schmidt/ns (case-> ((Matrix Flonum) Any Integer -> (Array Flonum)) + ((Matrix Real) Any Integer -> (Array Real)) + ((Matrix Float-Complex) Any Integer -> (Array Float-Complex)) ((Matrix Number) Any Integer -> (Array Number)))) ;; Performs Gram-Schmidt orthogonalization on M, assuming the rows before `start' are already ;; orthogonal @@ -60,31 +67,46 @@ [else (matrix-transpose (vector*->matrix (list->vector (reverse bs))))]))] [else - (make-array (vector (matrix-num-rows M) 0) 0)])) + (make-array (vector (matrix-num-rows M) 0) + ;; Value won't be in the matrix, but this satisfies TR: + (zero* (unsafe-vector2d-ref rows 0 0)))])) -(: matrix-gram-schmidt (case-> ((Matrix Real) -> (Array Real)) +(: matrix-gram-schmidt (case-> ((Matrix Flonum) -> (Array Flonum)) + ((Matrix Flonum) Any -> (Array Flonum)) + ((Matrix Flonum) Any Integer -> (Array Flonum)) + ((Matrix Real) -> (Array Real)) ((Matrix Real) Any -> (Array Real)) ((Matrix Real) Any Integer -> (Array Real)) + ((Matrix Float-Complex) -> (Array Float-Complex)) + ((Matrix Float-Complex) Any -> (Array Float-Complex)) + ((Matrix Float-Complex) Any Integer -> (Array Float-Complex)) ((Matrix Number) -> (Array Number)) ((Matrix Number) Any -> (Array Number)) ((Matrix Number) Any Integer -> (Array Number)))) (define (matrix-gram-schmidt M [normalize? #f] [start 0]) (call/ns (λ () (matrix-gram-schmidt/ns M normalize? start)))) -(: matrix-basis-extension/ns (case-> ((Matrix Real) -> (Array Real)) +(: matrix-basis-extension/ns (case-> ((Matrix Flonum) -> (Array Flonum)) + ((Matrix Real) -> (Array Real)) + ((Matrix Float-Complex) -> (Array Float-Complex)) ((Matrix Number) -> (Array Number)))) (define (matrix-basis-extension/ns B) (define-values (m n) (matrix-shape B)) + (define x00 (matrix-ref B 0 0)) + (define zero (zero* x00)) + (define one (one* x00)) (cond [(n . < . m) - (define S (matrix-gram-schmidt (matrix-augment (list B (identity-matrix m))) #f n)) + (define S (matrix-gram-schmidt (matrix-augment (list B (identity-matrix m one zero))) #f n)) (define R (submatrix S (::) (:: n #f))) (matrix-augment (take (sort/key (matrix-cols R) > matrix-norm) (- m n)))] [(n . = . m) - (make-array (vector m 0) 0)] + (make-array (vector m 0) zero)] [else (raise-argument-error 'matrix-extend-row-basis "matrix? with width < height" B)])) -(: matrix-basis-extension (case-> ((Matrix Real) -> (Array Real)) +(: matrix-basis-extension (case-> ((Matrix Flonum) -> (Array Flonum)) + ((Matrix Real) -> (Array Real)) + ((Matrix Float-Complex) -> (Array Float-Complex)) ((Matrix Number) -> (Array Number)))) (define (matrix-basis-extension B) (call/ns (λ () (matrix-basis-extension/ns B)))) diff --git a/collects/math/private/matrix/matrix-lu.rkt b/collects/math/private/matrix/matrix-lu.rkt index 50ad6a7893..ef90a61942 100644 --- a/collects/math/private/matrix/matrix-lu.rkt +++ b/collects/math/private/matrix/matrix-lu.rkt @@ -8,15 +8,22 @@ "../unsafe.rkt" "../vector/vector-mutate.rkt" "../array/mutable-array.rkt" - "../array/array-struct.rkt") + "../array/array-struct.rkt" + "../array/array-pointwise.rkt") (provide matrix-lu) ;; An LU factorization exists iff Gaussian elimination can be done without row swaps. (: matrix-lu - (All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) + (All (A) (case-> ((Matrix Flonum) -> (Values (Matrix Flonum) (Matrix Flonum))) + ((Matrix Flonum) (-> A) -> (Values (U A (Matrix Flonum)) (Matrix Flonum))) + ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) ((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real))) + ((Matrix Float-Complex) -> (Values (Matrix Float-Complex) + (Matrix Float-Complex))) + ((Matrix Float-Complex) (-> A) -> (Values (U A (Matrix Float-Complex)) + (Matrix Float-Complex))) ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) ((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number)))))) (define matrix-lu @@ -28,7 +35,7 @@ (define L (parameterize ([array-strictness #f]) ;; Construct L in a weird way to prove to TR that it has the right type - (array->mutable-array (matrix-scale M (ann 0 Real))))) + (array->mutable-array (inline-array-map zero* M)))) ;; Going to fill in the lower triangle by banging values into `ys' (define ys (mutable-array-data L)) (let loop ([#{i : Nonnegative-Fixnum} 0]) @@ -51,12 +58,13 @@ ;; Add row i, scaled (vector-scaled-add! (unsafe-vector-ref rows l) (unsafe-vector-ref rows i) - (- y_li))) + (* -1 y_li))) (l-loop (fx+ l 1))] [else (loop (fx+ i 1))]))])] [else ;; L's lower triangle has been filled; now fill the diagonal with 1s (for: ([i : Integer (in-range 0 m)]) - (vector-set! ys (+ (* i m) i) 1)) + (define j (+ (* i m) i)) + (vector-set! ys j (one* (vector-ref ys j)))) (values L (vector*->matrix rows))]))])) diff --git a/collects/math/private/matrix/matrix-operator-norm.rkt b/collects/math/private/matrix/matrix-operator-norm.rkt index 8a34eab489..e6a5c8aa6c 100644 --- a/collects/math/private/matrix/matrix-operator-norm.rkt +++ b/collects/math/private/matrix/matrix-operator-norm.rkt @@ -42,33 +42,46 @@ See "How to Measure Errors" in the LAPACK manual for more details: matrix-orthonormal? ) -(: matrix-op-1norm ((Matrix Number) -> Nonnegative-Real)) +(: matrix-op-1norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Real) -> Nonnegative-Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Number) -> Nonnegative-Real))) ;; When M is a column matrix, this is equivalent to matrix-1norm (define (matrix-op-1norm M) (parameterize ([array-strictness #f]) (assert (apply max (map matrix-1norm (matrix-cols M))) nonnegative?))) -(: matrix-op-2norm ((Matrix Number) -> Nonnegative-Real)) +(: matrix-op-2norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Real) -> Nonnegative-Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Number) -> Nonnegative-Real))) ;; When M is a column matrix, this is equivalent to matrix-2norm (define (matrix-op-2norm M) ;(matrix-max-singular-value M) ;(sqrt (matrix-max-eigenvalue M)) (error 'unimplemented)) -(: matrix-op-inf-norm ((Matrix Number) -> Nonnegative-Real)) +(: matrix-op-inf-norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum) + ((Matrix Real) -> Nonnegative-Real) + ((Matrix Float-Complex) -> Nonnegative-Flonum) + ((Matrix Number) -> Nonnegative-Real))) ;; When M is a column matrix, this is equivalent to matrix-inf-norm (define (matrix-op-inf-norm M) (parameterize ([array-strictness #f]) (assert (apply max (map matrix-1norm (matrix-rows M))) nonnegative?))) -(: matrix-basis-cos-angle (case-> ((Matrix Real) (Matrix Real) -> Real) +(: matrix-basis-cos-angle (case-> ((Matrix Flonum) (Matrix Flonum) -> Flonum) + ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex) ((Matrix Number) (Matrix Number) -> Number))) ;; Returns the angle between the two subspaces spanned by the two given sets of column vectors (define (matrix-basis-cos-angle M R) ;(matrix-min-singular-value (matrix* (matrix-hermitian M) R)) (error 'unimplemented)) -(: matrix-basis-angle (case-> ((Matrix Real) (Matrix Real) -> Real) +(: matrix-basis-angle (case-> ((Matrix Flonum) (Matrix Flonum) -> Flonum) + ((Matrix Real) (Matrix Real) -> Real) + ((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex) ((Matrix Number) (Matrix Number) -> Number))) ;; Returns the angle between the two subspaces spanned by the two given sets of column vectors (define (matrix-basis-angle M R) diff --git a/collects/math/private/matrix/matrix-qr.rkt b/collects/math/private/matrix/matrix-qr.rkt index e711793d47..6c09e87423 100644 --- a/collects/math/private/matrix/matrix-qr.rkt +++ b/collects/math/private/matrix/matrix-qr.rkt @@ -5,6 +5,7 @@ "matrix-arithmetic.rkt" "matrix-constructors.rkt" "matrix-gram-schmidt.rkt" + "utils.rkt" "../array/array-transform.rkt" "../array/array-struct.rkt") @@ -24,21 +25,33 @@ produces matrices for which `matrix-orthogonal?' returns #t with eps <= 10*epsil independently of the matrix size. |# -(: matrix-qr/ns (case-> ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real))) +(: matrix-qr/ns (case-> ((Matrix Flonum) Any -> (Values (Matrix Flonum) (Matrix Flonum))) + ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Float-Complex) Any -> (Values (Matrix Float-Complex) + (Matrix Float-Complex))) ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))) (define (matrix-qr/ns M full?) + (define x00 (matrix-ref M 0 0)) + (define zero (zero* x00)) + (define one (one* x00)) (define B (matrix-gram-schmidt M #f)) (define Q (matrix-gram-schmidt (cond [(or (square-matrix? B) (and (matrix? B) (not full?))) B] [(matrix? B) (array-append* (list B (matrix-basis-extension B)) 1)] - [full? (identity-matrix (matrix-num-rows M))] - [else (matrix-col (identity-matrix (matrix-num-rows M)) 0)]) + [full? (identity-matrix (matrix-num-rows M) one zero)] + [else (matrix-col (identity-matrix (matrix-num-rows M) one zero) 0)]) #t)) - (values Q (matrix-upper-triangle (matrix* (matrix-hermitian Q) M)))) + (values Q (matrix-upper-triangle (matrix* (matrix-hermitian Q) M) zero))) -(: matrix-qr (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) +(: matrix-qr (case-> ((Matrix Flonum) -> (Values (Matrix Flonum) (Matrix Flonum))) + ((Matrix Flonum) Any -> (Values (Matrix Flonum) (Matrix Flonum))) + ((Matrix Real) -> (Values (Matrix Real) (Matrix Real))) ((Matrix Real) Any -> (Values (Matrix Real) (Matrix Real))) + ((Matrix Float-Complex) -> (Values (Matrix Float-Complex) + (Matrix Float-Complex))) + ((Matrix Float-Complex) Any -> (Values (Matrix Float-Complex) + (Matrix Float-Complex))) ((Matrix Number) -> (Values (Matrix Number) (Matrix Number))) ((Matrix Number) Any -> (Values (Matrix Number) (Matrix Number))))) (define (matrix-qr M [full? #t]) diff --git a/collects/math/private/matrix/matrix-solve.rkt b/collects/math/private/matrix/matrix-solve.rkt index e92fdaa120..913d254546 100644 --- a/collects/math/private/matrix/matrix-solve.rkt +++ b/collects/math/private/matrix/matrix-solve.rkt @@ -24,7 +24,9 @@ ;; =================================================================================================== ;; Determinant -(: matrix-determinant (case-> ((Matrix Real) -> Real) +(: matrix-determinant (case-> ((Matrix Flonum) -> Flonum) + ((Matrix Real) -> Real) + ((Matrix Float-Complex) -> Float-Complex) ((Matrix Number) -> Number))) (define (matrix-determinant M) (define m (square-matrix-size M)) @@ -41,20 +43,22 @@ [else (matrix-determinant/row-reduction M)])) -(: matrix-determinant/row-reduction (case-> ((Matrix Real) -> Real) +(: matrix-determinant/row-reduction (case-> ((Matrix Flonum) -> Flonum) + ((Matrix Real) -> Real) + ((Matrix Float-Complex) -> Float-Complex) ((Matrix Number) -> Number))) (define (matrix-determinant/row-reduction M) (define m (square-matrix-size M)) (define rows (matrix->vector* M)) - (let loop ([#{i : Nonnegative-Fixnum} 0] [#{sign : Real} 1]) + (let loop ([#{i : Nonnegative-Fixnum} 0] [#{sign : (U Positive-Fixnum Negative-Fixnum)} 1]) (cond [(i . fx< . m) (define-values (p pivot) (find-partial-pivot rows m i i)) (cond - [(zero? pivot) 0] ; no pivot means non-invertible matrix + [(zero? pivot) pivot] ; no pivot means non-invertible matrix [else (let ([sign (if (= i p) sign (begin (vector-swap! rows i p) ; swapping negates sign - (* -1 sign)))]) + (if (= sign 1) -1 1)))]) (elim-rows! rows m i i pivot (fx+ i 1)) ; adding scaled rows doesn't change it (loop (fx+ i 1) sign))])] [else @@ -72,8 +76,12 @@ (and (square-matrix? M) (not (zero? (matrix-determinant M))))) -(: matrix-inverse (All (A) (case-> ((Matrix Real) -> (Matrix Real)) +(: matrix-inverse (All (A) (case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Flonum) (-> A) -> (U A (Matrix Flonum))) + ((Matrix Real) -> (Matrix Real)) ((Matrix Real) (-> A) -> (U A (Matrix Real))) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) + ((Matrix Float-Complex) (-> A) -> (U A (Matrix Float-Complex))) ((Matrix Number) -> (Matrix Number)) ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) (define matrix-inverse @@ -81,7 +89,8 @@ [(M) (matrix-inverse M (λ () (raise-argument-error 'matrix-inverse "matrix-invertible?" M)))] [(M fail) (define m (square-matrix-size M)) - (define I (identity-matrix m)) + (define x00 (matrix-ref M 0 0)) + (define I (identity-matrix m (one* x00) (zero* x00))) (define-values (IM^-1 wps) (parameterize ([array-strictness #f]) (matrix-gauss-elim (matrix-augment (list M I)) #t #t))) (cond [(and (not (empty? wps)) (= (first wps) m)) @@ -91,11 +100,16 @@ ;; =================================================================================================== ;; Solving linear systems -(: matrix-solve (All (A) (case-> - ((Matrix Real) (Matrix Real) -> (Matrix Real)) - ((Matrix Real) (Matrix Real) (-> A) -> (U A (Matrix Real))) - ((Matrix Number) (Matrix Number) -> (Matrix Number)) - ((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number)))))) +(: matrix-solve + (All (A) (case-> + ((Matrix Flonum) (Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Flonum) (Matrix Flonum) (-> A) -> (U A (Matrix Flonum))) + ((Matrix Real) (Matrix Real) -> (Matrix Real)) + ((Matrix Real) (Matrix Real) (-> A) -> (U A (Matrix Real))) + ((Matrix Float-Complex) (Matrix Float-Complex) -> (Matrix Float-Complex)) + ((Matrix Float-Complex) (Matrix Float-Complex) (-> A) -> (U A (Matrix Float-Complex))) + ((Matrix Number) (Matrix Number) -> (Matrix Number)) + ((Matrix Number) (Matrix Number) (-> A) -> (U A (Matrix Number)))))) (define matrix-solve (case-lambda [(M B) (matrix-solve M B (λ () (raise-argument-error 'matrix-solve "matrix-invertible?" 0 M B)))] diff --git a/collects/math/private/matrix/matrix-subspace.rkt b/collects/math/private/matrix/matrix-subspace.rkt index 07662d9e10..8202935b88 100644 --- a/collects/math/private/matrix/matrix-subspace.rkt +++ b/collects/math/private/matrix/matrix-subspace.rkt @@ -35,7 +35,9 @@ (cond [(= j0 j1) Bs] [else (cons (submatrix M (::) (:: j0 j1)) Bs)])) -(: matrix-col-space/ns (All (A) (case-> ((Matrix Real) -> (U #f (Matrix Real))) +(: matrix-col-space/ns (All (A) (case-> ((Matrix Flonum) -> (U #f (Matrix Flonum))) + ((Matrix Real) -> (U #f (Matrix Real))) + ((Matrix Float-Complex) -> (U #f (Matrix Float-Complex))) ((Matrix Number) -> (U #f (Matrix Number)))))) (define (matrix-col-space/ns M) (define n (matrix-num-cols M)) @@ -52,13 +54,19 @@ (define next-j (first wps)) (loop next-j (rest wps) (maybe-cons-submatrix M (fx+ j 1) next-j Bs))]))])) -(: matrix-col-space (All (A) (case-> ((Matrix Real) -> (Matrix Real)) +(: matrix-col-space (All (A) (case-> ((Matrix Flonum) -> (Array Flonum)) + ((Matrix Flonum) (-> A) -> (U A (Matrix Flonum))) + ((Matrix Real) -> (Array Real)) ((Matrix Real) (-> A) -> (U A (Matrix Real))) - ((Matrix Number) -> (Matrix Number)) + ((Matrix Float-Complex) -> (Array Float-Complex)) + ((Matrix Float-Complex) (-> A) -> (U A (Matrix Float-Complex))) + ((Matrix Number) -> (Array Number)) ((Matrix Number) (-> A) -> (U A (Matrix Number)))))) (define matrix-col-space (case-lambda - [(M) (matrix-col-space M (λ () (make-array (vector 0 (matrix-num-cols M)) 0)))] + [(M) (matrix-col-space M (λ () + (define x00 (matrix-ref M 0 0)) + (make-array (vector 0 (matrix-num-cols M)) (zero* x00))))] [(M fail) (define S (parameterize ([array-strictness #f]) (matrix-col-space/ns M))) diff --git a/collects/math/private/matrix/typed-matrix-arithmetic.rkt b/collects/math/private/matrix/typed-matrix-arithmetic.rkt index 516d62ff54..bedcb1b6fb 100644 --- a/collects/math/private/matrix/typed-matrix-arithmetic.rkt +++ b/collects/math/private/matrix/typed-matrix-arithmetic.rkt @@ -65,46 +65,65 @@ (loop (first arrs) (rest arrs)))])))])) -(: matrix*/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) - ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) +(: matrix*/ns + (case-> ((Matrix Flonum) (Listof (Matrix Flonum)) -> (Matrix Flonum)) + ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) + ((Matrix Float-Complex) (Listof (Matrix Float-Complex)) -> (Matrix Float-Complex)) + ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) (define (matrix*/ns a as) (cond [(empty? as) a] [else (matrix*/ns (inline-matrix-multiply a (first as)) (rest as))])) -(: matrix* (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) +(: matrix* (case-> ((Matrix Flonum) (Matrix Flonum) * -> (Matrix Flonum)) + ((Matrix Real) (Matrix Real) * -> (Matrix Real)) + ((Matrix Float-Complex) (Matrix Float-Complex) * -> (Matrix Float-Complex)) ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) (define (matrix* a . as) (call/ns (λ () (matrix*/ns a as)))) -(: matrix+/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) - ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) +(: matrix+/ns + (case-> ((Matrix Flonum) (Listof (Matrix Flonum)) -> (Matrix Flonum)) + ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) + ((Matrix Float-Complex) (Listof (Matrix Float-Complex)) -> (Matrix Float-Complex)) + ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) (define (matrix+/ns a as) (cond [(empty? as) a] [else (matrix+/ns (inline-matrix+ a (first as)) (rest as))])) -(: matrix+ (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) +(: matrix+ (case-> ((Matrix Flonum) (Matrix Flonum) * -> (Matrix Flonum)) + ((Matrix Real) (Matrix Real) * -> (Matrix Real)) + ((Matrix Float-Complex) (Matrix Float-Complex) * -> (Matrix Float-Complex)) ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) (define (matrix+ a . as) (call/ns (λ () (matrix+/ns a as)))) -(: matrix-/ns (case-> ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) - ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) +(: matrix-/ns + (case-> ((Matrix Flonum) (Listof (Matrix Flonum)) -> (Matrix Flonum)) + ((Matrix Real) (Listof (Matrix Real)) -> (Matrix Real)) + ((Matrix Float-Complex) (Listof (Matrix Float-Complex)) -> (Matrix Float-Complex)) + ((Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)))) (define (matrix-/ns a as) (cond [(empty? as) a] [else (matrix-/ns (inline-matrix- a (first as)) (rest as))])) -(: matrix- (case-> ((Matrix Real) (Matrix Real) * -> (Matrix Real)) +(: matrix- (case-> ((Matrix Flonum) (Matrix Flonum) * -> (Matrix Flonum)) + ((Matrix Real) (Matrix Real) * -> (Matrix Real)) + ((Matrix Float-Complex) (Matrix Float-Complex) * -> (Matrix Float-Complex)) ((Matrix Number) (Matrix Number) * -> (Matrix Number)))) (define (matrix- a . as) - (call/ns (λ () (cond [(empty? as) (inline-matrix- a)] + (call/ns (λ () (cond [(empty? as) (inline-matrix-scale a -1)] [else (matrix-/ns a as)])))) -(: matrix-scale (case-> ((Matrix Real) Real -> (Matrix Real)) +(: matrix-scale (case-> ((Matrix Flonum) Flonum -> (Matrix Flonum)) + ((Matrix Real) Real -> (Matrix Real)) + ((Matrix Float-Complex) Float-Complex -> (Matrix Float-Complex)) ((Matrix Number) Number -> (Matrix Number)))) (define (matrix-scale a x) (inline-matrix-scale a x)) -(: matrix-sum (case-> ((Listof (Matrix Real)) -> (Matrix Real)) +(: matrix-sum (case-> ((Listof (Matrix Flonum)) -> (Matrix Flonum)) + ((Listof (Matrix Real)) -> (Matrix Real)) + ((Listof (Matrix Float-Complex)) -> (Matrix Float-Complex)) ((Listof (Matrix Number)) -> (Matrix Number)))) (define (matrix-sum lst) (cond [(empty? lst) (raise-argument-error 'matrix-sum "nonempty List" lst)] diff --git a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt index 326a25bfd3..370b2a4a38 100644 --- a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt +++ b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt @@ -8,39 +8,67 @@ inline-matrix-map matrix-map) -(module syntax-defs racket/base - (require (for-syntax racket/base) - (only-in typed/racket/base λ: : inst Index) +(module typed-multiply-defs typed/racket/base + (require racket/fixnum "matrix-types.rkt" "utils.rkt" + "../unsafe.rkt" "../array/array-struct.rkt" - "../array/array-fold.rkt" "../array/array-transform.rkt" + "../array/mutable-array.rkt" "../array/utils.rkt") (provide (all-defined-out)) - ;(: matrix-multiply ((Matrix Number) (Matrix Number) -> (Matrix Number))) + (: matrix-multiply-data (All (A) ((Matrix A) (Matrix A) -> (Values Index Index Index + (Vectorof A) (Vectorof A) + (-> (Boxof A)))))) + (define (matrix-multiply-data arr brr) + (let-values ([(m p n) (matrix-multiply-shape arr brr)]) + (define arr-data (mutable-array-data (array->mutable-array arr))) + (define brr-data (mutable-array-data (array->mutable-array (parameterize ([array-strictness #f]) + (array-axis-swap brr 0 1))))) + (define bx (make-thread-local-box (unsafe-vector-ref arr-data 0))) + (values m p n arr-data brr-data bx))) + + (: make-matrix-multiply (All (A) (Index Index Index (Index Index -> A) -> (Matrix A)))) + (define (make-matrix-multiply m p n sum-loop) + (array-default-strict + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([ij : Indexes]) + (sum-loop (assert (fx* (unsafe-vector-ref ij 0) p) index?) + (assert (fx* (unsafe-vector-ref ij 1) p) index?)))))) + ) ; module + +(module untyped-multiply-defs racket/base + (require (for-syntax racket/base) + racket/fixnum + racket/unsafe/ops + (only-in typed/racket/base λ: : Index let: Nonnegative-Fixnum) + (submod ".." typed-multiply-defs) + "matrix-types.rkt" + "utils.rkt" + "../array/array-struct.rkt") + + (provide (all-defined-out)) + ;; This is a macro so the result can have as precise a type as possible - (define-syntax-rule (inline-matrix-multiply arr-expr brr-expr) - (let ([arr arr-expr] - [brr brr-expr]) - (let-values ([(m p n) (matrix-multiply-shape arr brr)] - ;; Make arr strict because its elements are reffed multiple times - [(_) (array-strict! arr)]) - (let (;; Extend arr in the center dimension - [arr-proc (unsafe-array-proc (array-axis-insert arr 1 n))] - ;; Transpose brr and extend in the leftmost dimension - [brr-proc (unsafe-array-proc - (array-axis-insert (array-strict (array-axis-swap brr 0 1)) 0 m))]) - ;; The *transpose* of brr is traversed in row-major order when this result is traversed - ;; in row-major order (which is why the transpose is strict, not brr) - (array-axis-sum - (unsafe-build-array - ((inst vector Index) m n p) - (λ: ([js : Indexes]) - (* (arr-proc js) (brr-proc js)))) - 2))))) + (define-syntax-rule (inline-matrix-multiply arr brr) + (let-values ([(m p n arr-data brr-data bx) (matrix-multiply-data arr brr)]) + (make-matrix-multiply + m p n + (λ: ([i : Index] [j : Index]) + (let ([bx (bx)] + [v (* (unsafe-vector-ref arr-data i) + (unsafe-vector-ref brr-data j))]) + (let: loop ([k : Nonnegative-Fixnum 1] [v v]) + (cond [(k . fx< . p) + (loop (fx+ k 1) + (+ v (* (unsafe-vector-ref arr-data (fx+ i k)) + (unsafe-vector-ref brr-data (fx+ j k)))))] + [else (set-box! bx v)])) + (unbox bx)))))) (define-syntax (do-inline-matrix* stx) (syntax-case stx () @@ -54,6 +82,19 @@ (parameterize ([array-strictness #f]) (do-inline-matrix* arr brrs ...)))) + ) ; module + +(module syntax-defs racket/base + (require (for-syntax racket/base) + (only-in typed/racket/base λ: : inst Index) + (submod ".." typed-multiply-defs) + "matrix-types.rkt" + "utils.rkt" + "../array/array-struct.rkt" + "../array/utils.rkt") + + (provide (all-defined-out)) + (define-syntax (inline-matrix-map stx) (syntax-case stx () [(_ f arr-expr) @@ -117,5 +158,6 @@ ) ; module -(require 'syntax-defs +(require 'untyped-multiply-defs + 'syntax-defs 'untyped-defs) diff --git a/collects/math/private/matrix/utils.rkt b/collects/math/private/matrix/utils.rkt index 21ca2fde64..032d291c65 100644 --- a/collects/math/private/matrix/utils.rkt +++ b/collects/math/private/matrix/utils.rkt @@ -3,6 +3,7 @@ (require racket/performance-hint racket/string racket/fixnum + math/base "matrix-types.rkt" "../unsafe.rkt" "../array/array-struct.rkt" @@ -69,7 +70,9 @@ (rational? (imag-part z)))]))) (: find-partial-pivot - (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real)) + (case-> ((Vectorof (Vectorof Flonum)) Index Index Index -> (Values Index Flonum)) + ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real)) + ((Vectorof (Vectorof Float-Complex)) Index Index Index -> (Values Index Float-Complex)) ((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number)))) ;; Find the element with maximum magnitude in a column (define (find-partial-pivot rows m i j) @@ -85,7 +88,9 @@ [else (values p pivot)]))) (: find-first-pivot - (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real)) + (case-> ((Vectorof (Vectorof Flonum)) Index Index Index -> (Values Index Flonum)) + ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real)) + ((Vectorof (Vectorof Float-Complex)) Index Index Index -> (Values Index Float-Complex)) ((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number)))) ;; Find the first nonzero element in a column (define (find-first-pivot rows m i j) @@ -100,7 +105,10 @@ (values i pivot)])))) (: elim-rows! - (case-> ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void) + (case-> ((Vectorof (Vectorof Flonum)) Index Index Index Flonum Nonnegative-Fixnum -> Void) + ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void) + ((Vectorof (Vectorof Float-Complex)) Index Index Index Float-Complex Nonnegative-Fixnum + -> Void) ((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void))) (define (elim-rows! rows m i j pivot start) (define row_i (unsafe-vector-ref rows i)) @@ -109,8 +117,8 @@ (unless (l . fx= . i) (define row_l (unsafe-vector-ref rows l)) (define x_lj (unsafe-vector-ref row_l j)) - (unless (zero? x_lj) - (vector-scaled-add! row_l row_i (- (/ x_lj pivot)) j) + (unless (= x_lj 0) + (vector-scaled-add! row_l row_i (* -1 (/ x_lj pivot)) j) ;; Make sure the element below the pivot is zero (unsafe-vector-set! row_l j (- x_lj x_lj)))) (loop (fx+ l 1))))) @@ -124,3 +132,51 @@ (thnk)))) ) ; begin-encourage-inline + +(: make-thread-local-box (All (A) (A -> (-> (Boxof A))))) +(define (make-thread-local-box contents) + (let: ([val : (Thread-Cellof (U #f (Boxof A))) (make-thread-cell #f)]) + (λ () (or (thread-cell-ref val) + (let: ([v : (Boxof A) (box contents)]) + (thread-cell-set! val v) + v))))) + +(: one (case-> (Flonum -> Nonnegative-Flonum) + (Real -> (U 1 Nonnegative-Flonum)) + (Float-Complex -> Nonnegative-Flonum) + (Number -> (U 1 Nonnegative-Flonum)))) +(define (one x) + (cond [(flonum? x) 1.0] + [(real? x) 1] + [(float-complex? x) 1.0] + [else 1])) + +(: zero (case-> (Flonum -> Flonum-Positive-Zero) + (Real -> (U 0 Flonum-Positive-Zero)) + (Float-Complex -> Flonum-Positive-Zero) + (Number -> (U 0 Flonum-Positive-Zero)))) +(define (zero x) + (cond [(flonum? x) 0.0] + [(real? x) 0] + [(float-complex? x) 0.0] + [else 0])) + +(: one* (case-> (Flonum -> Nonnegative-Flonum) + (Real -> (U 1 Nonnegative-Flonum)) + (Float-Complex -> Float-Complex) + (Number -> (U 1 Nonnegative-Flonum Float-Complex)))) +(define (one* x) + (cond [(flonum? x) 1.0] + [(real? x) 1] + [(float-complex? x) 1.0+0.0i] + [else 1])) + +(: zero* (case-> (Flonum -> Flonum-Positive-Zero) + (Real -> (U 0 Flonum-Positive-Zero)) + (Float-Complex -> Float-Complex) + (Number -> (U 0 Flonum-Positive-Zero Float-Complex)))) +(define (zero* x) + (cond [(flonum? x) 0.0] + [(real? x) 0] + [(float-complex? x) 0.0+0.0i] + [else 0])) diff --git a/collects/math/private/vector/vector-mutate.rkt b/collects/math/private/vector/vector-mutate.rkt index 8331d0c359..59cb2d8b14 100644 --- a/collects/math/private/vector/vector-mutate.rkt +++ b/collects/math/private/vector/vector-mutate.rkt @@ -1,7 +1,7 @@ #lang typed/racket/base (require racket/fixnum - racket/math + math/base math/private/unsafe) (provide vector-swap! @@ -14,9 +14,12 @@ vector-zero! vector-zero?) -(: mag^2 (Number -> Nonnegative-Real)) +(: mag^2 (case-> (Flonum -> Nonnegative-Flonum) + (Float-Complex -> Nonnegative-Flonum) + (Number -> Nonnegative-Real))) (define (mag^2 x) - (max 0 (real-part (* x (conjugate x))))) + (cond [(real? x) (sqr x)] + [else (abs (real-part (* x (conjugate x))))])) (: vector-swap! (All (A) ((Vectorof A) Integer Integer -> Void))) (define (vector-swap! vs i0 i1) @@ -35,7 +38,9 @@ (loop (fx+ i 1))) (void))))) -(: vector-scale! (case-> ((Vectorof Real) Real -> Void) +(: vector-scale! (case-> ((Vectorof Flonum) Flonum -> Void) + ((Vectorof Real) Real -> Void) + ((Vectorof Float-Complex) Float-Complex -> Void) ((Vectorof Number) Number -> Void))) (define (vector-scale! vs v) (vector-generic-scale! vs v *)) @@ -52,25 +57,38 @@ (loop (fx+ i 1))) (void))))) -(: vector-scaled-add! (case-> ((Vectorof Real) (Vectorof Real) Real -> Void) - ((Vectorof Real) (Vectorof Real) Real Index -> Void) - ((Vectorof Number) (Vectorof Number) Number -> Void) - ((Vectorof Number) (Vectorof Number) Number Index -> Void))) +(: vector-scaled-add! + (case-> ((Vectorof Flonum) (Vectorof Flonum) Flonum -> Void) + ((Vectorof Flonum) (Vectorof Flonum) Flonum Index -> Void) + ((Vectorof Real) (Vectorof Real) Real -> Void) + ((Vectorof Real) (Vectorof Real) Real Index -> Void) + ((Vectorof Float-Complex) (Vectorof Float-Complex) Float-Complex -> Void) + ((Vectorof Float-Complex) (Vectorof Float-Complex) Float-Complex Index -> Void) + ((Vectorof Number) (Vectorof Number) Number -> Void) + ((Vectorof Number) (Vectorof Number) Number Index -> Void))) (define (vector-scaled-add! vs0 vs1 s [start 0]) (vector-generic-scaled-add! vs0 vs1 s start + *)) -(: vector-mag^2 (case-> ((Vectorof Real) -> Nonnegative-Real) +(: vector-mag^2 (case-> ((Vectorof Flonum) -> Nonnegative-Flonum) + ((Vectorof Real) -> Nonnegative-Real) + ((Vectorof Float-Complex) -> Nonnegative-Flonum) ((Vectorof Number) -> Nonnegative-Real))) (define (vector-mag^2 vs) (define n (vector-length vs)) - (let loop ([#{i : Nonnegative-Fixnum} 0] [#{s : Nonnegative-Real} 0]) - (if (i . fx>= . n) s (loop (fx+ i 1) (+ s (mag^2 (unsafe-vector-ref vs i))))))) + (cond [(fx= n 0) (raise-argument-error 'vector-mag^2 "nonempty Vector" vs)] + [else + (define s (mag^2 (unsafe-vector-ref vs 0))) + (let: loop ([i : Nonnegative-Fixnum 1] [s s]) + (cond [(i . fx< . n) (loop (fx+ i 1) (+ s (mag^2 (unsafe-vector-ref vs i))))] + [else (abs s)]))])) -(: vector-dot (case-> ((Vectorof Real) (Vectorof Real) -> Real) +(: vector-dot (case-> ((Vectorof Flonum) (Vectorof Flonum) -> Flonum) + ((Vectorof Real) (Vectorof Real) -> Real) + ((Vectorof Float-Complex) (Vectorof Float-Complex) -> Float-Complex) ((Vectorof Number) (Vectorof Number) -> Number))) (define (vector-dot vs0 vs1) (define n (min (vector-length vs0) (vector-length vs1))) - (cond [(= n 0) 0] + (cond [(fx= n 0) (raise-argument-error 'vector-dot "nonempty Vector" 0 vs0 vs1)] [else (define v0 (unsafe-vector-ref vs0 0)) (define v1 (unsafe-vector-ref vs1 0)) @@ -81,7 +99,9 @@ (loop (fx+ i 1) (+ s (* v0 (conjugate v1))))] [else s]))])) -(: vector-normalize! (case-> ((Vectorof Real) -> Nonnegative-Real) +(: vector-normalize! (case-> ((Vectorof Flonum) -> Nonnegative-Flonum) + ((Vectorof Real) -> Nonnegative-Real) + ((Vectorof Float-Complex) -> Nonnegative-Flonum) ((Vectorof Number) -> Nonnegative-Real))) (define (vector-normalize! vs) (define n (vector-length vs)) @@ -93,31 +113,51 @@ (loop (fx+ i 1))))) s) -(: vector-sub-proj! (case-> ((Vectorof Real) (Vectorof Real) Any -> Nonnegative-Real) - ((Vectorof Number) (Vectorof Number) Any -> Nonnegative-Real))) +(: one (case-> (Flonum -> Nonnegative-Flonum) + (Real -> Nonnegative-Real) + (Float-Complex -> Nonnegative-Flonum) + (Number -> Nonnegative-Real))) +(define (one x) + (cond [(flonum? x) 1.0] + [(real? x) 1] + [(float-complex? x) 1.0] + [else 1])) + +(: vector-sub-proj! + (case-> ((Vectorof Flonum) (Vectorof Flonum) Any -> Nonnegative-Flonum) + ((Vectorof Real) (Vectorof Real) Any -> Nonnegative-Real) + ((Vectorof Float-Complex) (Vectorof Float-Complex) Any -> Nonnegative-Flonum) + ((Vectorof Number) (Vectorof Number) Any -> Nonnegative-Real))) (define (vector-sub-proj! vs0 vs1 unit?) (define n (min (vector-length vs0) (vector-length vs1))) - (define t (if unit? 1 (vector-mag^2 vs1))) - (unless (and (zero? t) (exact? t)) - (define s (/ (vector-dot vs0 vs1) t)) - (let loop ([#{i : Nonnegative-Fixnum} 0]) - (when (i . fx< . n) - (define v0 (unsafe-vector-ref vs0 i)) - (define v1 (unsafe-vector-ref vs1 i)) - (unsafe-vector-set! vs0 i (- v0 (* v1 s))) - (loop (fx+ i 1))))) - t) + (cond [(fx= n 0) (raise-argument-error 'vector-sub-proj! "nonempty Vector" 0 vs0 vs1)] + [else + (define t (if unit? (one (unsafe-vector-ref vs0 0)) (vector-mag^2 vs1))) + (unless (and (zero? t) (exact? t)) + (define s (/ (vector-dot vs0 vs1) t)) + (let loop ([#{i : Nonnegative-Fixnum} 0]) + (when (i . fx< . n) + (define v0 (unsafe-vector-ref vs0 i)) + (define v1 (unsafe-vector-ref vs1 i)) + (unsafe-vector-set! vs0 i (- v0 (* v1 s))) + (loop (fx+ i 1))))) + t])) -(: vector-zero! (case-> ((Vectorof Real) -> Void) +(: vector-zero! (case-> ((Vectorof Flonum) -> Void) + ((Vectorof Real) -> Void) + ((Vectorof Float-Complex) -> Void) ((Vectorof Number) -> Void))) (define (vector-zero! vs) (define n (vector-length vs)) (let loop ([#{i : Nonnegative-Fixnum} 0]) (when (i . fx< . n) - (unsafe-vector-set! vs i 0) + (define x (unsafe-vector-ref vs i)) + (unsafe-vector-set! vs i (- x x)) (loop (fx+ i 1))))) -(: vector-zero? (case-> ((Vectorof Real) -> Boolean) +(: vector-zero? (case-> ((Vectorof Flonum) -> Boolean) + ((Vectorof Real) -> Boolean) + ((Vectorof Float-Complex) -> Boolean) ((Vectorof Number) -> Boolean))) (define (vector-zero? vs) (define n (vector-length vs)) diff --git a/collects/math/scribblings/math-matrix.scrbl b/collects/math/scribblings/math-matrix.scrbl index f891d4b8a1..bd9ee42582 100644 --- a/collects/math/scribblings/math-matrix.scrbl +++ b/collects/math/scribblings/math-matrix.scrbl @@ -45,12 +45,13 @@ Like all of @racketmodname[math], @racketmodname[math/matrix] is a work in progr Most of the basic algorithms are implemented, but some are still in planning. Possibly the most useful unimplemented algorithms are @itemlist[@item{LUP decomposition (currently, LU decomposition is implemented, in @racket[matrix-lu])} - @item{@racket[matrix-solve] for upper-triangular matrices} + @item{@racket[matrix-solve] for triangular matrices} @item{Singular value decomposition (SVD)} @item{Eigendecomposition} @item{Decomposition-based solvers} - @item{Pseudoinverse, least-squares fitting}] -All of these are planned for the next Racket release. + @item{Pseudoinverse and least-squares solving}] +All of these are planned for the next Racket release, as well as fast flonum-specific matrix +operations and LAPACK integration. @local-table-of-contents[] @@ -66,8 +67,7 @@ From the point of view of the functions in @racketmodname[math/matrix], a @defte Technically, a matrix's entries may be any type, and some fully polymorphic matrix functions such as @racket[matrix-row] and @racket[matrix-map] operate on any kind of matrix. -Other functions, such as @racket[matrix+], require their input matrices to contain either -@racket[Real] or @racket[Number] values. +Other functions, such as @racket[matrix+], require their matrix arguments to contain numeric values. @subsection[#:tag "matrix:function-types"]{Function Types} @@ -77,17 +77,22 @@ that a return value is a matrix. Most functions that implement matrix algorithms are documented as accepting @racket[(Matrix Number)] values. This includes @racket[(Matrix Real)], which is a subtype. Most of these functions have a more -precise type than is documented. For example, @racket[matrix-conjugate] actually has the type -@racketblock[(case-> ((Matrix Real) -> (Matrix Real)) - ((Matrix Number) -> (Matrix Number)))] -even though it is documented as having the less precise type -@racket[((Matrix Number) -> (Matrix Number))]. +precise type than is documented. For example, @racket[matrix-conjugate] has the type +@racketblock[(case-> ((Matrix Flonum) -> (Matrix Flonum)) + ((Matrix Real) -> (Matrix Real)) + ((Matrix Float-Complex) -> (Matrix Float-Complex)) + ((Matrix Number) -> (Matrix Number)))] +but is documented as having the type @racket[((Matrix Number) -> (Matrix Number))]. Precise function types allow Typed Racket to prove more facts about @racketmodname[math/matrix] -client programs. In particular, it is usually easy for it to prove that matrix expressions on real +client programs. In particular, it is usually easy for it to prove that operations on real matrices return real matrices: @interaction[#:eval typed-eval (matrix-conjugate (matrix [[1 2 3] [4 5 6]]))] +and that operations on inexact matrices return inexact matrices: +@interaction[#:eval typed-eval + (matrix-conjugate (matrix [[1.0+2.0i 2.0+3.0i 3.0+4.0i] + [4.0+5.0i 5.0+6.0i 6.0+7.0i]]))] @subsection[#:tag "matrix:failure"]{Failure Arguments} @@ -102,7 +107,7 @@ For example, the (simplified) type of @racket[matrix-inverse] is Thus, if a failure thunk is given, the call site is required to check for return values of type @racket[F] explicitly. -The default failure thunk, which raises an error, has type @racket[(-> Nothing)]. +Default failure thunks usually raise an error, and have the type @racket[(-> Nothing)]. For such failure thunks, @racket[(U F (Matrix Number))] is equivalent to @racket[(Matrix Number)], because @racket[Nothing] is part of every type. (In Racket, any expression may raise an error.) Thus, in this case, no explicit test for values of type @racket[F] is necessary (though of course they @@ -110,8 +115,8 @@ may be caught using @racket[with-handlers] or similar). @subsection[#:tag "matrix:broadcasting"]{Broadcasting} -Pointwise matrix operations do not @tech{broadcast} their arguments when given matrices with -different sizes: +Unlike array operations, pointwise matrix operations @bold{do not} @tech{broadcast} their arguments +when given matrices with different axis lengths: @interaction[#:eval typed-eval (matrix+ (identity-matrix 2) (matrix [[10]]))] If you need broadcasting, use array operations: @@ -217,8 +222,12 @@ Like @racket[matrix], but returns a @tech{column matrix}. (col-matrix [])] } -@defproc[(identity-matrix [n Integer]) (Matrix (U 0 1))]{ -Returns an @racket[n]×@racket[n] identity matrix; @racket[n] must be positive. +@defproc[(identity-matrix [n Integer] [one A 1] [zero A 0]) (Matrix A)]{ +Returns an @racket[n]×@racket[n] identity matrix, which has the value @racket[one] on the diagonal +and @racket[zero] everywhere else. The height/width @racket[n] must be positive. +@examples[#:eval typed-eval + (identity-matrix 3) + (identity-matrix 4 1.0+0.0i 0.0+0.0i)] } @defproc[(make-matrix [m Integer] [n Integer] [x A]) (Matrix A)]{ @@ -233,22 +242,28 @@ both @racket[m] and @racket[n] must be positive. Analogous to @racket[build-array] (and defined in terms of it). } -@defproc[(diagonal-matrix [xs (Listof A)]) (Matrix (U A 0))]{ -Returns a matrix with @racket[xs] along the diagonal and @racket[0] everywhere else. +@defproc[(diagonal-matrix [xs (Listof A)] [zero A 0]) (Matrix A)]{ +Returns a matrix with @racket[xs] along the diagonal and @racket[zero] everywhere else. The length of @racket[xs] must be positive. +@examples[#:eval typed-eval + (diagonal-matrix '(1 2 3 4 5 6)) + (diagonal-matrix '(1.0 2.0 3.0 4.0 5.0) 0.0)] } @define[block-diagonal-url]{http://en.wikipedia.org/wiki/Block_matrix#Block_diagonal_matrices} @margin-note*{@hyperlink[block-diagonal-url]{Wikipedia: Block-diagonal matrices}} -@defproc[(block-diagonal-matrix [Xs (Listof (Matrix A))]) (Matrix (U A 0))]{ -Returns a matrix with matrices @racket[Xs] along the diagonal and @racket[0] everywhere else. +@defproc[(block-diagonal-matrix [Xs (Listof (Matrix A))] [zero A 0]) (Matrix A)]{ +Returns a matrix with matrices @racket[Xs] along the diagonal and @racket[zero] everywhere else. The length of @racket[Xs] must be positive. @examples[#:eval typed-eval (block-diagonal-matrix (list (matrix [[6 7] [8 9]]) (diagonal-matrix '(7 5 7)) (col-matrix [1 2 3]) - (row-matrix [4 5 6])))] + (row-matrix [4 5 6]))) + (block-diagonal-matrix (list (make-matrix 2 2 2.0+3.0i) + (make-matrix 2 2 5.0+7.0i)) + 0.0+0.0i)] } @define[vandermonde-url]{http://en.wikipedia.org/wiki/Vandermonde_matrix} @@ -257,7 +272,8 @@ The length of @racket[Xs] must be positive. @defproc[(vandermonde-matrix [xs (Listof Number)] [n Integer]) (Matrix Number)]{ Returns an @racket[m]×@racket[n] Vandermonde matrix, where @racket[m = (length xs)]. @examples[#:eval typed-eval - (vandermonde-matrix '(1 2 3 4) 5)] + (vandermonde-matrix '(1 2 3 4) 5) + (vandermonde-matrix '(5.2 3.4 2.0) 3)] Using a Vandermonde matrix to find a Lagrange polynomial (the polynomial of least degree that passes through a given set of points): @interaction[#:eval untyped-eval @@ -384,7 +400,8 @@ Computes @racket[(matrix* M ...)] with @racket[n] arguments, but more efficientl @examples[#:eval untyped-eval ; The 100th (and 101th) Fibonacci number: (matrix* (matrix-expt (matrix [[1 1] [1 0]]) 100) - (col-matrix [0 1]))] + (col-matrix [0 1])) + (->col-matrix (list (fibonacci 100) (fibonacci 99)))] } @defproc[(matrix-scale [M (Matrix Number)] [z Number]) (Matrix Number)]{ @@ -459,18 +476,19 @@ Returns array of the entries on the diagonal of @racket[M]. (matrix ([1 2 3] [4 5 6] [7 8 9])))] } -@deftogether[(@defproc[(matrix-upper-triangle [M (Matrix A)]) (Matrix (U A 0))] - @defproc[(matrix-lower-triangle [M (Matrix A)]) (Matrix (U A 0))])]{ +@deftogether[(@defproc[(matrix-upper-triangle [M (Matrix A)] [zero A 0]) (Matrix A)] + @defproc[(matrix-lower-triangle [M (Matrix A)] [zero A 0]) (Matrix A)])]{ The function @racket[matrix-upper-triangle] returns an upper -triangular matrix (entries below the diagonal are zero) with +triangular matrix (entries below the diagonal have the value @racket[zero]) with entries from the given matrix. Likewise the function @racket[matrix-lower-triangle] returns a lower triangular matrix. -@examples[#:eval untyped-eval +@examples[#:eval typed-eval (define M (array+ (array 1) (axis-index-array #(5 7) 1))) M (matrix-upper-triangle M) - (matrix-lower-triangle M)] + (matrix-lower-triangle M) + (matrix-lower-triangle (array->flarray M) 0.0)] } @deftogether[(@defproc[(matrix-rows [M (Matrix A)]) (Listof (Matrix A))] @@ -1010,7 +1028,7 @@ The norm used by @racket[matrix-relative-error] and @racket[matrix-absolute-erro The default value is @racket[matrix-op-inf-norm]. Besides being a true norm, @racket[norm] should also be @deftech{submultiplicative}: -@racketblock[(norm (matrix* M0 M1)) <= (* (norm A) (norm B))] +@racketblock[(norm (matrix* M0 M1)) <= (* (norm M0) (norm M1))] This additional triangle-like inequality makes it possible to prove error bounds for formulas that involve matrix multiplication. diff --git a/collects/math/tests/matrix-strictness-tests.rkt b/collects/math/tests/matrix-strictness-tests.rkt index 697197d48b..abc7c48130 100644 --- a/collects/math/tests/matrix-strictness-tests.rkt +++ b/collects/math/tests/matrix-strictness-tests.rkt @@ -131,7 +131,7 @@ (for*: ([M (list nonstrict-2x2-arr strict-2x2-arr)]) (check-false (array-strict? (matrix-scale M -1))) - (check-true (array-strict? (matrix-expt M 0))) + (check-false (array-strict? (matrix-expt M 0))) (check-false (array-strict? (matrix-expt (array-lazy M) 1))) (check-false (array-strict? (matrix-expt M 2))) (check-false (array-strict? (matrix-expt M 3)))