diff --git a/collects/math/array.rkt b/collects/math/array.rkt index e1159c4c04..ef07ac9211 100644 --- a/collects/math/array.rkt +++ b/collects/math/array.rkt @@ -9,6 +9,7 @@ "private/array/array-convert.rkt" "private/array/array-fold.rkt" "private/array/array-special-folds.rkt" + "private/array/array-unfold.rkt" "private/array/array-print.rkt" "private/array/array-fft.rkt" "private/array/array-syntax.rkt" @@ -36,6 +37,7 @@ "private/array/array-convert.rkt" "private/array/array-fold.rkt" "private/array/array-special-folds.rkt" + "private/array/array-unfold.rkt" "private/array/array-print.rkt" "private/array/array-syntax.rkt" "private/array/array-fft.rkt" diff --git a/collects/math/main.rkt b/collects/math/main.rkt index f038ae03dd..5577fe694f 100644 --- a/collects/math/main.rkt +++ b/collects/math/main.rkt @@ -9,7 +9,7 @@ "number-theory.rkt" "vector.rkt" "array.rkt" - ;"matrix.rkt" + "matrix.rkt" "utils.rkt") (provide (all-from-out @@ -22,5 +22,5 @@ "number-theory.rkt" "vector.rkt" "array.rkt" - ;"matrix.rkt" + "matrix.rkt" "utils.rkt")) diff --git a/collects/math/matrix.rkt b/collects/math/matrix.rkt index 8dc3bcc599..3dc1c82ca1 100644 --- a/collects/math/matrix.rkt +++ b/collects/math/matrix.rkt @@ -1,21 +1,22 @@ #lang typed/racket/base -(require "private/matrix/matrix-pointwise.rkt" - "private/matrix/matrix-multiply.rkt" +(require "private/matrix/matrix-arithmetic.rkt" "private/matrix/matrix-constructors.rkt" + "private/matrix/matrix-basic.rkt" "private/matrix/matrix-operations.rkt" + "private/matrix/matrix-comprehension.rkt" + "private/matrix/matrix-sequences.rkt" "private/matrix/matrix-expt.rkt" "private/matrix/matrix-types.rkt" "private/matrix/utils.rkt") -(provide (all-from-out "private/matrix/matrix-pointwise.rkt" - "private/matrix/matrix-multiply.rkt" - "private/matrix/matrix-constructors.rkt" - "private/matrix/matrix-operations.rkt" - "private/matrix/matrix-expt.rkt" - "private/matrix/matrix-types.rkt") - ;; From "utils.rkt" - array-matrix? - ;; would also like matrix? : (Any -> Boolean : (Array Any)), but we can't have one until we - ;; can define array? : (Any -> Boolean : (Array Any)), and there's been trouble with that - ) +(provide (all-from-out + "private/matrix/matrix-arithmetic.rkt" + "private/matrix/matrix-constructors.rkt" + "private/matrix/matrix-basic.rkt" + "private/matrix/matrix-operations.rkt" + "private/matrix/matrix-comprehension.rkt" + "private/matrix/matrix-sequences.rkt" + "private/matrix/matrix-expt.rkt" + "private/matrix/matrix-types.rkt") + matrix?) diff --git a/collects/math/private/array/array-unfold.rkt b/collects/math/private/array/array-unfold.rkt new file mode 100644 index 0000000000..71640e4f70 --- /dev/null +++ b/collects/math/private/array/array-unfold.rkt @@ -0,0 +1,51 @@ +#lang typed/racket/base + +(require racket/fixnum + "array-struct.rkt" + "array-pointwise.rkt" + "array-fold.rkt" + "utils.rkt" + "../unsafe.rkt") + +(provide unsafe-array-axis-expand + array-axis-expand + list-array->array) + +(: check-array-axis (All (A) (Symbol (Array A) Integer -> Index))) +(define (check-array-axis name arr k) + (define dims (array-dims arr)) + (cond [(fx= dims 0) (raise-argument-error name "Array with at least one axis" 0 arr k)] + [(or (k . < . 0) (k . > . dims)) + (raise-argument-error name (format "Index <= ~a" dims) 1 arr k)] + [else k])) + +(: unsafe-array-axis-expand (All (A B) ((Array A) Index Index (A Index -> B) -> (Array B)))) +(define (unsafe-array-axis-expand arr k dk f) + (define ds (array-shape arr)) + (define new-ds (unsafe-vector-insert ds k dk)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array + new-ds (λ: ([js : Indexes]) + (define jk (unsafe-vector-ref js k)) + (f (proc (unsafe-vector-remove js k)) jk)))) + +(: array-axis-expand (All (A B) ((Array A) Integer Integer (A Index -> B) -> (Array B)))) +(define (array-axis-expand arr k dk f) + (let ([k (check-array-axis 'array-axis-expand arr k)]) + (cond [(not (index? dk)) (raise-argument-error 'array-axis-expand "Index" 2 arr k dk f)] + [else (unsafe-array-axis-expand arr k dk f)]))) + +;; =================================================================================================== +;; Specific unfolds/expansions + +(: list-array->array (All (A) (case-> ((Array (Listof A)) -> (Array A)) + ((Array (Listof A)) Integer -> (Array A))))) +(define (list-array->array arr [k 0]) + (define dims (array-dims arr)) + (cond [(and (k . >= . 0) (k . <= . dims)) + (let ([arr (array-strict (array-map (inst list->vector A) arr))]) + ;(define dks (remove-duplicates (array->list (array-map vector-length arr)))) + (define dk (array-all-min (array-map vector-length arr))) + (unsafe-array-axis-expand arr k dk (inst unsafe-vector-ref A)))] + [else + (raise-argument-error 'list-array->array (format "Index <= ~a" dims) 1 arr k)])) diff --git a/collects/math/private/array/mutable-array.rkt b/collects/math/private/array/mutable-array.rkt index 2bca6df4e3..aa6715745a 100644 --- a/collects/math/private/array/mutable-array.rkt +++ b/collects/math/private/array/mutable-array.rkt @@ -23,8 +23,7 @@ mutable-array-copy mutable-array ;; Conversion - array->mutable-array - flat-vector->matrix) + array->mutable-array) (define-syntax (mutable-array stx) (syntax-parse stx #:literals (:) diff --git a/collects/math/private/array/typed-array-fold.rkt b/collects/math/private/array/typed-array-fold.rkt index d7ad28f90a..365612c554 100644 --- a/collects/math/private/array/typed-array-fold.rkt +++ b/collects/math/private/array/typed-array-fold.rkt @@ -21,23 +21,6 @@ (raise-argument-error name (format "Index < ~a" dims) 1 arr k)] [else k])) -(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B)))) -(define (array-axis-reduce arr k f) - (let ([k (check-array-axis 'array-axis-reduce arr k)]) - (define ds (array-shape arr)) - (define dk (unsafe-vector-ref ds k)) - (define new-ds (unsafe-vector-remove ds k)) - (define proc (unsafe-array-proc arr)) - (unsafe-build-array - new-ds (λ: ([js : Indexes]) - (define old-js (unsafe-vector-insert js k 0)) - (f dk (λ: ([jk : Integer]) - (cond [(or (jk . < . 0) (jk . >= . dk)) - (raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)] - [else - (unsafe-vector-set! old-js k jk) - (proc (vector-copy-all old-js))]))))))) - (: unsafe-array-axis-reduce (All (A B) ((Array A) Index (Index (Index -> A) -> B) -> (Array B)))) (begin-encourage-inline (define (unsafe-array-axis-reduce arr k f) @@ -52,6 +35,19 @@ (unsafe-vector-set! old-js k jk) (proc old-js))))))) +(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B)))) +(define (array-axis-reduce arr k f) + (let ([k (check-array-axis 'array-axis-reduce arr k)]) + (unsafe-array-axis-reduce + arr k + (λ: ([dk : Index] [proc : (Index -> A)]) + (: safe-proc (Integer -> A)) + (define (safe-proc jk) + (cond [(or (jk . < . 0) (jk . >= . dk)) + (raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)] + [else (proc jk)])) + (f dk safe-proc))))) + (: array-axis-fold/init (All (A B) ((Array A) Integer (A B -> B) B -> (Array B)))) (define (array-axis-fold/init arr k f init) (let ([k (check-array-axis 'array-axis-fold arr k)]) diff --git a/collects/math/private/array/typed-mutable-array.rkt b/collects/math/private/array/typed-mutable-array.rkt index 6e6c99ee97..4be8d74dd5 100644 --- a/collects/math/private/array/typed-mutable-array.rkt +++ b/collects/math/private/array/typed-mutable-array.rkt @@ -41,10 +41,6 @@ (define (mutable-array-copy arr) (unsafe-vector->array (array-shape arr) (vector-copy-all (mutable-array-data arr)))) -(: flat-vector->matrix : (All (A) (Index Index (Vectorof A) -> (Array A)))) -(define (flat-vector->matrix m n v) - (vector->array (vector m n) v)) - ;; =================================================================================================== ;; Conversions diff --git a/collects/math/private/matrix/matrix-2d.rkt b/collects/math/private/matrix/matrix-2d.rkt index e592e7bad4..9efdde259e 100644 --- a/collects/math/private/matrix/matrix-2d.rkt +++ b/collects/math/private/matrix/matrix-2d.rkt @@ -1,5 +1,8 @@ #lang typed/racket/base -(require math/matrix) + +(require math/array + "matrix-types.rkt" + "matrix-constructors.rkt") (provide matrix-2d-rotation matrix-2d-scaling @@ -11,34 +14,30 @@ ; Transformations from: ; http://en.wikipedia.org/wiki/Transformation_matrix -(: matrix-2d-rotation : Real -> (Matrix Number)) +(: matrix-2d-rotation : Real -> (Matrix Real)) ; matrix representing rotation θ radians counter clockwise (define (matrix-2d-rotation θ) (define cosθ (cos θ)) (define sinθ (sin θ)) - (matrix/dim 2 2 - cosθ (- sinθ) - sinθ cosθ)) + (matrix [[cosθ (- sinθ)] + [sinθ cosθ]])) -(: matrix-2d-scaling : Real Real -> (Matrix Number)) +(: matrix-2d-scaling : Real Real -> (Matrix Real)) (define (matrix-2d-scaling sx sy) - (matrix/dim 2 2 - sx 0 - 0 sy)) + (matrix [[sx 0] + [0 sy]])) -(: matrix-2d-shear-x : Real -> (Matrix Number)) +(: matrix-2d-shear-x : Real -> (Matrix Real)) (define (matrix-2d-shear-x k) - (matrix/dim 2 2 - 1 k - 0 1)) + (matrix [[1 k] + [0 1]])) -(: matrix-2d-shear-y : Real -> (Matrix Number)) +(: matrix-2d-shear-y : Real -> (Matrix Real)) (define (matrix-2d-shear-y k) - (matrix/dim 2 2 - 1 0 - k 1)) + (matrix [[1 0] + [k 1]])) -(: matrix-2d-reflection : Real Real -> (Matrix Number)) +(: matrix-2d-reflection : Real Real -> (Matrix Real)) (define (matrix-2d-reflection a b) ; reflection about the line through (0,0) and (a,b) (define a2 (* a a)) @@ -46,11 +45,10 @@ (define 2ab (* 2 a b)) (define norm2 (+ a2 b2)) (define 2ab/norm2 (/ 2ab norm2)) - (matrix/dim 2 2 - (/ (- a2 b2) norm2) 2ab/norm2 - 2ab/norm2 (/ (- b2 a2) norm2))) + (matrix [[(/ (- a2 b2) norm2) 2ab/norm2] + [2ab/norm2 (/ (- b2 a2) norm2)]])) -(: matrix-2d-orthogonal-projection : Real Real -> (Matrix Number)) +(: matrix-2d-orthogonal-projection : Real Real -> (Matrix Real)) ; orthogonal projection onto the line through (0,0) and (a,b) (define (matrix-2d-orthogonal-projection a b) (define a2 (* a a)) @@ -58,6 +56,5 @@ (define ab (* a b)) (define norm2 (+ a2 b2)) (define ab/norm2 (/ ab norm2)) - (matrix/dim 2 2 - (/ a2 norm2) ab/norm2 - ab/norm2 (/ b2 norm2))) + (matrix [[(/ a2 norm2) ab/norm2] + [ab/norm2 (/ b2 norm2)]])) diff --git a/collects/math/private/matrix/matrix-arithmetic.rkt b/collects/math/private/matrix/matrix-arithmetic.rkt new file mode 100644 index 0000000000..e0d5e85499 --- /dev/null +++ b/collects/math/private/matrix/matrix-arithmetic.rkt @@ -0,0 +1,39 @@ +#lang racket/base + +(require (for-syntax racket/base) + typed/untyped-utils + (prefix-in typed: "typed-matrix-arithmetic.rkt") + (rename-in "untyped-matrix-arithmetic.rkt" + [matrix-map untyped:matrix-map])) + +(define-typed/untyped-identifier matrix-map + typed:matrix-map untyped:matrix-map) + +(define-syntax (define/inline-macro stx) + (syntax-case stx () + [(_ name pat inline-fun typed:fun) + (syntax/loc stx + (define-syntax (name inner-stx) + (syntax-case inner-stx () + [(_ . pat) (syntax/loc inner-stx (inline-fun . pat))] + [(_ . es) (syntax/loc inner-stx (typed:fun . es))] + [_ (syntax/loc inner-stx typed:fun)])))])) + +(define/inline-macro matrix* (a . as) inline-matrix* typed:matrix*) +(define/inline-macro matrix+ (a . as) inline-matrix+ typed:matrix+) +(define/inline-macro matrix- (a . as) inline-matrix- typed:matrix-) +(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale) + +(provide + ;; Equality + (rename-out [typed:matrix= matrix=]) + ;; Mapping + inline-matrix-map + matrix-map + ;; Multiplication + matrix* + ;; Pointwise operators + matrix+ + matrix- + matrix-scale + (rename-out [typed:matrix-sum matrix-sum])) diff --git a/collects/math/private/matrix/matrix-basic.rkt b/collects/math/private/matrix/matrix-basic.rkt new file mode 100644 index 0000000000..c5f4b04e43 --- /dev/null +++ b/collects/math/private/matrix/matrix-basic.rkt @@ -0,0 +1,206 @@ +#lang typed/racket + +(require racket/list + racket/fixnum + math/array + math/flonum + "matrix-types.rkt" + "utils.rkt" + "../unsafe.rkt") + +(provide + ;; Extraction + matrix-ref + matrix-diagonal + submatrix + matrix-row + matrix-col + matrix-rows + matrix-cols + ;; Predicates + zero-matrix? + ;; Embiggenment + matrix-augment + matrix-stack + ;; Norm and inner product + matrix-norm + matrix-dot + ;; Simple operators + matrix-transpose + matrix-conjugate + matrix-hermitian + matrix-trace) + +;; =================================================================================================== +;; Extraction + +(: matrix-ref (All (A) (Array A) Integer Integer -> A)) +(define (matrix-ref a i j) + (define-values (m n) (matrix-shape a)) + (cond [(or (i . < . 0) (i . >= . m)) + (raise-argument-error 'matrix-ref (format "Index < ~a" m) 1 a i j)] + [(or (j . < . 0) (j . >= . n)) + (raise-argument-error 'matrix-ref (format "Index < ~a" n) 2 a i j)] + [else + (unsafe-array-ref a ((inst vector Index) i j))])) + +(: matrix-diagonal (All (A) ((Array A) -> (Array A)))) +(define (matrix-diagonal a) + (define m (square-matrix-size a)) + (define proc (unsafe-array-proc a)) + (unsafe-build-array + ((inst vector Index) m) + (λ: ([js : Indexes]) + (define i (unsafe-vector-ref js 0)) + (proc ((inst vector Index) i i))))) + +(: submatrix (All (A) (Matrix A) Slice-Spec Slice-Spec -> (Matrix A))) +(define (submatrix a row-range col-range) + (array-slice-ref (ensure-matrix 'submatrix a) (list row-range col-range))) + +(: matrix-row (All (A) (Matrix A) Integer -> (Matrix A))) +(define (matrix-row a i) + (define-values (m n) (matrix-shape a)) + (cond [(or (i . < . 0) (i . >= . m)) + (raise-argument-error 'matrix-row (format "Index < ~a" m) 1 a i)] + [else + (define proc (unsafe-array-proc a)) + (unsafe-build-array + ((inst vector Index) 1 n) + (λ: ([ij : Indexes]) + (unsafe-vector-set! ij 0 i) + (define res (proc ij)) + (unsafe-vector-set! ij 0 0) + res))])) + +(: matrix-col (All (A) (Matrix A) Index -> (Matrix A))) +(define (matrix-col a j) + (define-values (m n) (matrix-shape a)) + (cond [(or (j . < . 0) (j . >= . n)) + (raise-argument-error 'matrix-row (format "Index < ~a" n) 1 a j)] + [else + (define proc (unsafe-array-proc a)) + (unsafe-build-array + ((inst vector Index) m 1) + (λ: ([ij : Indexes]) + (unsafe-vector-set! ij 1 j) + (define res (proc ij)) + (unsafe-vector-set! ij 1 0) + res))])) + +(: matrix-rows (All (A) (Array A) -> (Listof (Array A)))) +(define (matrix-rows a) + (array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0)) + +(: matrix-cols (All (A) (Array A) -> (Listof (Array A)))) +(define (matrix-cols a) + (array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1)) + +;; =================================================================================================== +;; Predicates + +(: zero-matrix? ((Array Number) -> Boolean)) +(define (zero-matrix? a) + (array-all-and (array-map zero? a))) + +;; =================================================================================================== +;; Embiggenment (this is a perfectly cromulent word) + +(: matrix-augment (All (A) (Listof (Array A)) -> (Array A))) +(define (matrix-augment as) + (cond [(empty? as) (raise-argument-error 'matrix-augment "nonempty List" as)] + [else + (define m (matrix-num-rows (first as))) + (cond [(andmap (λ: ([a : (Array A)]) (= m (matrix-num-rows a))) (rest as)) + (array-append* as 1)] + [else + (error 'matrix-augment + "matrices must have the same number of rows; given ~a" + (format-matrices/error as))])])) + +(: matrix-stack (All (A) (Listof (Array A)) -> (Array A))) +(define (matrix-stack as) + (cond [(empty? as) (raise-argument-error 'matrix-stack "nonempty List" as)] + [else + (define n (matrix-num-cols (first as))) + (cond [(andmap (λ: ([a : (Array A)]) (= n (matrix-num-cols a))) (rest as)) + (array-append* as 0)] + [else + (error 'matrix-stack + "matrices must have the same number of columns; given ~a" + (format-matrices/error as))])])) + +;; =================================================================================================== +;; Matrix norms and Frobenius inner product + +(: maximum-norm ((Array Number) -> Real)) +(define (maximum-norm a) + (array-all-max (array-magnitude a))) + +(: taxicab-norm ((Array Number) -> Real)) +(define (taxicab-norm a) + (array-all-sum (array-magnitude a))) + +(: frobenius-norm ((Array Number) -> Real)) +(define (frobenius-norm a) + (let ([a (array-strict (array-magnitude a))]) + ;; Compute this divided by the maximum to avoid underflow and overflow + (define mx (array-all-max a)) + (cond [(and (rational? mx) (positive? mx)) + (assert + (* mx (sqrt (array-all-sum (inline-array-map (λ: ([x : Number]) (sqr (/ x mx))) a)))) + real?)] + [else mx]))) + +(: p-norm ((Array Number) Positive-Real -> Real)) +(define (p-norm a p) + (let ([a (array-strict (array-magnitude a))]) + ;; Compute this divided by the maximum to avoid underflow and overflow + (define mx (array-all-max a)) + (cond [(and (rational? mx) (positive? mx)) + (assert + (* mx (expt (array-all-sum (inline-array-map (λ: ([x : Real]) (expt (/ x mx) p)) a)) + (/ p))) + real?)] + [else mx]))) + +(: matrix-norm (case-> ((Array Number) -> Real) + ((Array Number) Real -> Real))) +;; Computes the p norm of a matrix +(define (matrix-norm a [p 2]) + (cond [(not (matrix? a)) (raise-argument-error 'matrix-norm "matrix?" 0 a p)] + [(p . = . 2) (frobenius-norm a)] + [(p . = . +inf.0) (maximum-norm a)] + [(p . = . 1) (taxicab-norm a)] + [(p . > . 1) (p-norm a p)] + [else (raise-argument-error 'matrix-norm "Real >= 1" 1 a p)])) + +(: matrix-dot (case-> ((Array Real) (Array Real) -> Real) + ((Array Number) (Array Number) -> Number))) +;; Computes the Frobenius inner product of two matrices +(define (matrix-dot a b) + (cond [(not (matrix? a)) (raise-argument-error 'matrix-dot "matrix?" 0 a b)] + [(not (matrix? b)) (raise-argument-error 'matrix-dot "matrix?" 1 a b)] + [else (array-all-sum (array* a (array-conjugate b)))])) + +;; =================================================================================================== +;; Operators + +(: matrix-transpose (All (A) (Array A) -> (Array A))) +(define (matrix-transpose a) + (array-axis-swap (ensure-matrix 'matrix-transpose a) 0 1)) + +(: matrix-conjugate (case-> ((Array Real) -> (Array Real)) + ((Array Number) -> (Array Number)))) +(define (matrix-conjugate a) + (array-conjugate (ensure-matrix 'matrix-conjugate a))) + +(: matrix-hermitian (case-> ((Array Real) -> (Array Real)) + ((Array Number) -> (Array Number)))) +(define (matrix-hermitian a) + (array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1)) + +(: matrix-trace (case-> ((Array Real) -> Real) + ((Array Number) -> Number))) +(define (matrix-trace a) + (array-all-sum (matrix-diagonal a))) diff --git a/collects/math/private/matrix/matrix-column.rkt b/collects/math/private/matrix/matrix-column.rkt new file mode 100644 index 0000000000..ab864dfed7 --- /dev/null +++ b/collects/math/private/matrix/matrix-column.rkt @@ -0,0 +1,126 @@ +#lang typed/racket/base + +(require math/array + math/base + "matrix-types.rkt" + "matrix-constructors.rkt" + "matrix-arithmetic.rkt" + "../unsafe.rkt") + +(provide unit-column + column-height + unsafe-column->vector + column-scale + column+ + column-dot + column-norm + column-project + column-project/unit + column-normalize) + +(: unit-column : Integer Integer -> (Result-Column Number)) +(define (unit-column m i) + (cond + [(and (index? m) (index? i)) + (define v (make-vector m 0)) + (if (< i m) + (vector-set! v i 1) + (error 'unit-vector "dimension must be largest")) + (vector->matrix m 1 v)] + [else + (error 'unit-vector "expected two indices")])) + +(: column-height : (Column Number) -> Index) +(define (column-height v) + (if (vector? v) + (vector-length v) + (matrix-num-rows v))) + +(: unsafe-column->vector : (Column Number) -> (Vectorof Number)) +(define (unsafe-column->vector v) + (if (vector? v) v + (let () + (define-values (m n) (matrix-shape v)) + (if (= n 1) + (mutable-array-data (array->mutable-array v)) + (error 'unsafe-column->vector + "expected a column (vector or mx1 matrix), got ~a" v))))) + +(: column-scale : (Column Number) Number -> (Result-Column Number)) +(define (column-scale a s) + (if (vector? a) + (let*: ([n (vector-length a)] + [v : (Vectorof Number) (make-vector n 0)]) + (for: ([i (in-range 0 n)] + [x : Number (in-vector a)]) + (vector-set! v i (* s x))) + (->col-matrix v)) + (matrix-scale a s))) + +(: column+ : (Column Number) (Column Number) -> (Result-Column Number)) +(define (column+ v w) + (cond [(and (vector? v) (vector? w)) + (let ([n (vector-length v)] + [m (vector-length w)]) + (unless (= m n) + (error 'column+ + "expected two column vectors of the same length, got ~a and ~a" v w)) + (define: v+w : (Vectorof Number) (make-vector n 0)) + (for: ([i (in-range 0 n)] + [x : Number (in-vector v)] + [y : Number (in-vector w)]) + (vector-set! v+w i (+ x y))) + (->col-matrix v+w))] + [else + (unless (= (column-height v) (column-height w)) + (error 'column+ + "expected two column vectors of the same length, got ~a and ~a" v w)) + (array+ (->col-matrix v) (->col-matrix w))])) + +(: column-dot : (Column Number) (Column Number) -> Number) +(define (column-dot c d) + (define v (unsafe-column->vector c)) + (define w (unsafe-column->vector d)) + (define m (column-height v)) + (define s (column-height w)) + (cond + [(not (= m s)) (error 'column-dot + "expected two mx1 matrices with same number of rows, got ~a and ~a" + c d)] + [else + (for/sum: : Number ([i (in-range 0 m)]) + (assert i index?) + ; Note: If d is a vector of reals, + ; then the conjugate is a no-op + (* (unsafe-vector-ref v i) + (conjugate (unsafe-vector-ref w i))))])) + +(: column-norm : (Column Number) -> Real) +(define (column-norm v) + (define norm (sqrt (column-dot v v))) + (assert norm real?)) + +(: column-project : (Column Number) (Column Number) -> (Result-Column Number)) +; (column-project v w) +; Return the projection og vector v on vector w. +(define (column-project v w) + (let ([w.w (column-dot w w)]) + (if (zero? w.w) + (error 'column-project "projection on the zero vector not defined") + (matrix-scale (->col-matrix w) (/ (column-dot v w) w.w))))) + +(: column-project/unit : (Column Number) (Column Number) -> (Result-Column Number)) +; (column-project-on-unit v w) +; Return the projection og vector v on a unit vector w. +(define (column-project/unit v w) + (matrix-scale (->col-matrix w) (column-dot v w))) + +(: column-normalize : (Column Number) -> (Result-Column Number)) +; (column-vector-normalize v) +; Return unit vector with same direction as v. +; If v is the zero vector, the zero vector is returned. +(define (column-normalize w) + (let ([norm (column-norm w)] + [w (->col-matrix w)]) + (cond [(zero? norm) w] + [else (matrix-scale w (/ norm))]))) diff --git a/collects/math/private/matrix/matrix-comprehension.rkt b/collects/math/private/matrix/matrix-comprehension.rkt new file mode 100644 index 0000000000..f9ac60b731 --- /dev/null +++ b/collects/math/private/matrix/matrix-comprehension.rkt @@ -0,0 +1,140 @@ +#lang racket + +(require math/array + typed/racket/base + "matrix-types.rkt" + "matrix-constructors.rkt") + +(provide for/matrix + for*/matrix + for/matrix: + for*/matrix:) + +;;; COMPREHENSIONS + +; (for/matrix m n (clause ...) . defs+exprs) +; Return an m x n matrix with elements from the last expr. +; The first n values produced becomes the first row. +; The next n values becomes the second row and so on. +; The bindings in clauses run in parallel. +(define-syntax (for/matrix stx) + (syntax-case stx () + [(_ m-expr n-expr (clause ...) . defs+exprs) + (syntax/loc stx + (let ([m m-expr] [n n-expr]) + (define flat-vector + (for/vector #:length (* m n) + (clause ...) . defs+exprs)) + (vector->matrix m n flat-vector)))])) + +; (for*/matrix m n (clause ...) . defs+exprs) +; Return an m x n matrix with elements from the last expr. +; The first n values produced becomes the first row. +; The next n values becomes the second row and so on. +; The bindings in clauses run nested. +; (for*/matrix m n #:column (clause ...) . defs+exprs) +; Return an m x n matrix with elements from the last expr. +; The first m values produced becomes the first column. +; The next m values becomes the second column and so on. +; The bindings in clauses run nested. +(define-syntax (for*/matrix stx) + (syntax-case stx () + [(_ m-expr n-expr #:column (clause ...) . defs+exprs) + (syntax/loc stx + (let* ([m m-expr] + [n n-expr] + [v (make-vector (* m n) 0)] + [w (for*/vector #:length (* m n) (clause ...) . defs+exprs)]) + (for* ([i (in-range m)] [j (in-range n)]) + (vector-set! v (+ (* i n) j) + (vector-ref w (+ (* j m) i)))) + (vector->matrix m n v)))] + [(_ m-expr n-expr (clause ...) . defs+exprs) + (syntax/loc stx + (let ([m m-expr] [n n-expr]) + (vector->matrix + m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))])) + + +(define-syntax (for/column: stx) + (syntax-case stx () + [(_ : type m-expr (for:-clause ...) . defs+exprs) + (syntax/loc stx + (let () + (define: m : Index m-expr) + (define: flat-vector : (Vectorof Number) (make-vector m 0)) + (for: ([i (in-range m)] for:-clause ...) + (define x (let () . defs+exprs)) + (vector-set! flat-vector i x)) + (vector->col-matrix flat-vector)))])) + +(define-syntax (for/matrix: stx) + (syntax-case stx () + [(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs) + (syntax/loc stx + (let () + (define: m : Index m-expr) + (define: n : Index n-expr) + (define: m*n : Index (assert (* m n) index?)) + (define: v : (Vectorof Number) (make-vector m*n 0)) + (define: k : Index 0) + (for: ([i (in-range m*n)] for:-clause ...) + (define x (let () . defs+exprs)) + (vector-set! v (+ (* n (remainder k m)) (quotient k m)) x) + (set! k (assert (+ k 1) index?))) + (vector->matrix m n v)))] + [(_ : type m-expr n-expr (for:-clause ...) . defs+exprs) + (syntax/loc stx + (let () + (define: m : Index m-expr) + (define: n : Index n-expr) + (define: m*n : Index (assert (* m n) index?)) + (define: v : (Vectorof Number) (make-vector m*n 0)) + (for: ([i (in-range m*n)] for:-clause ...) + (define x (let () . defs+exprs)) + (vector-set! v i x)) + (vector->matrix m n v)))])) + +(define-syntax (for*/matrix: stx) + (syntax-case stx () + [(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs) + (syntax/loc stx + (let () + (define: m : Index m-expr) + (define: n : Index n-expr) + (define: m*n : Index (assert (* m n) index?)) + (define: v : (Vectorof Number) (make-vector m*n 0)) + (define: k : Index 0) + (for*: (for:-clause ...) + (define x (let () . defs+exprs)) + (vector-set! v (+ (* n (remainder k m)) (quotient k m)) x) + (set! k (assert (+ k 1) index?))) + (vector->matrix m n v)))] + [(_ : type m-expr n-expr (for:-clause ...) . defs+exprs) + (syntax/loc stx + (let () + (define: m : Index m-expr) + (define: n : Index n-expr) + (define: m*n : Index (assert (* m n) index?)) + (define: v : (Vectorof Number) (make-vector m*n 0)) + (define: i : Index 0) + (for*: (for:-clause ...) + (define x (let () . defs+exprs)) + (vector-set! v i x) + (set! i (assert (+ i 1) index?))) + (vector->matrix m n v)))])) +#; +(module* test #f + (require rackunit) + ; "matrix-sequences.rkt" + ; These work in racket not in typed racket + (check-equal? (matrix->list* (for*/matrix 2 3 ([i 2] [j 3]) (+ i j))) + '[[0 1 2] [1 2 3]]) + (check-equal? (matrix->list* (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j))) + '[[0 2 2] [1 1 3]]) + (check-equal? (matrix->list* (for*/matrix 2 2 #:column ([i 4]) i)) + '[[0 2] [1 3]]) + (check-equal? (matrix->list* (for/matrix 2 2 ([i 4]) i)) + '[[0 1] [2 3]]) + (check-equal? (matrix->list* (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j))) + '[[6 8 10] [12 14 16]])) diff --git a/collects/math/private/matrix/matrix-constructors.rkt b/collects/math/private/matrix/matrix-constructors.rkt index 34afaacc3c..c4fc391b2f 100644 --- a/collects/math/private/matrix/matrix-constructors.rkt +++ b/collects/math/private/matrix/matrix-constructors.rkt @@ -1,62 +1,376 @@ -#lang typed/racket +#lang racket/base -(require math/array - "../unsafe.rkt" - "matrix-types.rkt") +(provide + ;; Constructors + identity-matrix + make-matrix + build-matrix + diagonal-matrix/zero + diagonal-matrix + block-diagonal-matrix/zero + block-diagonal-matrix + vandermonde-matrix + ;; Basic conversion + list->matrix + matrix->list + vector->matrix + matrix->vector + ->row-matrix + ->col-matrix + ;; Nested conversion + list*->matrix + matrix->list* + vector*->matrix + matrix->vector* + ;; Syntax + matrix + row-matrix + col-matrix) -(provide identity-matrix flidentity-matrix - matrix->list list->matrix fllist->matrix - matrix->vector vector->matrix flvector->matrix - flat-vector->matrix - make-matrix - matrix-row - matrix-column - submatrix) +(module typed-defs typed/racket/base + (require racket/fixnum + racket/list + racket/vector + math/array + "../array/utils.rkt" + "matrix-types.rkt" + "utils.rkt" + "../unsafe.rkt") + + (provide (all-defined-out)) + + ;; ================================================================================================= + ;; Constructors + + (: identity-matrix (Integer -> (Matrix (U 0 1)))) + (define (identity-matrix m) (diagonal-array 2 m 1 0)) + + (: make-matrix (All (A) (Integer Integer A -> (Matrix A)))) + (define (make-matrix m n x) + (make-array (vector m n) x)) + + (: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A)))) + (define (build-matrix m n proc) + (cond [(or (not (index? m)) (= m 0)) + (raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)] + [(or (not (index? n)) (= n 0)) + (raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)] + [else + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (proc (unsafe-vector-ref js 0) + (unsafe-vector-ref js 1))))])) + + (: diagonal-matrix/zero (All (A) (Array A) A -> (Array A))) + (define (diagonal-matrix/zero a zero) + (define ds (array-shape a)) + (cond [(= 1 (vector-length ds)) + (define m (unsafe-vector-ref ds 0)) + (define proc (unsafe-array-proc a)) + (unsafe-build-array + ((inst vector Index) m m) + (λ: ([js : Indexes]) + (define i (unsafe-vector-ref js 0)) + (cond [(= i (unsafe-vector-ref js 1)) (proc ((inst vector Index) i))] + [else zero])))] + [else + (raise-argument-error 'diagonal-matrix "Array with one dimension" a)])) + + (: diagonal-matrix (case-> ((Array Real) -> (Array Real)) + ((Array Number) -> (Array Number)))) + (define (diagonal-matrix a) + (diagonal-matrix/zero a 0)) + + (: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A))) + (define (block-diagonal-matrix/zero* as zero) + (define num (vector-length as)) + (define-values (ms ns) + (let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty] + [ns : (Listof Index) empty] + ) ([a (in-vector as)]) + (define-values (m n) (matrix-shape a)) + (values (cons m ms) (cons n ns)))]) + (values (reverse ms) (reverse ns)))) + (define res-m (assert (apply + ms) index?)) + (define res-n (assert (apply + ns) index?)) + (define vs ((inst make-vector Index) res-m 0)) + (define hs ((inst make-vector Index) res-n 0)) + (define is ((inst make-vector Index) res-m 0)) + (define js ((inst make-vector Index) res-n 0)) + (define-values (_res-i _res-j) + (for/fold: ([res-i : Nonnegative-Fixnum 0] + [res-j : Nonnegative-Fixnum 0] + ) ([m (in-list ms)] + [n (in-list ns)] + [k : Nonnegative-Fixnum (in-range num)]) + (let ([k (assert k index?)]) + (for: ([i : Nonnegative-Fixnum (in-range m)]) + (vector-set! vs (unsafe-fx+ res-i i) k) + (vector-set! is (unsafe-fx+ res-i i) (assert i index?))) + (for: ([j : Nonnegative-Fixnum (in-range n)]) + (vector-set! hs (unsafe-fx+ res-j j) k) + (vector-set! js (unsafe-fx+ res-j j) (assert j index?)))) + (values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n)))) + (define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as)) + (unsafe-build-array + ((inst vector Index) res-m res-n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (define v (unsafe-vector-ref vs i)) + (cond [(fx= v (vector-ref hs j)) + (define proc (unsafe-vector-ref procs v)) + (define iv (unsafe-vector-ref is i)) + (define jv (unsafe-vector-ref js j)) + (unsafe-vector-set! ij 0 iv) + (unsafe-vector-set! ij 1 jv) + (define res (proc ij)) + (unsafe-vector-set! ij 0 i) + (unsafe-vector-set! ij 1 j) + res] + [else + zero])))) + + (: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A))) + (define (block-diagonal-matrix/zero as zero) + (let ([as (list->vector as)]) + (define num (vector-length as)) + (cond [(= num 0) + (raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)] + [(= num 1) + (unsafe-vector-ref as 0)] + [else + (block-diagonal-matrix/zero* as zero)]))) + + (: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real)) + ((Listof (Array Number)) -> (Array Number)))) + (define (block-diagonal-matrix as) + (block-diagonal-matrix/zero as 0)) + + (: expt-hack (case-> (Real Integer -> Real) + (Number Integer -> Number))) + ;; Stop using this when TR correctly derives expt : Real Integer -> Real + (define (expt-hack x n) + (cond [(real? x) (assert (expt x n) real?)] + [else (expt x n)])) + + (: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real)) + ((Listof Number) Integer -> (Array Number)))) + (define (vandermonde-matrix xs n) + (cond [(empty? xs) + (raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)] + [(or (not (index? n)) (zero? n)) + (raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)] + [else + (array-axis-expand (list->array xs) 1 n expt-hack)])) + + ;; ================================================================================================= + ;; Flat conversion + + (: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A)))) + (define (list->matrix m n xs) + (cond [(or (not (index? m)) (= m 0)) + (raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)] + [(or (not (index? n)) (= n 0)) + (raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)] + [else (list->array (vector m n) xs)])) + + (: matrix->list (All (A) ((Array A) -> (Listof A)))) + (define (matrix->list a) + (array->list (ensure-matrix 'matrix->list a))) + + (: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A)))) + (define (vector->matrix m n v) + (cond [(or (not (index? m)) (= m 0)) + (raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)] + [(or (not (index? n)) (= n 0)) + (raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)] + [else (vector->array (vector m n) v)])) + + (: matrix->vector (All (A) ((Array A) -> (Vectorof A)))) + (define (matrix->vector a) + (array->vector (ensure-matrix 'matrix->vector a))) + + (: list->row-matrix (All (A) ((Listof A) -> (Array A)))) + (define (list->row-matrix xs) + (cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)] + [else (list->array ((inst vector Index) 1 (length xs)) xs)])) + + (: list->col-matrix (All (A) ((Listof A) -> (Array A)))) + (define (list->col-matrix xs) + (cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)] + [else (list->array ((inst vector Index) (length xs) 1) xs)])) + + (: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) + (define (vector->row-matrix xs) + (define n (vector-length xs)) + (cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)] + [else (vector->array ((inst vector Index) 1 n) xs)])) + + (: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) + (define (vector->col-matrix xs) + (define n (vector-length xs)) + (cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)] + [else (vector->array ((inst vector Index) n 1) xs)])) + + (: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index))) + (define (find-nontrivial-axis ds) + (define dims (vector-length ds)) + (let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0]) + (cond [(k . < . dims) (define dk (unsafe-vector-ref ds k)) + (if (dk . > . 1) (values k dk) (loop (fx+ k 1)))] + [else (values 0 0)]))) + + (: array->row-matrix (All (A) ((Array A) -> (Array A)))) + (define (array->row-matrix arr) + (define (fail) + (raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr)) + (define ds (array-shape arr)) + (define dims (vector-length ds)) + (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) + (cond [(zero? (array-size arr)) (fail)] + [(row-matrix? arr) arr] + [(= num-ones dims) + (define: js : (Vectorof Index) (make-vector dims 0)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) 1 1) + (λ: ([ij : Indexes]) (proc js)))] + [(= num-ones (- dims 1)) + (define-values (k n) (find-nontrivial-axis ds)) + (define js (make-thread-local-indexes dims)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) 1 n) + (λ: ([ij : Indexes]) + (let ([js (js)]) + (unsafe-vector-set! js k (unsafe-vector-ref ij 1)) + (proc js))))] + [else (fail)])) + + (: array->col-matrix (All (A) ((Array A) -> (Array A)))) + (define (array->col-matrix arr) + (define (fail) + (raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr)) + (define ds (array-shape arr)) + (define dims (vector-length ds)) + (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) + (cond [(zero? (array-size arr)) (fail)] + [(col-matrix? arr) arr] + [(= num-ones dims) + (define: js : (Vectorof Index) (make-vector dims 0)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) 1 1) + (λ: ([ij : Indexes]) (proc js)))] + [(= num-ones (- dims 1)) + (define-values (k m) (find-nontrivial-axis ds)) + (define js (make-thread-local-indexes dims)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) m 1) + (λ: ([ij : Indexes]) + (let ([js (js)]) + (unsafe-vector-set! js k (unsafe-vector-ref ij 0)) + (proc js))))] + [else (fail)])) + + (: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) + (define (->row-matrix xs) + (cond [(list? xs) (list->row-matrix xs)] + [(array? xs) (array->row-matrix xs)] + [else (vector->row-matrix xs)])) + + (: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) + (define (->col-matrix xs) + (cond [(list? xs) (list->col-matrix xs)] + [(array? xs) (array->col-matrix xs)] + [else (vector->col-matrix xs)])) + + ;; ================================================================================================= + ;; Nested conversion + + (: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index))) + (define (list*-shape xss fail) + (define m (length xss)) + (cond [(m . > . 0) + (define n (length (first xss))) + (cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss))) + (values m n)] + [else (fail)])] + [else (fail)])) + + (: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing) + -> (Values Positive-Index Positive-Index))) + (define (vector*-shape xss fail) + (define m (vector-length xss)) + (cond [(m . > . 0) + (define ns ((inst vector-map Index (Vectorof A)) vector-length xss)) + (define n (vector-length (unsafe-vector-ref xss 0))) + (cond [(and (n . > . 0) + (let: loop : Boolean ([i : Nonnegative-Fixnum 1]) + (cond [(i . fx< . m) + (if (= n (vector-length (unsafe-vector-ref xss i))) + (loop (fx+ i 1)) + #f)] + [else #t]))) + (values m n)] + [else (fail)])] + [else (fail)])) + + (: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A))) + (define (list*->matrix xss) + (define (fail) + (raise-argument-error 'list*->matrix + "nested lists with rectangular shape and at least one matrix element" + xss)) + (define-values (m n) (list*-shape xss fail)) + (list->array ((inst vector Index) m n) (apply append xss))) + + (: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A)))) + (define (matrix->list* a) + (cond [(matrix? a) (array->list (array->list-array a 1))] + [else (raise-argument-error 'matrix->list* "matrix?" a)])) + + (: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A))) + (define (vector*->matrix xss) + (define (fail) + (raise-argument-error 'vector*->matrix + "nested vectors with rectangular shape and at least one matrix element" + xss)) + (define-values (m n) (vector*-shape xss fail)) + (vector->matrix m n (apply vector-append (vector->list xss)))) + + (: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A)))) + (define (matrix->vector* a) + (cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))] + [else (raise-argument-error 'matrix->vector* "matrix?" a)])) + ) ; module -(: identity-matrix (Integer -> (Matrix Real))) -(define (identity-matrix m) (diagonal-array 2 m 1 0)) +(require (for-syntax racket/base + syntax/parse) + (only-in typed/racket/base :) + math/array + (submod "." typed-defs)) -(: flidentity-matrix (Integer -> (Matrix Float))) -(define (flidentity-matrix m) (diagonal-array 2 m 1.0 0.0)) +(define-syntax (matrix stx) + (syntax-parse stx #:literals (:) + [(_ [[x0 xs0 ...] [x xs ...] ...]) + (syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))] + [(_ [[x0 xs0 ...] [x xs ...] ...] : T) + (syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))] + [(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T))) + (raise-syntax-error 'matrix "given empty row" stx #'r)] + [(_ (~and [] c) (~optional (~seq : T))) + (raise-syntax-error 'matrix "given empty matrix" stx #'c)])) -(: make-matrix (All (A) (Integer Integer A -> (Matrix A)))) -(define (make-matrix m n x) - (make-array (vector m n) x)) +(define-syntax (row-matrix stx) + (syntax-parse stx #:literals (:) + [(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))] + [(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))] + [(_ (~and [] r) (~optional (~seq : T))) + (raise-syntax-error 'row-matrix "given empty row" stx #'r)])) -(: list->matrix : (Listof* Number) -> (Matrix Number)) -(define (list->matrix rows) - (list*->array rows number?)) - -(: fllist->matrix : (Listof* Flonum) -> (Matrix Flonum)) -(define (fllist->matrix rows) - (list*->array rows flonum? )) - -(: matrix->list : (All (A) (Matrix A) -> (Listof* A))) -(define (matrix->list a) - (array->list* a)) - -(: vector->matrix : (Vectorof* Number) -> (Matrix Number)) -(define (vector->matrix rows) - (vector*->array rows number? )) - -(: flvector->matrix : (Vectorof* Flonum) -> (Matrix Flonum)) -(define (flvector->matrix rows) - (vector*->array rows flonum? )) - -(: matrix->vector : (All (A) (Matrix A) -> (Vectorof* A))) -(define (matrix->vector a) - (array->vector* a)) - -(: submatrix : (Matrix Number) (Sequenceof Index) (Sequenceof Index) -> (Matrix Number)) -(define (submatrix a row-range col-range) - (array-slice-ref a (list row-range col-range))) - -(: matrix-row : (Matrix Number) Index -> (Matrix Number)) -(define (matrix-row a i) - (define-values (m n) (matrix-dimensions a)) - (array-slice-ref a (list (in-range i (add1 i)) (in-range n)))) - -(: matrix-column : (Matrix Number) Index -> (Matrix Number)) -(define (matrix-column a j) - (define-values (m n)(matrix-dimensions a)) - (array-slice-ref a (list (in-range m) (in-range j (add1 j))))) +(define-syntax (col-matrix stx) + (syntax-parse stx #:literals (:) + [(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))] + [(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))] + [(_ (~and [] c) (~optional (~seq : T))) + (raise-syntax-error 'row-matrix "given empty column" stx #'c)])) diff --git a/collects/math/private/matrix/matrix-expt.rkt b/collects/math/private/matrix/matrix-expt.rkt index 46386df1b5..b31391fad8 100644 --- a/collects/math/private/matrix/matrix-expt.rkt +++ b/collects/math/private/matrix/matrix-expt.rkt @@ -1,22 +1,21 @@ #lang typed/racket -(require "../../array.rkt" +(require math/array + "matrix-types.rkt" "matrix-constructors.rkt" - "matrix-multiply.rkt" - "matrix-types.rkt") + "matrix-arithmetic.rkt") (provide matrix-expt) (: matrix-expt : (Matrix Number) Integer -> (Matrix Number)) (define (matrix-expt a n) - (unless (array-matrix? a) - (raise-type-error 'matrix-expt "(Matrix Number)" a)) - (unless (square-matrix? a) - (error 'matrix-expt "Square matrix expected, got ~a" a)) - (cond - [(= n 0) (identity-matrix (square-matrix-size a))] - [(= n 1) a] - [(= n 2) (matrix* a a)] - [(even? n) (let ([a^n/2 (matrix-expt a (quotient n 2))]) - (matrix* a^n/2 a^n/2))] - [else (matrix* a (matrix-expt a (sub1 n)))])) + (cond [(not (square-matrix? a)) (raise-argument-error 'matrix-expt "square-matrix?" 0 a n)] + [(negative? n) (raise-argument-error 'matrix-expt "Natural" 1 a n)] + [(zero? n) (identity-matrix (square-matrix-size a))] + [else + (let: loop : (Matrix Number) ([n : Positive-Integer n]) + (cond [(= n 1) a] + [(= n 2) (matrix* a a)] + [(even? n) (let ([a^n/2 (matrix-expt a (quotient n 2))]) + (matrix* a^n/2 a^n/2))] + [else (matrix* a (matrix-expt a (sub1 n)))]))])) diff --git a/collects/math/private/matrix/matrix-multiply.rkt b/collects/math/private/matrix/matrix-multiply.rkt deleted file mode 100644 index 51947e15db..0000000000 --- a/collects/math/private/matrix/matrix-multiply.rkt +++ /dev/null @@ -1,60 +0,0 @@ -#lang typed/racket - -(require "../unsafe.rkt" - "../../array.rkt" - "matrix-types.rkt") - -(provide matrix*) - -;; The `make-matrix-*' operators have to be macros; see ../array/array-pointwise.rkt for an -;; explanation. - -#;(: make-matrix-multiply (All (A) (Symbol - ((Array A) Integer -> (Array A)) - ((Array A) (Array A) -> (Array A)) - -> ((Array A) (Array A) -> (Array A))))) -(define-syntax-rule (make-matrix-multiply name array-axis-sum array*) - (λ (arr brr) - (unless (array-matrix? arr) (raise-type-error name "matrix" 0 arr brr)) - (unless (array-matrix? brr) (raise-type-error name "matrix" 1 arr brr)) - (match-define (vector ad0 ad1) (array-shape arr)) - (match-define (vector bd0 bd1) (array-shape brr)) - (unless (= ad1 bd0) - (error name - "1st argument column size and 2nd argument row size are not equal; given ~e and ~e" - arr brr)) - ;; Get strict versions of both because each element in both is evaluated multiple times - (let ([arr (array->mutable-array arr)] - [brr (array->mutable-array brr)]) - ;; This next part could be done with array-permute, but it's much slower that way - (define avs (mutable-array-data arr)) - (define bvs (mutable-array-data brr)) - ;; Extend arr in the center dimension - (define: ds-ext : (Vectorof Index) (vector ad0 bd1 ad1)) - (define arr-ext - (unsafe-build-array - ds-ext (λ: ([js : (Vectorof Index)]) - (define j0 (unsafe-vector-ref js 0)) - (define j1 (unsafe-vector-ref js 2)) - ;(unsafe-array-ref* arr j0 j1) [twice as slow] - (unsafe-vector-ref avs (unsafe-fx+ j1 (unsafe-fx* j0 ad1)))))) - ;; Transpose brr and extend in the leftmost dimension - ;; Note that ds-ext = (vector ad0 bd1 bd0) because bd0 = ad1 - (define brr-ext - (unsafe-build-array - ds-ext (λ: ([js : (Vectorof Index)]) - (define j0 (unsafe-vector-ref js 2)) - (define j1 (unsafe-vector-ref js 1)) - ;(unsafe-array-ref* brr j0 j1) [twice as slow] - (unsafe-vector-ref bvs (unsafe-fx+ j1 (unsafe-fx* j0 bd1)))))) - (array-axis-sum (array* arr-ext brr-ext) 2)))) - -;; --------------------------------------------------------------------------------------------------- - -(: matrix* (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real)) - ((Matrix Number) (Matrix Number) -> (Matrix Number)))) -(define matrix* (make-matrix-multiply 'matrix* array-axis-sum array*)) - -;(: matrix-fl* ((Array Float) (Array Float) -> (Array Float))) -;(define matrix-fl* (make-matrix-multiply 'matrix-fl* array-axis-flsum array-fl*)) - diff --git a/collects/math/private/matrix/matrix-operations.rkt b/collects/math/private/matrix/matrix-operations.rkt index 5cff470d8d..f620b9bcef 100644 --- a/collects/math/private/matrix/matrix-operations.rkt +++ b/collects/math/private/matrix/matrix-operations.rkt @@ -1,14 +1,16 @@ #lang typed/racket/base -(require math/array +(require racket/list + math/array (only-in typed/racket conjugate) "../unsafe.rkt" "matrix-types.rkt" "matrix-constructors.rkt" - "matrix-pointwise.rkt" + "matrix-arithmetic.rkt" + "matrix-basic.rkt" + "matrix-column.rkt" (for-syntax racket)) - ; TODO: ; 1. compute null space from QR factorization ; (better numerical stability than from Gauss elimnation) @@ -25,21 +27,6 @@ ; But TR has problems with #:when so what is the proper expansion ? (provide - ; basic - matrix-ref - matrix-scale - matrix-row-vector? - matrix-column-vector? - matrix/dim ; construct - matrix-augment ; horizontally - matrix-stack ; vertically - matrix-block-diagonal - ; norms - matrix-norm - ; operators - matrix-transpose - matrix-conjugate - matrix-hermitian matrix-inverse ; row and column matrix-scale-row @@ -56,7 +43,6 @@ matrix-rank matrix-nullity matrix-determinant - matrix-trace ; spaces ;matrix-column+null-space ; solvers @@ -64,17 +50,6 @@ matrix-solve-many ; spaces matrix-column-space - ; column vectors - column ; construct - unit-column - result-column ; convert to lazy - column-dimension - column-dot - column-norm - column-projection - column-normalize - scale-column - column+ ; projection projection-on-orthogonal-basis projection-on-orthonormal-basis @@ -84,70 +59,8 @@ ; factorization matrix-lu matrix-qr - ; comprehensions - for/matrix: - for*/matrix: - for/matrix-sum: - ; sequences - in-row - in-column - ; special matrices - vandermonde-matrix ) -;;; -;;; Basic -;;; - -(: matrix-ref : (Matrix Number) Integer Integer -> Number) -(define (matrix-ref M i j) - ((inst array-ref Number) M (vector i j))) - -(: matrix-scale : Number (Matrix Number) -> (Matrix Number)) -(define (matrix-scale s a) - (array-scale a s)) - -(: matrix-row-vector? : (Matrix Number) -> Boolean) -(define (matrix-row-vector? a) - (= (matrix-row-dimension a) 1)) - -(: matrix-column-vector? : (Matrix Number) -> Boolean) -(define (matrix-column-vector? a) - (= (matrix-column-dimension a) 1)) - - -;;; -;;; Norms -;;; - -(: matrix-norm : (Matrix Number) -> Real) -(define (matrix-norm a) - (define n - (sqrt - (array-ref - (array-axis-sum - (array-axis-sum - (matrix.sqr (matrix.magnitude a)) 0) 0) - '#()))) - (assert n real?)) - -;;; -;;; Operators -;;; - -(: matrix-transpose : (Matrix Number) -> (Matrix Number)) -(define (matrix-transpose a) - (array-axis-swap a 0 1)) - -(: matrix-conjugate : (Matrix Number) -> (Matrix Number)) -(define (matrix-conjugate a) - (array-conjugate a)) - -(: matrix-hermitian : (Matrix Number) -> (Matrix Number)) -(define (matrix-hermitian a) - (matrix-transpose - (array-conjugate a))) - ;;; ;;; Row and column ;;; @@ -296,7 +209,7 @@ ((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer))) ((Matrix Number) -> (Values (Matrix Number) (Listof Integer))))) (define (matrix-gauss-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t]) - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (: loop : (Integer Integer (Matrix Number) Integer (Listof Integer) -> (Values (Matrix Number) (Listof Integer)))) (define (loop i j ; i from 0 to m @@ -362,20 +275,20 @@ ; TODO: Use QR or SVD instead for inexact matrices ; See answer: http://scicomp.stackexchange.com/questions/1861/understanding-how-numpy-does-svd ; rank = dimension of column space = dimension of row space - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (define-values (_ cols-without-pivot) (matrix-gauss-eliminate M)) (- n (length cols-without-pivot))) (: matrix-nullity : (Matrix Number) -> Integer) (define (matrix-nullity M) ; nullity = dimension of null space - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (define-values (_ cols-without-pivot) (matrix-gauss-eliminate M)) (length cols-without-pivot)) (: matrix-determinant : (Matrix Number) -> Number) (define (matrix-determinant M) - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (cond [(= m 1) (matrix-ref M 0 0)] [(= m 2) (let ([a (matrix-ref M 0 0)] @@ -406,18 +319,12 @@ (set! product (* product (matrix-ref M i i)))) product))])) -(: matrix-trace : (Matrix Number) -> Number) -(define (matrix-trace M) - (define-values (m n) (matrix-dimensions M)) - (for/sum: : Number ([i (in-range 0 m)]) - (matrix-ref M i i))) - (: matrix-column-space : (Matrix Number) -> (Listof (Matrix Number))) ; Returns ; 1) a list of column vectors spanning the column space ; 2) a list of column vectors spanning the null space (define (matrix-column-space M) - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (: M1 (Matrix Number)) (: cols-without-pivot (Listof Integer)) (define-values (M1 cols-without-pivot) (matrix-gauss-eliminate M #t)) @@ -426,7 +333,7 @@ (for/list: ([i : Index n] #:when (not (member i cols-without-pivot))) - (matrix-column M1 i))) + (matrix-col M1 i))) column-space) (: matrix-row-echelon-form : @@ -442,7 +349,7 @@ ((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer))) ((Matrix Number) -> (Values (Matrix Number) (Listof Integer))))) (define (matrix-gauss-jordan-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t]) - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (: loop : (Integer Integer (Matrix Number) Integer (Listof Integer) -> (Values (Matrix Number) (Listof Integer)))) (define (loop i j ; i from 0 to m @@ -512,23 +419,11 @@ (let-values ([(M wp) (matrix-gauss-jordan-eliminate M unitize-pivot-row?)]) M)) -(: matrix-augment : (Matrix Number) (Matrix Number) * -> (Matrix Number)) -(define (matrix-augment a . as) - (array-append* (cons a as) 1)) - -(: matrix-stack : (Matrix Number) * -> (Matrix Number)) -(define (matrix-stack . as) - (if (null? as) - (error 'matrix-stack - "expected non-empty list of matrices") - (array-append* as 0))) - - (: matrix-inverse : (Matrix Number) -> (Matrix Number)) (define (matrix-inverse M) - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (unless (= m n) (error 'matrix-inverse "matrix not square")) - (let ([MI (matrix-augment M (identity-matrix m))]) + (let ([MI (matrix-augment (list M (identity-matrix m)))]) (define 2m (* 2 m)) (if (index? 2m) (submatrix (matrix-reduced-row-echelon-form MI #t) @@ -539,8 +434,8 @@ ; Return a column-vector x such that Mx = b. ; If no such vector exists return #f. (define (matrix-solve M b) - (define-values (m n) (matrix-dimensions M)) - (define-values (s t) (matrix-dimensions b)) + (define-values (m n) (matrix-shape M)) + (define-values (s t) (matrix-shape b)) (define m+1 (+ m 1)) (cond [(not (= t 1)) (error 'matrix-solve "expected column vector (i.e. r x 1 - matrix), got: ~a " b)] @@ -548,18 +443,14 @@ [(index? m+1) (submatrix (matrix-reduced-row-echelon-form - (matrix-augment M b) #t) + (matrix-augment (list M b)) #t) (in-range 0 m) (in-range m m+1))] [else (error 'matrix-solve "internatl error")])) (: matrix-solve-many : (Matrix Number) (Listof (Matrix Number)) -> (Matrix Number)) (define (matrix-solve-many M bs) - ; TODO: Rewrite matrix-augment* to use array-append when it is ready - (: matrix-augment* : (Listof (Matrix Number)) -> (Matrix Number)) - (define (matrix-augment* vs) - (foldl matrix-augment (car vs) (cdr vs))) - (define-values (m n) (matrix-dimensions M)) - (define-values (s t) (matrix-dimensions (car bs))) + (define-values (m n) (matrix-shape M)) + (define-values (s t) (matrix-shape (car bs))) (define k (length bs)) (define m+1 (+ m 1)) (define m+k (+ m k)) @@ -567,8 +458,8 @@ [(not (= t 1)) (error 'matrix-solve-many "expected column vector (i.e. r x 1 - matrix), got: ~a " (car bs))] [(not (= m s)) (error 'matrix-solve-many "expected column vectors with same number of rows as the matrix")] [(and (index? m+1) (index? m+k)) - (define bs-as-matrix (matrix-augment* bs)) - (define MB (matrix-augment M bs-as-matrix)) + (define bs-as-matrix (matrix-augment bs)) + (define MB (matrix-augment (list M bs-as-matrix))) (define reduced-MB (matrix-reduced-row-echelon-form MB #t)) (submatrix reduced-MB (in-range 0 m+k) @@ -584,7 +475,7 @@ (: matrix-lu : (Matrix Number) -> (U False (List (Matrix Number) (Matrix Number)))) (define (matrix-lu M) - (define-values (m _) (matrix-dimensions M)) + (define-values (m _) (matrix-shape M)) (define: ms : (Listof Number) '()) (define V (let/ec: return : (U False (Matrix Number)) @@ -639,111 +530,6 @@ (list L V)))) -(: column-dimension : (Column Number) -> Index) -(define (column-dimension v) - (if (vector? v) - (unsafe-vector-length v) - (matrix-row-dimension v))) - - -(: unsafe-column->vector : (Column Number) -> (Vectorof Number)) -(define (unsafe-column->vector v) - (if (vector? v) v - (let () - (define-values (m n) (matrix-dimensions v)) - (if (= n 1) - (mutable-array-data (array->mutable-array v)) - (error 'unsafe-column->vector - "expected a column (vector or mx1 matrix), got ~a" v))))) - -(: vector->column : (Vectorof Number) -> (Result-Column Number)) -(define (vector->column v) - (define m (vector-length v)) - (flat-vector->matrix m 1 v)) - -(: column : Number * -> (Result-Column Number)) -(define (column . xs) - (vector->column - (list->vector xs))) - -(: result-column : (Column Number) -> (Result-Column Number)) -(define (result-column c) - (if (vector? c) - (vector->column c) - c)) - -(: scale-column : Number (Column Number) -> (Result-Column Number)) -(define (scale-column s a) - (if (vector? a) - (let*: ([n (vector-length a)] - [v : (Vectorof Number) (make-vector n 0)]) - (for: ([i (in-range 0 n)] - [x : Number (in-vector a)]) - (vector-set! v i (* s x))) - (vector->column v)) - (matrix-scale s a))) - -(: column+ : (Column Number) (Column Number) -> (Result-Column Number)) -(define (column+ v w) - (cond [(and (vector? v) (vector? w)) - (let ([n (vector-length v)] - [m (vector-length w)]) - (unless (= m n) - (error 'column+ - "expected two column vectors of the same length, got ~a and ~a" v w)) - (define: v+w : (Vectorof Number) (make-vector n 0)) - (for: ([i (in-range 0 n)] - [x : Number (in-vector v)] - [y : Number (in-vector w)]) - (vector-set! v+w i (+ x y))) - (result-column v+w))] - [else - (unless (= (column-dimension v) (column-dimension w)) - (error 'column+ - "expected two column vectors of the same length, got ~a and ~a" v w)) - (matrix+ (result-column v) (result-column w))])) - - -(: column-dot : (Column Number) (Column Number) -> Number) -(define (column-dot c d) - (define v (unsafe-column->vector c)) - (define w (unsafe-column->vector d)) - (define m (column-dimension v)) - (define s (column-dimension w)) - (cond - [(not (= m s)) (error 'column-dot - "expected two mx1 matrices with same number of rows, got ~a and ~a" - c d)] - [else - (for/sum: : Number ([i (in-range 0 m)]) - (assert i index?) - ; Note: If d is a vector of reals, - ; then the conjugate is a no-op - (* (unsafe-vector-ref v i) - (conjugate (unsafe-vector-ref w i))))])) - -(: column-norm : (Column Number) -> Real) -(define (column-norm v) - (define norm (sqrt (column-dot v v))) - (assert norm real?)) - - -(: column-projection : (Column Number) (Column Number) -> (Result-Column Number)) -; (column-projection v w) -; Return the projection og vector v on vector w. -(define (column-projection v w) - (let ([w.w (column-dot w w)]) - (if (zero? w.w) - (error 'column-projection "projection on the zero vector not defined") - (matrix-scale (/ (column-dot v w) w.w) (result-column w))))) - -(: column-projection-on-unit : (Column Number) (Column Number) -> (Result-Column Number)) -; (column-projection-on-unit v w) -; Return the projection og vector v on a unit vector w. -(define (column-projection-on-unit v w) - (matrix-scale (column-dot v w) (result-column w))) - - (: projection-on-orthogonal-basis : (Column Number) (Listof (Column Number)) -> (Result-Column Number)) ; (projection-on-orthogonal-basis v bs) @@ -754,21 +540,10 @@ (if (null? bs) (error 'projection-on-orthogonal-basis "received empty list of basis vectors") - (for/matrix-sum: : Number ([b (in-list bs)]) - (column-projection v (result-column b))))) + (matrix-sum (map (λ: ([b : (Column Number)]) + (column-project v (->col-matrix b))) + bs)))) -; #;(for/matrix-sum ([b bs]) -; (matrix-scale (column-dot v b) b)) -; (define: sum : (U False (Result-Column Number)) #f) -; (for ([b1 (in-list bs)]) -; (define: b : (Result-Column Number) (result-column b1)) -; (cond [(not sum) (set! sum (column-projection v b))] -; [else (set! sum (matrix+ (assert sum) (column-projection v b)))])) -; (cond [sum (assert sum)] -; [else (error 'projection-on-orthogonal-basis -; "received empty list of basis vectors")]) - - ; (projection-on-orthonormal-basis v bs) ; Project the vector v on the orthonormal basis vectors in bs. ; The basis bs must be either the column vectors of a matrix @@ -776,35 +551,29 @@ (: projection-on-orthonormal-basis : (Column Number) (Listof (Column Number)) -> (Result-Column Number)) (define (projection-on-orthonormal-basis v bs) - #;(for/matrix-sum ([b bs]) (matrix-scale (column-dot v b) b)) + #;(for/matrix-sum ([b bs]) (matrix-scale b (column-dot v b))) (define: sum : (U False (Result-Column Number)) #f) (for ([b1 (in-list bs)]) - (define: b : (Result-Column Number) (result-column b1)) - (cond [(not sum) (set! sum (column-projection-on-unit v b))] - [else (set! sum (matrix+ (assert sum) (column-projection-on-unit v b)))])) + (define: b : (Result-Column Number) (->col-matrix b1)) + (cond [(not sum) (set! sum (column-project/unit v b))] + [else (set! sum (array+ (assert sum) (column-project/unit v b)))])) (cond [sum (assert sum)] - [else (error 'projection-on-orthogonal-basis + [else (error 'projection-on-orthonormal-basis "received empty list of basis vectors")])) - -(: zero-column-vector? : (Matrix Number) -> Boolean) -(define (zero-column-vector? v) - (define-values (m n) (matrix-dimensions v)) - (for/and: ([i (in-range 0 m)]) - (zero? (matrix-ref v i 0)))) - (: gram-schmidt-orthogonal : (Listof (Column Number)) -> (Listof (Result-Column Number))) ; (gram-schmidt-orthogonal ws) ; Given a list ws of column vectors, produce ; an orthogonal basis for the span of the ; vectors in ws. (define (gram-schmidt-orthogonal ws1) - (define ws (map result-column ws1)) + (define ws (map (λ: ([w : (Column Number)]) (->col-matrix w)) ws1)) (cond [(null? ws) '()] [(null? (cdr ws)) (list (car ws))] [else - (: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number)) -> (Listof (Result-Column Number))) + (: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number)) + -> (Listof (Result-Column Number))) (define (loop vs ws) (cond [(null? ws) vs] [else @@ -812,84 +581,32 @@ (let ([w-proj (projection-on-orthogonal-basis w vs)]) ; Note: We project onto vs (not on the original ws) ; in order to get numerical stability. - (let ([w-minus-proj (matrix- w w-proj)]) - (if (zero-column-vector? w-minus-proj) + (let ([w-minus-proj (array-strict (array- w w-proj))]) + (if (zero-matrix? w-minus-proj) (loop vs (cdr ws)) ; w in span{vs} => omit it - (loop (cons (matrix- w w-proj) vs) (cdr ws)))))])) + (loop (cons w-minus-proj vs) (cdr ws)))))])) (reverse (loop (list (car ws)) (cdr ws)))])) - - -(: column-normalize : (Column Number) -> (Result-Column Number)) -; (column-vector-normalize v) -; Return unit vector with same direction as v. -; If v is the zero vector, the zero vector is returned. -(define (column-normalize w) - (let ([norm (column-norm w)] - [w (result-column w)]) - (cond [(zero? norm) w] - [else (matrix-scale (/ norm) w)]))) - (: gram-schmidt-orthonormal : (Listof (Column Number)) -> (Listof (Result-Column Number))) ; (gram-schmidt-orthonormal ws) ; Given a list ws of column vectors, produce ; an orthonormal basis for the span of the ; vectors in ws. (define (gram-schmidt-orthonormal ws) - (map column-normalize - (gram-schmidt-orthogonal ws))) + (map column-normalize (gram-schmidt-orthogonal ws))) (: projection-on-subspace : (Column Number) (Listof (Column Number)) -> (Result-Column Number)) ; (projection-on-subspace v ws) ; Returns the projection of v on span{w_i}, w_i in ws. (define (projection-on-subspace v ws) - (projection-on-orthogonal-basis - v (gram-schmidt-orthogonal ws))) - -(: unit-column : Integer Integer -> (Result-Column Number)) -(define (unit-column m i) - (cond - [(and (index? m) (index? i)) - (define v (make-vector m 0)) - (if (< i m) - (vector-set! v i 1) - (error 'unit-vector "dimension must be largest")) - (flat-vector->matrix m 1 v)] - [else - (error 'unit-vector "expected two indices")])) - - -(: take : (All (A) ((Listof A) Index -> (Listof A)))) -(define (take xs n) - (if (= n 0) - '() - (let ([n-1 (- n 1)]) - (if (index? n-1) - (cons (car xs) (take (cdr xs) n-1)) - (error 'take "can not take more elements than the length of the list"))))) - -; (list 'take (equal? (take (list 0 1 2 3 4) 2) '(0 1))) - -(: matrix->columns : (Matrix Number) -> (Listof (Matrix Number))) -(define (matrix->columns M) - (define-values (m n) (matrix-dimensions M)) - (for/list: : (Listof (Matrix Number)) - ([j (in-range 0 n)]) - (matrix-column M (assert j index?)))) - -(: matrix-augment* : (Listof (Matrix Number)) -> (Matrix Number)) -(define (matrix-augment* Ms) - (define MM (car Ms)) - (for: ([M (in-list (cdr Ms))]) - (set! MM (matrix-augment MM M))) - MM) + (projection-on-orthogonal-basis v (gram-schmidt-orthogonal ws))) (: extend-span-to-basis : (Listof (Matrix Number)) Integer -> (Listof (Matrix Number))) ; Extend the basis in vs to with rdimensional basis (define (extend-span-to-basis vs r) - (define-values (m n) (matrix-dimensions (car vs))) + (define-values (m n) (matrix-shape (car vs))) (: loop : (Listof (Matrix Number)) (Listof (Matrix Number)) Integer -> (Listof (Matrix Number))) (define (loop vs ws i) (if (>= i m) @@ -897,15 +614,15 @@ (let () (define ei (unit-column m i)) (define pi (projection-on-subspace ei vs)) - (if (matrix-all= ei pi) + (if (matrix= ei pi) (loop vs ws (+ i 1)) - (let ([w (matrix- ei pi)]) + (let ([w (array- ei pi)]) (loop (cons w vs) (cons w ws) (+ i 1))))))) (: norm> : (Matrix Number) (Matrix Number) -> Boolean) (define (norm> v w) (> (column-norm v) (column-norm w))) (if (index? r) - ((inst take (Matrix Number)) (sort (loop vs '() 0) norm>) r) + (take (sort (loop vs '() 0) norm>) r) (error 'extend-span-to-basis "expected index as second argument, got ~a" r))) (: matrix-qr : (Matrix Number) -> (Values (Matrix Number) (Matrix Number))) @@ -915,13 +632,13 @@ ; 2) columns of Q is are orthonormal ; 3) R is upper-triangular ; Note: columnspace(A)=columnspace(Q) ! - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (let* ([basis-for-column-space - (gram-schmidt-orthonormal (matrix->columns M))] + (gram-schmidt-orthonormal (matrix-cols M))] [extension (extend-span-to-basis basis-for-column-space (- n (length basis-for-column-space)))] - [Q (matrix-augment* + [Q (matrix-augment (append basis-for-column-space (map column-normalize extension)))] @@ -938,305 +655,5 @@ (set! sum (+ sum (* (matrix-ref Q k i) (matrix-ref M k j))))) (vector-set! v (+ (* i n) j) sum)))) - (flat-vector->matrix n n v))]) + (vector->matrix n n v))]) (values Q R))) - -(: matrix/dim : Integer Integer Number * -> (Matrix Number)) -; construct a mxn matrix with elements from the values xs -; the length of xs must be m*n -(define (matrix/dim m n . xs) - (cond [(and (index? m) (index? n)) - (flat-vector->matrix m n (list->vector xs))] - [else (error 'matrix/dim "expected two indices as dimensions, got ~a and ~a" m n)])) - -(: matrix-block-diagonal : (Listof (Matrix Number)) -> (Matrix Number)) -(define (matrix-block-diagonal as) - (define sum-m 0) - (define sum-n 0) - (define: ms : (Listof Index) '()) - (define: ns : (Listof Index) '()) - (for: ([a (in-list as)]) - (define-values (m n) (matrix-dimensions a)) - (set! sum-m (+ sum-m m)) - (set! sum-n (+ sum-n n)) - (set! ms (cons m ms)) - (set! ns (cons n ns))) - (set! ms (reverse ms)) - (set! ns (reverse ns)) - (: loop : (Listof (Matrix Number)) (Listof Index) (Listof Index) - (Listof (Matrix Number)) Integer -> (Matrix Number)) - (define (loop as ms ns rows left) - (cond [(null? as) (apply matrix-stack (reverse rows))] - [else - (define m (car ms)) - (define n (car ns)) - (define a (car as)) - (define row - (matrix-augment - ((inst make-matrix Number) m left 0) a (make-matrix m (- sum-n n left) 0))) - (loop (cdr as) (cdr ms) (cdr ns) (cons row rows) (+ left n))])) - (loop as ms ns '() 0)) - -(define-syntax (for/column: stx) - (syntax-case stx () - [(_ : type m-expr (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: flat-vector : (Vectorof Number) (make-vector m 0)) - (for: ([i (in-range m)] for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! flat-vector i x)) - (vector->column flat-vector)))])) - -(define-syntax (for/matrix: stx) - (syntax-case stx () - [(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (define: k : Index 0) - (for: ([i (in-range m*n)] for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v (+ (* n (remainder k m)) (quotient k m)) x) - (set! k (assert (+ k 1) index?))) - (flat-vector->matrix m n v)))] - [(_ : type m-expr n-expr (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (for: ([i (in-range m*n)] for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v i x)) - (flat-vector->matrix m n v)))])) - -(define-syntax (for*/matrix: stx) - (syntax-case stx () - [(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (define: k : Index 0) - (for*: (for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v (+ (* n (remainder k m)) (quotient k m)) x) - (set! k (assert (+ k 1) index?))) - (flat-vector->matrix m n v)))] - [(_ : type m-expr n-expr (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (define: i : Index 0) - (for*: (for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v i x) - (set! i (assert (+ i 1) index?))) - (flat-vector->matrix m n v)))])) - - - -(define-syntax (for/matrix-sum: stx) - (syntax-case stx () - [(_ : type (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: sum : (U False (Matrix Number)) #f) - (for: (for:-clause ...) - (define a (let () . defs+exprs)) - (set! sum (if sum (matrix+ (assert sum) a) a))) - (assert sum)))])) - -;;; -;;; SEQUENCES -;;; - -(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number)) -(define (in-row/proc M r) - (define-values (m n) (matrix-dimensions M)) - (make-do-sequence - (λ () - (values - ; pos->element - (λ: ([j : Index]) (matrix-ref M r j)) - ; next-pos - (λ: ([j : Index]) (assert (+ j 1) index?)) - ; initial-pos - 0 - ; continue-with-pos? - (λ: ([j : Index ]) (< j n)) - #f #f)))) - -(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number)) -(define (in-column/proc M s) - (define-values (m n) (matrix-dimensions M)) - (make-do-sequence - (λ () - (values - ; pos->element - (λ: ([i : Index]) (matrix-ref M i s)) - ; next-pos - (λ: ([i : Index]) (assert (+ i 1) index?)) - ; initial-pos - 0 - ; continue-with-pos? - (λ: ([i : Index]) (< i m)) - #f #f)))) - -; (in-row M i] -; Returns a sequence of all elements of row i, -; that is xi0, xi1, xi2, ... -(define-sequence-syntax in-row - (λ () #'in-row/proc) - (λ (stx) - (syntax-case stx () - [[(x) (_ M-expr r-expr)] - #'((x) - (:do-in - ([(M r n d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) - (values M1 r-expr rd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (array-matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? r) - (raise-type-error 'in-row "expected row number, got ~a" r)) - (unless (and (integer? r) (and (<= 0 r ) (< r n))) - (raise-type-error 'in-row "expected row number, got ~a" r))) - ([j 0]) - (< j n) - ([(x) (vector-ref d (+ (* r n) j))]) - #true - #true - [(+ j 1)]))] - [[(i x) (_ M-expr r-expr)] - #'((i x) - (:do-in - ([(M r n d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) - (values M1 r-expr rd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (array-matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? r) - (raise-type-error 'in-row "expected row number, got ~a" r))) - ([j 0]) - (< j n) - ([(x) (vector-ref d (+ (* r n) j))] - [(i) j]) - #true - #true - [(+ j 1)]))] - [[_ clause] (raise-syntax-error - 'in-row "expected (in-row )" #'clause #'clause)]))) - -; (in-column M j] -; Returns a sequence of all elements of column j, -; that is x0j, x1j, x2j, ... - -(define-sequence-syntax in-column - (λ () #'in-column/proc) - (λ (stx) - (syntax-case stx () - ; M-expr evaluates to column - [[(x) (_ M-expr)] - #'((x) - (:do-in - ([(M n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) - (values M1 rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (unless (array-matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - ([j 0]) - (< j n) - ([(x) (vector-ref d j)]) - #true - #true - [(+ j 1)]))] - ; M-expr evaluats to matrix, s-expr to the column index - [[(x) (_ M-expr s-expr)] - #'((x) - (:do-in - ([(M s n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) - (values M1 s-expr rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (array-matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? s) - (raise-type-error 'in-row "expected col number, got ~a" s)) - (unless (and (integer? s) (and (<= 0 s ) (< s m))) - (raise-type-error 'in-col "expected col number, got ~a" s))) - ([j 0]) - (< j m) - ([(x) (vector-ref d (+ (* j n) s))]) - #true - #true - [(+ j 1)]))] - [[(i x) (_ M-expr s-expr)] - #'((x) - (:do-in - ([(M s n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) - (values M1 s-expr rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (array-matrix? M) - (raise-type-error 'in-column "expected matrix, got ~a" M)) - (unless (integer? s) - (raise-type-error 'in-column "expected col number, got ~a" s)) - (unless (and (integer? s) (and (<= 0 s ) (< s m))) - (raise-type-error 'in-column "expected col number, got ~a" s))) - ([j 0]) - (< j m) - ([(x) (vector-ref d (+ (* j n) s))] - [(i) j]) - #true - #true - [(+ j 1)]))] - [[_ clause] (raise-syntax-error - 'in-column "expected (in-column )" #'clause #'clause)]))) - -(: vandermonde-matrix : (Listof Number) Integer -> (Matrix Number)) -(define (vandermonde-matrix xs n) - ; construct matrix M with M(i,j)=α_i^j ; where i and j begin from 0 ... - ; Inefficient version: - (cond - [(not (index? n)) - (error 'vandermonde-matrix "expected Index as second argument, got ~a" n)] - [else (define: m : Index (length xs)) - (define: αs : (Vectorof Number) (list->vector xs)) - (define: α^j : (Vectorof Number) (make-vector n 1)) - (for*/matrix: : Number m n #:column - ([j (in-range 0 n)] - [i (in-range 0 m)]) - (define αi^j (vector-ref α^j i)) - (define αi (vector-ref αs i )) - (vector-set! α^j i (* αi^j αi)) - αi^j)])) - diff --git a/collects/math/private/matrix/matrix-pointwise.rkt b/collects/math/private/matrix/matrix-pointwise.rkt deleted file mode 100644 index d224bbe87c..0000000000 --- a/collects/math/private/matrix/matrix-pointwise.rkt +++ /dev/null @@ -1,57 +0,0 @@ -#lang typed/racket - -(require "../unsafe.rkt" - "../../array.rkt" - "matrix-types.rkt") - -(provide matrix+ matrix- - matrix.sqr matrix.magnitude) - -;; The `make-matrix-*' operators have to be macros; see ../array/array-pointwise.rkt for an -;; explanation. - -#;(: make-matrix-pointwise1 (All (A) (Symbol - ((Array A) -> (Array A)) - -> ((Array A) -> (Array A))))) -(define-syntax-rule (make-matrix-pointwise1 name array-op1) - (λ (arr) - (unless (array-matrix? arr) (raise-type-error name "matrix" arr)) - (array-op1 arr))) - -#;(: make-matrix-pointwise2 (All (A) (Symbol - ((Array A) (Array A) -> (Array A)) - -> ((Array A) (Array A) -> (Array A))))) -(define-syntax-rule (make-matrix-pointwise2 name array-op2) - (λ (arr brr) - (unless (array-matrix? arr) (raise-type-error name "matrix" 0 arr brr)) - (unless (array-matrix? brr) (raise-type-error name "matrix" 1 arr brr)) - (array-op2 arr brr))) - -#;(: make-matrix-pointwise1/2 - (All (A) (Symbol - (case-> ((Array A) -> (Array A)) - ((Array A) (Array A) -> (Array A))) - -> - (case-> ((Array A) -> (Array A)) - ((Array A) (Array A) -> (Array A)))))) -(define-syntax-rule (make-matrix-pointwise1/2 name array-op1/2) - (case-lambda - [(arr) ((make-matrix-pointwise1 name array-op1/2) arr)] - [(arr brr) ((make-matrix-pointwise2 name array-op1/2) arr brr)])) - -;; --------------------------------------------------------------------------------------------------- - -(: matrix+ (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real)) - ((Matrix Number) (Matrix Number) -> (Matrix Number)))) -(: matrix- (case-> ((Matrix Real) -> (Matrix Real)) - ((Matrix Number) -> (Matrix Number)) - ((Matrix Real) (Matrix Real) -> (Matrix Real)) - ((Matrix Number) (Matrix Number) -> (Matrix Number)))) -(: matrix.sqr (case-> ((Matrix Real) -> (Matrix Real)) - ((Matrix Number) -> (Matrix Number)))) -(: matrix.magnitude ((Matrix Number) -> (Matrix Real))) - -(define matrix+ (make-matrix-pointwise2 'matrix+ array+)) -(define matrix- (make-matrix-pointwise1/2 'matrix- array-)) -(define matrix.sqr (make-matrix-pointwise1 'matrix.sqr array-sqr)) -(define matrix.magnitude (make-matrix-pointwise1 'matrix.magnitude array-magnitude)) diff --git a/collects/math/private/matrix/matrix-sequences.rkt b/collects/math/private/matrix/matrix-sequences.rkt index 9d164c9a20..a0c2be15cf 100644 --- a/collects/math/private/matrix/matrix-sequences.rkt +++ b/collects/math/private/matrix/matrix-sequences.rkt @@ -1,100 +1,16 @@ #lang racket -(provide for/matrix - for*/matrix - in-row + +(provide in-row in-column) (require math/array - (except-in math/matrix in-row in-column)) - -;;; COMPREHENSIONS - -; (for/matrix m n (clause ...) . defs+exprs) -; Return an m x n matrix with elements from the last expr. -; The first n values produced becomes the first row. -; The next n values becomes the second row and so on. -; The bindings in clauses run in parallel. -(define-syntax (for/matrix stx) - (syntax-case stx () - [(_ m-expr n-expr (clause ...) . defs+exprs) - (syntax/loc stx - (let ([m m-expr] [n n-expr]) - (define flat-vector - (for/vector #:length (* m n) - (clause ...) . defs+exprs)) - ; TODO (efficiency): Use a flat-vector->array instead - (flat-vector->matrix m n flat-vector)))])) - -; (for*/matrix m n (clause ...) . defs+exprs) -; Return an m x n matrix with elements from the last expr. -; The first n values produced becomes the first row. -; The next n values becomes the second row and so on. -; The bindings in clauses run nested. -; (for*/matrix m n #:column (clause ...) . defs+exprs) -; Return an m x n matrix with elements from the last expr. -; The first m values produced becomes the first column. -; The next m values becomes the second column and so on. -; The bindings in clauses run nested. -(define-syntax (for*/matrix stx) - (syntax-case stx () - [(_ m-expr n-expr #:column (clause ...) . defs+exprs) - (syntax/loc stx - (let* ([m m-expr] - [n n-expr] - [v (make-vector (* m n) 0)] - [w (for*/vector #:length (* m n) (clause ...) . defs+exprs)]) - (for* ([i (in-range m)] [j (in-range n)]) - (vector-set! v (+ (* i n) j) - (vector-ref w (+ (* j m) i)))) - (flat-vector->matrix m n v)))] - [(_ m-expr n-expr (clause ...) . defs+exprs) - (syntax/loc stx - (let ([m m-expr] [n n-expr]) - (flat-vector->matrix - m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))])) - -; TODO: The following is uncommented until matrix+ can be imported. - -; (for/matrix-sum (clause ...) . defs+exprs) -; Return the matrix sum of all matrices produced by the last expr. -; The bindings in clauses are parallel. - -;(define-syntax (for/matrix-sum stx) -; (syntax-case stx () -; [(_ (clause ...) . defs+exprs) -; (syntax/loc stx -; (let ([ms (for/list (clause ...) . defs+exprs)]) -; (foldl matrix+ (first ms) (rest ms))))])) -; -;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))]) -; (for/matrix-sum ([i 3]) M)) -; (flat-vector->matrix 2 2 #(3 6 9 12))) -;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))]) -; (for/matrix-sum ([i 2] [j 2]) M)) -; (flat-vector->matrix 2 2 #(2 4 6 8))) - -; (for*/matrix-sum (clause ...) . defs+exprs) -; Return the matrix sum of all matrices produced by the last expr. -; The bindings in clauses are in nested. - -;(define-syntax (for*/matrix-sum stx) -; (syntax-case stx () -; [(_ (clause ...) . defs+exprs) -; (syntax/loc stx -; (let ([ms (for*/list (clause ...) . defs+exprs)]) -; (foldl matrix+ (first ms) (rest ms))))])) -; -;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))]) -; (for*/matrix-sum ([i 2] [j 2]) M)) -; (flat-vector->matrix 2 2 #(4 8 12 16))) - - -;;; -;;; SEQUENCES -;;; + "matrix-types.rkt" + "matrix-basic.rkt" + "matrix-constructors.rkt" + ) (define (in-row/proc M r) - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (make-do-sequence (λ () (values @@ -120,12 +36,12 @@ (:do-in ([(M r n d) (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) + (define-values (rd cd) (matrix-shape M1)) (values M1 r-expr rd (mutable-array-data (array->mutable-array M1))))]) (begin - (unless (array-matrix? M) + (unless (matrix? M) (raise-type-error 'in-row "expected matrix, got ~a" M)) (unless (integer? r) (raise-type-error 'in-row "expected row number, got ~a" r)) @@ -142,12 +58,12 @@ (:do-in ([(M r n d) (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) + (define-values (rd cd) (matrix-shape M1)) (values M1 r-expr rd (mutable-array-data (array->mutable-array M1))))]) (begin - (unless (array-matrix? M) + (unless (matrix? M) (raise-type-error 'in-row "expected matrix, got ~a" M)) (unless (integer? r) (raise-type-error 'in-row "expected row number, got ~a" r))) @@ -167,7 +83,7 @@ (define (in-column/proc M s) - (define-values (m n) (matrix-dimensions M)) + (define-values (m n) (matrix-shape M)) (make-do-sequence (λ () (values @@ -190,12 +106,12 @@ (:do-in ([(M s n m d) (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) + (define-values (rd cd) (matrix-shape M1)) (values M1 s-expr rd cd (mutable-array-data (array->mutable-array M1))))]) (begin - (unless (array-matrix? M) + (unless (matrix? M) (raise-type-error 'in-row "expected matrix, got ~a" M)) (unless (integer? s) (raise-type-error 'in-row "expected col number, got ~a" s)) @@ -212,12 +128,12 @@ (:do-in ([(M s n m d) (let ([M1 M-expr]) - (define-values (rd cd) (matrix-dimensions M1)) + (define-values (rd cd) (matrix-shape M1)) (values M1 s-expr rd cd (mutable-array-data (array->mutable-array M1))))]) (begin - (unless (array-matrix? M) + (unless (matrix? M) (raise-type-error 'in-column "expected matrix, got ~a" M)) (unless (integer? s) (raise-type-error 'in-column "expected col number, got ~a" s)) @@ -233,28 +149,192 @@ [[_ clause] (raise-syntax-error 'in-column "expected (in-column )" #'clause #'clause)]))) +(define-syntax (for/matrix-sum: stx) + (syntax-case stx () + [(_ : type (for:-clause ...) . defs+exprs) + (syntax/loc stx + (let () + (define: sum : (U False (Matrix Number)) #f) + (for: (for:-clause ...) + (define a (let () . defs+exprs)) + (set! sum (if sum (array+ (assert sum) a) a))) + (assert sum)))])) +#| +;;; +;;; SEQUENCES +;;; + +(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number)) +(define (in-row/proc M r) + (define-values (m n) (matrix-shape M)) + (make-do-sequence + (λ () + (values + ; pos->element + (λ: ([j : Index]) (matrix-ref M r j)) + ; next-pos + (λ: ([j : Index]) (assert (+ j 1) index?)) + ; initial-pos + 0 + ; continue-with-pos? + (λ: ([j : Index ]) (< j n)) + #f #f)))) + +(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number)) +(define (in-column/proc M s) + (define-values (m n) (matrix-shape M)) + (make-do-sequence + (λ () + (values + ; pos->element + (λ: ([i : Index]) (matrix-ref M i s)) + ; next-pos + (λ: ([i : Index]) (assert (+ i 1) index?)) + ; initial-pos + 0 + ; continue-with-pos? + (λ: ([i : Index]) (< i m)) + #f #f)))) + +; (in-row M i] +; Returns a sequence of all elements of row i, +; that is xi0, xi1, xi2, ... +(define-sequence-syntax in-row + (λ () #'in-row/proc) + (λ (stx) + (syntax-case stx () + [[(x) (_ M-expr r-expr)] + #'((x) + (:do-in + ([(M r n d) + (let ([M1 M-expr]) + (define-values (rd cd) (matrix-shape M1)) + (values M1 r-expr rd + (mutable-array-data + (array->mutable-array M1))))]) + (begin + (unless (matrix? M) + (raise-type-error 'in-row "expected matrix, got ~a" M)) + (unless (integer? r) + (raise-type-error 'in-row "expected row number, got ~a" r)) + (unless (and (integer? r) (and (<= 0 r ) (< r n))) + (raise-type-error 'in-row "expected row number, got ~a" r))) + ([j 0]) + (< j n) + ([(x) (vector-ref d (+ (* r n) j))]) + #true + #true + [(+ j 1)]))] + [[(i x) (_ M-expr r-expr)] + #'((i x) + (:do-in + ([(M r n d) + (let ([M1 M-expr]) + (define-values (rd cd) (matrix-shape M1)) + (values M1 r-expr rd + (mutable-array-data + (array->mutable-array M1))))]) + (begin + (unless (matrix? M) + (raise-type-error 'in-row "expected matrix, got ~a" M)) + (unless (integer? r) + (raise-type-error 'in-row "expected row number, got ~a" r))) + ([j 0]) + (< j n) + ([(x) (vector-ref d (+ (* r n) j))] + [(i) j]) + #true + #true + [(+ j 1)]))] + [[_ clause] (raise-syntax-error + 'in-row "expected (in-row )" #'clause #'clause)]))) + +; (in-column M j] +; Returns a sequence of all elements of column j, +; that is x0j, x1j, x2j, ... + +(define-sequence-syntax in-column + (λ () #'in-column/proc) + (λ (stx) + (syntax-case stx () + ; M-expr evaluates to column + [[(x) (_ M-expr)] + #'((x) + (:do-in + ([(M n m d) + (let ([M1 M-expr]) + (define-values (rd cd) (matrix-shape M1)) + (values M1 rd cd + (mutable-array-data + (array->mutable-array M1))))]) + (unless (matrix? M) + (raise-type-error 'in-row "expected matrix, got ~a" M)) + ([j 0]) + (< j n) + ([(x) (vector-ref d j)]) + #true + #true + [(+ j 1)]))] + ; M-expr evaluats to matrix, s-expr to the column index + [[(x) (_ M-expr s-expr)] + #'((x) + (:do-in + ([(M s n m d) + (let ([M1 M-expr]) + (define-values (rd cd) (matrix-shape M1)) + (values M1 s-expr rd cd + (mutable-array-data + (array->mutable-array M1))))]) + (begin + (unless (matrix? M) + (raise-type-error 'in-row "expected matrix, got ~a" M)) + (unless (integer? s) + (raise-type-error 'in-row "expected col number, got ~a" s)) + (unless (and (integer? s) (and (<= 0 s ) (< s m))) + (raise-type-error 'in-col "expected col number, got ~a" s))) + ([j 0]) + (< j m) + ([(x) (vector-ref d (+ (* j n) s))]) + #true + #true + [(+ j 1)]))] + [[(i x) (_ M-expr s-expr)] + #'((x) + (:do-in + ([(M s n m d) + (let ([M1 M-expr]) + (define-values (rd cd) (matrix-shape M1)) + (values M1 s-expr rd cd + (mutable-array-data + (array->mutable-array M1))))]) + (begin + (unless (matrix? M) + (raise-type-error 'in-column "expected matrix, got ~a" M)) + (unless (integer? s) + (raise-type-error 'in-column "expected col number, got ~a" s)) + (unless (and (integer? s) (and (<= 0 s ) (< s m))) + (raise-type-error 'in-column "expected col number, got ~a" s))) + ([j 0]) + (< j m) + ([(x) (vector-ref d (+ (* j n) s))] + [(i) j]) + #true + #true + [(+ j 1)]))] + [[_ clause] (raise-syntax-error + 'in-column "expected (in-column )" #'clause #'clause)]))) +|# +#; (module* test #f - (require (except-in math/matrix in-row in-column) - rackunit) + (require rackunit) ; "matrix-sequences.rkt" - ; These work in racket not in typed racket - (check-equal? (matrix->list (for*/matrix 2 3 ([i 2] [j 3]) (+ i j))) - '[[0 1 2] [1 2 3]]) - (check-equal? (matrix->list (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j))) - '[[0 2 2] [1 1 3]]) - (check-equal? (matrix->list (for*/matrix 2 2 #:column ([i 4]) i)) - '[[0 2] [1 3]]) - (check-equal? (matrix->list (for/matrix 2 2 ([i 4]) i)) - '[[0 1] [2 3]]) - (check-equal? (matrix->list (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j))) - '[[6 8 10] [12 14 16]]) - (check-equal? (for/list ([x (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x) + (check-equal? (for/list ([x (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) x) '(3 4)) - (check-equal? (for/list ([(i x) (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) + (check-equal? (for/list ([(i x) (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) (list i x)) '((0 3) (1 4))) - (check-equal? (for/list ([x (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x) + (check-equal? (for/list ([x (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) x) '(2 4)) - (check-equal? (for/list ([(i x) (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) + (check-equal? (for/list ([(i x) (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) (list i x)) '((0 2) (1 4)))) diff --git a/collects/math/private/matrix/matrix-types.rkt b/collects/math/private/matrix/matrix-types.rkt index 4bb1332680..ae1950ff04 100644 --- a/collects/math/private/matrix/matrix-types.rkt +++ b/collects/math/private/matrix/matrix-types.rkt @@ -1,16 +1,21 @@ -#lang typed/racket +#lang typed/racket/base + +(require "../array/array-struct.rkt" + "../array/array-fold.rkt" + "../array/array-pointwise.rkt" + "../unsafe.rkt") + (provide Matrix Column Result-Column Column-Matrix - array-matrix? - matrix-all= - square-matrix? + matrix? + square-matrix? + row-matrix? + col-matrix? + matrix-shape square-matrix-size - matrix-dimensions - matrix-row-dimension - matrix-column-dimension) - -(require math/array) + matrix-num-rows + matrix-num-cols) (define-type (Matrix A) (Array A)) ; matrices are represented as arrays @@ -22,39 +27,44 @@ (define-type (Column-Matrix A) (Matrix A)) ; a column vector represented as a matrix -(: array-matrix? (All (A) ((Array A) -> Boolean))) -(define (array-matrix? x) - (= 2 (array-dims x))) +(define matrix? + (plambda: (A) ([arr : (Array A)]) + (and (> (array-size arr) 0) + (= (array-dims arr) 2)))) -(: square-matrix? : (All (A) (Matrix A) -> Boolean)) -(define (square-matrix? a) - (and (array-matrix? a) - (let ([sh (array-shape a)]) - (= (vector-ref sh 0) (vector-ref sh 1))))) +(define square-matrix? + (plambda: (A) ([arr : (Array A)]) + (and (matrix? arr) + (let ([sh (array-shape arr)]) + (= (vector-ref sh 0) (vector-ref sh 1)))))) -(: square-matrix-size : (All (A) (Matrix A) -> Index)) -(define (square-matrix-size a) - (vector-ref (array-shape a) 0)) +(define row-matrix? + (plambda: (A) ([arr : (Array A)]) + (and (matrix? arr) + (= (vector-ref (array-shape arr) 0) 1)))) -(: matrix-all= : (Matrix Number) (Matrix Number) -> Boolean) -(define (matrix-all= arr0 arr1) - (array-all-and (array= arr0 arr1))) +(define col-matrix? + (plambda: (A) ([arr : (Array A)]) + (and (matrix? arr) + (= (vector-ref (array-shape arr) 1) 1)))) -(: matrix-dimensions : (Matrix Number) -> (Values Index Index)) -(define (matrix-dimensions a) - (define sh (array-shape a)) - ; TODO: Remove list conversion when trbug1 is fixed - (define sh-tmp (vector->list sh)) - (values (car sh-tmp) (cadr sh-tmp)) - ; (values (vector-ref sh 0) (vector-ref sh 1)) - ) +(: matrix-shape : (All (A) (Matrix A) -> (Values Index Index))) +(define (matrix-shape a) + (cond [(matrix? a) (define sh (array-shape a)) + (values (unsafe-vector-ref sh 0) (unsafe-vector-ref sh 1))] + [else (raise-argument-error 'matrix-shape "matrix?" a)])) -(: matrix-row-dimension : (Matrix Number) -> Index) -(define (matrix-row-dimension a) - (define sh (array-shape a)) - (vector-ref sh 0)) +(: square-matrix-size (All (A) ((Matrix A) -> Index))) +(define (square-matrix-size arr) + (cond [(square-matrix? arr) (unsafe-vector-ref (array-shape arr) 0)] + [else (raise-argument-error 'square-matrix-size "square-matrix?" arr)])) -(: matrix-column-dimension : (Matrix Number) -> Index) -(define (matrix-column-dimension a) - (define sh (array-shape a)) - (vector-ref sh 1)) +(: matrix-num-rows (All (A) ((Matrix A) -> Index))) +(define (matrix-num-rows a) + (cond [(matrix? a) (vector-ref (array-shape a) 0)] + [else (raise-argument-error 'matrix-col-length "matrix?" a)])) + +(: matrix-num-cols (All (A) ((Matrix A) -> Index))) +(define (matrix-num-cols a) + (cond [(matrix? a) (vector-ref (array-shape a) 1)] + [else (raise-argument-error 'matrix-row-length "matrix?" a)])) diff --git a/collects/math/private/matrix/typed-matrix-arithmetic.rkt b/collects/math/private/matrix/typed-matrix-arithmetic.rkt new file mode 100644 index 0000000000..0f53c02252 --- /dev/null +++ b/collects/math/private/matrix/typed-matrix-arithmetic.rkt @@ -0,0 +1,93 @@ +#lang typed/racket/base + +(require racket/list + math/array + "matrix-types.rkt" + "utils.rkt" + (except-in "untyped-matrix-arithmetic.rkt" matrix-map)) + +(provide matrix-map + matrix= + matrix* + matrix+ + matrix- + matrix-scale + matrix-sum) + +(: matrix-map (All (R A B T ...) + (case-> ((A -> R) (Array A) -> (Array R)) + ((A B T ... T -> R) (Array A) (Array B) (Array T) ... T -> (Array R))))) +(define matrix-map + (case-lambda: + [([f : (A -> R)] [arr : (Array A)]) + (inline-matrix-map f arr)] + [([f : (A B -> R)] [arr0 : (Array A)] [arr1 : (Array B)]) + (inline-matrix-map f arr0 arr1)] + [([f : (A B T ... T -> R)] [arr0 : (Array A)] [arr1 : (Array B)] . [arrs : (Array T) ... T]) + (define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs)) + (define g0 (unsafe-array-proc arr0)) + (define g1 (unsafe-array-proc arr1)) + (define gs (map unsafe-array-proc arrs)) + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) (apply f (g0 js) (g1 js) + (map (λ: ([g : (Indexes -> T)]) (g js)) gs))))])) + +(: matrix=? ((Array Number) (Array Number) -> Boolean)) +(define (matrix=? arr0 arr1) + (define-values (m0 n0) (matrix-shape arr0)) + (define-values (m1 n1) (matrix-shape arr1)) + (and (= m0 m1) + (= n0 n1) + (let ([proc0 (unsafe-array-proc arr0)] + [proc1 (unsafe-array-proc arr1)]) + (array-all-and (unsafe-build-array + ((inst vector Index) m0 n0) + (λ: ([js : Indexes]) + (= (proc0 js) (proc1 js)))))))) + +(: matrix= (case-> ((Array Number) (Array Number) -> Boolean) + ((Array Number) (Array Number) (Array Number) (Array Number) * -> Boolean))) +(define matrix= + (case-lambda: + [([arr0 : (Array Number)] [arr1 : (Array Number)]) (matrix=? arr0 arr1)] + [([arr0 : (Array Number)] [arr1 : (Array Number)] . [arrs : (Array Number) *]) + (and (matrix=? arr0 arr1) + (let: loop : Boolean ([arr1 : (Array Number) arr1] + [arrs : (Listof (Array Number)) arrs]) + (cond [(empty? arrs) #t] + [else (and (matrix=? arr1 (first arrs)) + (loop (first arrs) (rest arrs)))])))])) + +(: matrix* (case-> ((Array Real) (Array Real) * -> (Array Real)) + ((Array Number) (Array Number) * -> (Array Number)))) +(define (matrix* a . as) + (let loop ([a a] [as as]) + (cond [(empty? as) a] + [else (loop (inline-matrix* a (first as)) (rest as))]))) + +(: matrix+ (case-> ((Array Real) (Array Real) * -> (Array Real)) + ((Array Number) (Array Number) * -> (Array Number)))) +(define (matrix+ a . as) + (let loop ([a a] [as as]) + (cond [(empty? as) a] + [else (loop (inline-matrix+ a (first as)) (rest as))]))) + +(: matrix- (case-> ((Array Real) (Array Real) * -> (Array Real)) + ((Array Number) (Array Number) * -> (Array Number)))) +(define (matrix- a . as) + (cond [(empty? as) (inline-matrix- a)] + [else + (let loop ([a a] [as as]) + (cond [(empty? as) a] + [else (loop (inline-matrix- a (first as)) (rest as))]))])) + +(: matrix-scale (case-> ((Array Real) Real -> (Array Real)) + ((Array Number) Number -> (Array Number)))) +(define (matrix-scale a x) (inline-matrix-scale a x)) + +(: matrix-sum (case-> ((Listof (Array Real)) -> (Array Real)) + ((Listof (Array Number)) -> (Array Number)))) +(define (matrix-sum lst) + (cond [(empty? lst) (raise-argument-error 'matrix-sum "nonempty List" lst)] + [else (apply matrix+ lst)])) diff --git a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt new file mode 100644 index 0000000000..9dceed073f --- /dev/null +++ b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt @@ -0,0 +1,105 @@ +#lang racket/base + +(provide inline-matrix* + inline-matrix+ + inline-matrix- + inline-matrix-scale + inline-matrix-map + matrix-map) + +(module syntax-defs racket/base + (require (for-syntax racket/base) + (only-in typed/racket/base λ: : inst Index) + math/array + "matrix-types.rkt" + "utils.rkt") + + (provide (all-defined-out)) + + ;(: matrix-multiply ((Array Number) (Array Number) -> (Array Number))) + ;; This is a macro so the result can have as precise a type as possible + (define-syntax-rule (inline-matrix-multiply arr-expr brr-expr) + (let ([arr arr-expr] + [brr brr-expr]) + (let-values ([(m p n) (matrix-multiply-shape arr brr)] + ;; Make arr strict because its elements are reffed multiple times + [(_) (array-strict! arr)]) + (let (;; Extend arr in the center dimension + [arr-proc (unsafe-array-proc (array-axis-insert arr 1 n))] + ;; Transpose brr and extend in the leftmost dimension + [brr-proc (unsafe-array-proc + (array-axis-insert (array-strict (array-axis-swap brr 0 1)) 0 m))]) + ;; The *transpose* of brr is traversed in row-major order when this result is traversed + ;; in row-major order (which is why the transpose is strict, not brr) + (array-axis-sum + (unsafe-build-array + ((inst vector Index) m n p) + (λ: ([js : Indexes]) + (* (arr-proc js) (brr-proc js)))) + 2))))) + + (define-syntax (inline-matrix* stx) + (syntax-case stx () + [(_ arr) + (syntax/loc stx arr)] + [(_ arr brr crrs ...) + (syntax/loc stx (inline-matrix* (inline-matrix-multiply arr brr) crrs ...))])) + + (define-syntax (inline-matrix-map stx) + (syntax-case stx () + [(_ f arr-expr) + (syntax/loc stx + (let*-values ([(arr) arr-expr] + [(m n) (matrix-shape arr)] + [(proc) (unsafe-array-proc arr)]) + (unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (f (proc js))))))] + [(_ f arr-expr brr-exprs ...) + (with-syntax ([(brrs ...) (generate-temporaries #'(brr-exprs ...))] + [(procs ...) (generate-temporaries #'(brr-exprs ...))]) + (syntax/loc stx + (let ([arr arr-expr] + [brrs brr-exprs] ...) + (let-values ([(m n) (matrix-shapes 'matrix-map arr brrs ...)] + [(proc) (unsafe-array-proc arr)] + [(procs) (unsafe-array-proc brrs)] ...) + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (f (proc js) (procs js) ...)))))))])) + + (define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...)) + (define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...)) + (define-syntax-rule (inline-matrix-scale arr x) (inline-matrix-map (λ (y) (* x y)) arr)) + + ) ; module + +(require 'syntax-defs) + +(module untyped-defs typed/racket/base + (require math/array + (submod ".." syntax-defs) + "utils.rkt") + + (provide matrix-map) + + (: matrix-map (All (R A) (case-> ((A -> R) (Array A) -> (Array R)) + ((A A A * -> R) (Array A) (Array A) (Array A) * -> (Array R))))) + (define matrix-map + (case-lambda: + [([f : (A -> R)] [arr : (Array A)]) + (inline-matrix-map f arr)] + [([f : (A A -> R)] [arr0 : (Array A)] [arr1 : (Array A)]) + (inline-matrix-map f arr0 arr1)] + [([f : (A A A * -> R)] [arr0 : (Array A)] [arr1 : (Array A)] . [arrs : (Array A) *]) + (define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs)) + (define g0 (unsafe-array-proc arr0)) + (define g1 (unsafe-array-proc arr1)) + (define gs (map (inst unsafe-array-proc A) arrs)) + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) (apply f (g0 js) (g1 js) + (map (λ: ([g : (Indexes -> A)]) (g js)) gs))))])) + + ) ; module + +(require 'untyped-defs) diff --git a/collects/math/private/matrix/utils.rkt b/collects/math/private/matrix/utils.rkt index 80813ecd42..fefced18be 100644 --- a/collects/math/private/matrix/utils.rkt +++ b/collects/math/private/matrix/utils.rkt @@ -1,6 +1,38 @@ #lang typed/racket/base -(require math/array) +(require racket/match + racket/string + math/array + "matrix-types.rkt") (provide (all-defined-out)) +(: format-matrices/error ((Listof (Array Any)) -> String)) +(define (format-matrices/error as) + (string-join (map (λ: ([a : (Array Any)]) (format "~e" a)) as))) + +(: matrix-shapes (Symbol (Matrix Any) (Matrix Any) * -> (Values Index Index))) +(define (matrix-shapes name arr . brrs) + (define-values (m n) (matrix-shape arr)) + (unless (andmap (λ: ([brr : (Matrix Any)]) + (match-define (vector bm bn) (array-shape brr)) + (and (= bm m) (= bn n))) + brrs) + (error name + "matrices must have the same shape; given ~a" + (format-matrices/error (cons arr brrs)))) + (values m n)) + +(: matrix-multiply-shape ((Matrix Any) (Matrix Any) -> (Values Index Index Index))) +(define (matrix-multiply-shape arr brr) + (define-values (ad0 ad1) (matrix-shape arr)) + (define-values (bd0 bd1) (matrix-shape brr)) + (unless (= ad1 bd0) + (error 'matrix-multiply + "1st argument column size and 2nd argument row size are not equal; given ~e and ~e" + arr brr)) + (values ad0 ad1 bd1)) + +(: ensure-matrix (All (A) Symbol (Array A) -> (Array A))) +(define (ensure-matrix name a) + (if (matrix? a) a (raise-argument-error name "matrix?" a))) diff --git a/collects/math/scribblings/math-array.scrbl b/collects/math/scribblings/math-array.scrbl index c0f13ba618..9f2abdb385 100644 --- a/collects/math/scribblings/math-array.scrbl +++ b/collects/math/scribblings/math-array.scrbl @@ -1385,7 +1385,7 @@ Returns an array with shape @racket[(vector (array-size arr))], with the element @;{==================================================================================================} -@section{Folds and Other Axis Reductions} +@section{Folds, Reductions and Expansions} @defproc*[([(array-axis-fold [arr (Array A)] [k Integer] [f (A A -> A)]) (Array A)] [(array-axis-fold [arr (Array A)] [k Integer] [f (A B -> B)] [init B]) (Array B)])]{ @@ -1547,18 +1547,6 @@ and @racket[array-all-or] is defined similarly. (array-all-or (array= arr (array 0)))] } -@defproc[(array->list-array [arr (Array A)] [k Integer]) (Array (Listof A))]{ -Returns an array of lists, computed as if by applying @racket[list] to the elements in each row of -axis @racket[k]. -@examples[#:eval typed-eval - (define arr (index-array #(3 3))) - arr - (array->list-array arr 1) - (array-ref (array->list-array (array->list-array arr 1) 0) #())] -See @racket[mean] for more useful examples, and @racket[array-axis-reduce] for an example that shows -how @racket[array->list-array] is implemented. -} - @defproc[(array-axis-reduce [arr (Array A)] [k Integer] [h (Index (Integer -> A) -> B)]) (Array B)]{ Like @racket[array-axis-fold], but allows evaluation control (such as short-cutting @racket[and] and @racket[or]) by moving the loop into @racket[h]. The result has the shape of @racket[arr], but with @@ -1591,6 +1579,60 @@ Every fold, including @racket[array-axis-fold], is ultimately defined using @racket[array-axis-reduce] or its unsafe counterpart. } +@defproc[(array-axis-expand [arr (Array A)] [k Integer] [dk Integer] [g (A Index -> B)]) (Array B)]{ +Inserts a new axis number @racket[k] of length @racket[dk], using @racket[g] to generate values; +@racket[k] must be @italic{no greater than} the dimension of @racket[arr], and @racket[dk] must be +nonnegative. + +Conceptually, @racket[g] is applied @racket[dk] times to each element in each row of axis @racket[k], +once for each nonnegative index @racket[jk < dk]. (In reality, @racket[g] is applied only when the +resulting array is indexed.) + +Turning vector elements into rows of a new last axis using @racket[array-axis-expand] and +@racket[vector-ref]: +@interaction[#:eval typed-eval + (define arr (array #['#(a b c) '#(d e f) '#(g h i)])) + (array-axis-expand arr 1 3 vector-ref)] +Creating a @racket[vandermonde-matrix]: +@interaction[#:eval typed-eval + (array-axis-expand (list->array '(1 2 3 4)) 1 5 expt)] + +This function is a dual of @racket[array-axis-reduce] in that it can be used to invert applications +of @racket[array-axis-reduce]. +To do so, @racket[g] should be a destructuring function that is dual to the constructor passed to +@racket[array-axis-reduce]. +Example dual pairs are @racket[vector-ref] and @racket[build-vector], and @racket[list-ref] and +@racket[build-list]. + +(Do not pass @racket[list-ref] to @racket[array-axis-expand] if you care about performance, though. +See @racket[list-array->array] for a more efficient solution.) +} + +@defproc[(array->list-array [arr (Array A)] [k Integer 0]) (Array (Listof A))]{ +Returns an array of lists, computed as if by applying @racket[list] to the elements in each row of +axis @racket[k]. +@examples[#:eval typed-eval + (define arr (index-array #(3 3))) + arr + (array->list-array arr 1) + (array-ref (array->list-array (array->list-array arr 1) 0) #())] +See @racket[mean] for more useful examples, and @racket[array-axis-reduce] for an example that shows +how @racket[array->list-array] is implemented. +} + +@defproc[(list-array->array [arr (Array (Listof A))] [k Integer 0]) (Array A)]{ +Returns an array in which the list elements of @racket[arr] comprise a new axis @racket[k]. +Equivalent to @racket[(array-axis-expand arr k n list-ref)] where @racket[n] is the +length of the lists in @racket[arr], but with O(1) indexing. + +@examples[#:eval typed-eval + (define arr (array->list-array (index-array #(3 3)) 1)) + arr + (list-array->array arr 1)] + +For fixed @racket[k], this function and @racket[array->list-array] are mutual inverses with respect +to their array arguments. +} @;{==================================================================================================} diff --git a/collects/math/tests/matrix-tests.rkt b/collects/math/tests/matrix-tests.rkt index 3663945fc0..24b8eb1a62 100644 --- a/collects/math/tests/matrix-tests.rkt +++ b/collects/math/tests/matrix-tests.rkt @@ -3,28 +3,86 @@ (require math/array math/base math/flonum - math/matrix) + math/matrix + "../private/matrix/matrix-column.rkt" + "test-utils.rkt") + +(: random-matrix (Integer Integer Integer -> (Matrix Integer))) +;; Generates a random matrix with integer elements < k. Useful to test properties. +(define (random-matrix m n k) + (array-strict (build-array (vector m n) (λ (_) (random k))))) + +;; =================================================================================================== +;; Types + +(check-true (matrix? (array #[#[1]]))) +(check-false (matrix? (array #[1]))) +(check-false (matrix? (array 1))) +(check-false (matrix? (array #[]))) + +(check-true (row-matrix? (matrix [[1 2 3 4]]))) +(check-false (row-matrix? (matrix [[1] [2] [3] [4]]))) + +(check-true (col-matrix? (matrix [[1] [2] [3] [4]]))) +(check-false (col-matrix? (matrix [[1 2 3 4]]))) + +;; =================================================================================================== +;; Matrix multiplication + +(check-equal? (matrix* (identity-matrix 2) + (matrix [[1 20] [300 4000]])) + (matrix [[1 20] [300 4000]])) + +(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]]) + (matrix [[1 2 3] [4 5 6] [7 8 9]])) + (matrix [[30 36 42] [66 81 96] [102 126 150]])) + +(let ([m0 (random-matrix 4 5 100)] + [m1 (random-matrix 5 2 100)] + [m2 (random-matrix 2 10 100)]) + (check-equal? (matrix* (matrix* m0 m1) m2) + (matrix* m0 (matrix* m1 m2)))) + +;; =================================================================================================== +;; Construction + +(check-equal? + (block-diagonal-matrix + (list + (matrix [[1 2] [3 4]]) + (matrix [[1 2 3] [4 5 6]]) + (matrix [[1] [3] [5]]))) + (matrix + [[1 2 0 0 0 0] + [3 4 0 0 0 0] + [0 0 1 2 3 0] + [0 0 4 5 6 0] + [0 0 0 0 0 1] + [0 0 0 0 0 3] + [0 0 0 0 0 5]])) + +;; =================================================================================================== (begin (begin "matrix-types.rkt" (list - 'array-matrix? - (array-matrix? (list*->array '[[1 2] [3 4]] real? )) - (not (array-matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? )))) + 'matrix? + (matrix? (list*->array '[[1 2] [3 4]] real? )) + (not (matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? )))) (list 'square-matrix? (square-matrix? (list*->array '[[1 2] [3 4]] real? )) (not (square-matrix? (list*->array '[[1 2 3] [4 5 6]] real? )))) (list 'square-matrix-size - (= 2 (square-matrix-size (list*->array '[[1 2 3] [4 5 6]] real? )))) + (= 3 (square-matrix-size (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real?)))) (list - 'matrix-all=- - (matrix-all= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? )) - (not (matrix-all= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? )))) + 'matrix= + (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? )) + #;(not (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? )))) (list - 'matrix-dimensions - (let-values ([(m n) (matrix-dimensions (list->matrix '[[1 2 3] [4 5 6]]))]) + 'matrix-shape + (let-values ([(m n) (matrix-shape (list*->matrix '[[1 2 3] [4 5 6]]))]) (equal? (list m n) '(2 3))))) (begin "matrix-constructors.rkt" @@ -32,47 +90,46 @@ 'identity-matrix (equal? (array->list* (identity-matrix 1)) '[[1]]) (equal? (array->list* (identity-matrix 2)) '[[1 0] [0 1]]) - (equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]]) - (equal? (array->list* (flidentity-matrix 1)) '[[1.]]) - (equal? (array->list* (flidentity-matrix 2)) '[[1. 0.] [0. 1.]]) - (equal? (array->list* (flidentity-matrix 3)) '[[1. 0. 0.] [0. 1. 0.] [0. 0. 1.]])) + (equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]])) (list 'const-matrix (equal? (array->list* (make-matrix 2 3 0)) '((0 0 0) (0 0 0))) (equal? (array->list* (make-matrix 2 3 0.)) '((0. 0. 0.) (0. 0. 0.)))) (list 'matrix->list - (equal? (matrix->list (list->matrix '((1 2) (3 4)))) '((1 2) (3 4))) - (equal? (matrix->list (fllist->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.)))) + (equal? (matrix->list* (list*->matrix '((1 2) (3 4)))) '((1 2) (3 4))) + (equal? (matrix->list* (list*->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.)))) (list 'matrix->vector - (equal? (matrix->vector (vector->matrix '#(#(1 2) #(3 4)))) '#(#(1 2) #(3 4))) - (equal? (matrix->vector (flvector->matrix '#(#(1. 2.) #(3. 4.)))) '#(#(1. 2.) #(3. 4.)))) + (equal? (matrix->vector* ((inst vector*->matrix Integer) '#(#(1 2) #(3 4)))) + '#(#(1 2) #(3 4))) + (equal? (matrix->vector* ((inst vector*->matrix Flonum) '#(#(1. 2.) #(3. 4.)))) + '#(#(1. 2.) #(3. 4.)))) (list 'matrix-row - (equal? (matrix-row (identity-matrix 3) 0) (list->matrix '[[1 0 0]])) - (equal? (matrix-row (identity-matrix 3) 1) (list->matrix '[[0 1 0]])) - (equal? (matrix-row (identity-matrix 3) 2) (list->matrix '[[0 0 1]]))) + (equal? (matrix-row (identity-matrix 3) 0) (list*->matrix '[[1 0 0]])) + (equal? (matrix-row (identity-matrix 3) 1) (list*->matrix '[[0 1 0]])) + (equal? (matrix-row (identity-matrix 3) 2) (list*->matrix '[[0 0 1]]))) (list 'matrix-col - (equal? (matrix-column (identity-matrix 3) 0) (list->matrix '[[1] [0] [0]])) - (equal? (matrix-column (identity-matrix 3) 1) (list->matrix '[[0] [1] [0]])) - (equal? (matrix-column (identity-matrix 3) 2) (list->matrix '[[0] [0] [1]]))) + (equal? (matrix-col (identity-matrix 3) 0) (list*->matrix '[[1] [0] [0]])) + (equal? (matrix-col (identity-matrix 3) 1) (list*->matrix '[[0] [1] [0]])) + (equal? (matrix-col (identity-matrix 3) 2) (list*->matrix '[[0] [0] [1]]))) (list 'submatrix (equal? (submatrix (identity-matrix 3) - (in-range 0 1) (in-range 0 2)) (list->matrix '[[1 0]])) + (in-range 0 1) (in-range 0 2)) (list*->matrix '[[1 0]])) (equal? (submatrix (identity-matrix 3) - (in-range 0 2) (in-range 0 3)) (list->matrix '[[1 0 0] [0 1 0]])))) - + (in-range 0 2) (in-range 0 3)) (list*->matrix '[[1 0 0] [0 1 0]])))) + (begin "matrix-pointwise.rkt" (let () - (define A (list->matrix '[[1 2] [3 4]])) - (define ~A (list->matrix '[[-1 -2] [-3 -4]])) - (define B (list->matrix '[[5 6] [7 8]])) - (define A+B (list->matrix '[[6 8] [10 12]])) - (define A-B (list->matrix '[[-4 -4] [-4 -4]])) + (define A (list*->matrix '[[1 2] [3 4]])) + (define ~A (list*->matrix '[[-1 -2] [-3 -4]])) + (define B (list*->matrix '[[5 6] [7 8]])) + (define A+B (list*->matrix '[[6 8] [10 12]])) + (define A-B (list*->matrix '[[-4 -4] [-4 -4]])) (list 'matrix+ (equal? (matrix+ A B) A+B)) (list 'matrix- (equal? (matrix- A B) A-B) @@ -81,102 +138,102 @@ (begin "matrix-expt.rkt" (let () - (define A (list->matrix '[[1 2] [3 4]])) + (define A (list*->matrix '[[1 2] [3 4]])) (list 'matrix-expt (equal? (matrix-expt A 0) (identity-matrix 2)) (equal? (matrix-expt A 1) A) - (equal? (matrix-expt A 2) (list->matrix '[[7 10] [15 22]])) - (equal? (matrix-expt A 3) (list->matrix '[[37 54] [81 118]])) - (equal? (matrix-expt A 8) (list->matrix '[[165751 241570] [362355 528106]])))) - #;(list - (define A (fllist->matrix '[[1. 2.] [3. 4.]])) - (check-equal? (matrix->list (flmatrix-expt A 0)) (matrix->list (flidentity-matrix 2))) - (check-equal? (matrix->list (flmatrix-expt A 1)) (matrix->list A)) - (check-equal? (matrix->list (flmatrix-expt A 2)) '[[7. 10.] [15. 22.]]) - (check-equal? (matrix->list (flmatrix-expt A 3)) '[[37. 54.] [81. 118.]]) - (check-equal? (matrix->list (flmatrix-expt A 8)) '[[165751. 241570.] [362355. 528106.]]))) - + (equal? (matrix-expt A 2) (list*->matrix '[[7 10] [15 22]])) + (equal? (matrix-expt A 3) (list*->matrix '[[37 54] [81 118]])) + (equal? (matrix-expt A 8) (list*->matrix '[[165751 241570] [362355 528106]]))))) + (begin "matrix-operations.rkt" (list 'vandermonde-matrix (equal? (vandermonde-matrix '(1 2 3) 5) - (list->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]]))) + (list*->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]]))) + #; (list 'in-column - (equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix/dim 2 2 1 2 3 4) 0)]) x) + (equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 0)]) + x) '(1 3)) - (equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix/dim 2 2 1 2 3 4) 1)]) x) + (equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 1)]) + x) '(2 4)) - (equal? (for/list: : (Listof Number) ([x (in-column (column 5 2 3))]) x) + (equal? (for/list: : (Listof Number) ([x (in-column (col-matrix [5 2 3]))]) x) '(5 2 3))) + #; (list 'in-row - (equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix/dim 2 2 1 2 3 4) 0)]) x) + (equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 0)]) + x) '(1 2)) - (equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix/dim 2 2 1 2 3 4) 1)]) x) + (equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 1)]) + x) '(3 4))) (list 'for/matrix: (equal? (for/matrix: : Number 2 4 ([i (in-naturals)]) i) - (matrix/dim 2 4 - 0 1 2 3 - 4 5 6 7)) + (matrix [[0 1 2 3] [4 5 6 7]])) (equal? (for/matrix: : Number 2 4 #:column ([i (in-naturals)]) i) - (matrix/dim 2 4 - 0 2 4 6 - 1 3 5 7)) + (matrix [[0 2 4 6] [1 3 5 7]])) (equal? (for/matrix: : Number 3 3 ([i (in-range 10 100)]) i) - (matrix/dim 3 3 10 11 12 13 14 15 16 17 18))) + (matrix [[10 11 12] [13 14 15] [16 17 18]]))) (list 'for*/matrix: (equal? (for*/matrix: : Number 3 3 ([i (in-range 3)] [j (in-range 3)]) (+ (* i 10) j)) - (matrix/dim 3 3 0 1 2 10 11 12 20 21 22))) + (matrix [[0 1 2] [10 11 12] [20 21 22]]))) (list 'matrix-block-diagonal - (equal? (matrix-block-diagonal (list (matrix/dim 2 2 1 2 3 4) (matrix/dim 1 3 5 6 7))) - (list->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]]))) + (equal? (block-diagonal-matrix (list (matrix [[1 2] [3 4]]) (matrix [[5 6 7]]))) + (list*->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]]))) (list 'matrix-augment - (equal? (matrix-augment (column 1 2 3) (column 4 5 6) (column 7 8 9)) - (matrix/dim 3 3 1 4 7 2 5 8 3 6 9))) + (equal? (matrix-augment (list (col-matrix [1 2 3]) + (col-matrix [4 5 6]) + (col-matrix [7 8 9]))) + (matrix [[1 4 7] [2 5 8] [3 6 9]]))) (list 'matrix-stack - (equal? (matrix-stack (column 1 2 3) (column 4 5 6) (column 7 8 9)) - (column 1 2 3 4 5 6 7 8 9))) + (equal? (matrix-stack (list (col-matrix [1 2 3]) + (col-matrix [4 5 6]) + (col-matrix [7 8 9]))) + (col-matrix [1 2 3 4 5 6 7 8 9]))) + #; (list 'column-dimension (= (column-dimension #(1 2 3)) 3) - (= (column-dimension (flat-vector->matrix 1 2 #(1 2))) 1)) - (let ([matrix: flat-vector->matrix]) + (= (column-dimension (vector->matrix 1 2 #(1 2))) 1)) + (let ([matrix: vector->matrix]) (list 'column-dot - (= (column-dot (column 1 2) (column 1 2)) 5) - (= (column-dot (column 1 2) (column 3 4)) 11) - (= (column-dot (column 3 4) (column 3 4)) 25) - (= (column-dot (column 1 2 3) (column 4 5 6)) + (= (column-dot (col-matrix [1 2]) (col-matrix [1 2])) 5) + (= (column-dot (col-matrix [1 2]) (col-matrix [3 4])) 11) + (= (column-dot (col-matrix [3 4]) (col-matrix [3 4])) 25) + (= (column-dot (col-matrix [1 2 3]) (col-matrix [4 5 6])) (+ (* 1 4) (* 2 5) (* 3 6))) - (= (column-dot (column +3i +4i) (column +3i +4i)) + (= (column-dot (col-matrix [+3i +4i]) (col-matrix [+3i +4i])) 25))) (list 'matrix-trace - (equal? (matrix-trace (flat-vector->matrix 2 2 #(1 2 3 4))) 5)) - (let ([matrix: flat-vector->matrix]) + (equal? (matrix-trace (vector->matrix 2 2 #(1 2 3 4))) 5)) + (let ([matrix: vector->matrix]) (list 'column-norm - (= (column-norm (column 2 4)) (sqrt 20)))) - (list 'column-projection - (equal? (column-projection #(1 2 3) #(4 5 6)) (column 128/77 160/77 192/77)) - (equal? (column-projection (column 1 2 3) (column 2 4 3)) - (matrix-scale 19/29 (column 2 4 3)))) + (= (column-norm (col-matrix [2 4])) (sqrt 20)))) + (list 'column-project + (equal? (column-project #(1 2 3) #(4 5 6)) (col-matrix [128/77 160/77 192/77])) + (equal? (column-project (col-matrix [1 2 3]) (col-matrix [2 4 3])) + (matrix-scale (col-matrix [2 4 3]) 19/29))) (list 'projection-on-orthogonal-basis (equal? (projection-on-orthogonal-basis #(3 -2 2) (list #(-1 0 2) #( 2 5 1))) - (column -1/3 -1/3 1/3)) + (col-matrix [-1/3 -1/3 1/3])) (equal? (projection-on-orthogonal-basis - (column 3 -2 2) (list #(-1 0 2) (column 2 5 1))) - (column -1/3 -1/3 1/3))) + (col-matrix [3 -2 2]) (list #(-1 0 2) (col-matrix [2 5 1]))) + (col-matrix [-1/3 -1/3 1/3]))) (list 'projection-on-orthonormal-basis (equal? (projection-on-orthonormal-basis #(1 2 3 4) - (list (matrix-scale 1/2 (column 1 1 1 1)) - (matrix-scale 1/2 (column -1 1 -1 1)) - (matrix-scale 1/2 (column 1 -1 -1 1)))) - (column 2 3 2 3))) + (list (matrix-scale (col-matrix [ 1 1 1 1]) 1/2) + (matrix-scale (col-matrix [-1 1 -1 1]) 1/2) + (matrix-scale (col-matrix [ 1 -1 -1 1]) 1/2))) + (col-matrix [2 3 2 3]))) (list 'gram-schmidt-orthogonal (equal? (gram-schmidt-orthogonal (list #(3 1) #(2 2))) - (list (column 3 1) (column -2/5 6/5)))) + (list (col-matrix [3 1]) (col-matrix [-2/5 6/5])))) (list 'vector-normalize (equal? (column-normalize #(3 4)) - (column 3/5 4/5))) + (col-matrix [3/5 4/5]))) (list 'gram-schmidt-orthonormal (equal? (gram-schmidt-orthonormal (ann '(#(3 1) #(2 2)) (Listof (Column Number)))) (list (column-normalize #(3 1)) @@ -184,74 +241,81 @@ (list 'projection-on-subspace (equal? (projection-on-subspace #(1 2 3) '(#(2 4 3))) - (matrix-scale 19/29 (column 2 4 3)))) + (matrix-scale (col-matrix [2 4 3]) 19/29))) (list 'unit-vector - (equal? (unit-column 4 1) (column 0 1 0 0))) + (equal? (unit-column 4 1) (col-matrix [0 1 0 0]))) (list 'matrix-qr - (let-values ([(Q R) (matrix-qr (matrix/dim 3 2 1 1 0 1 1 1))]) + (let-values ([(Q R) (matrix-qr (matrix [[1 1] [0 1] [1 1]]))]) (equal? (list Q R) - (list (matrix/dim 3 2 0.7071067811865475 0 0 1 0.7071067811865475 0) - (matrix/dim 2 2 1.414213562373095 1.414213562373095 0 1)))) + (list (matrix [[0.7071067811865475 0] + [0 1] + [0.7071067811865475 0]]) + (matrix [[1.414213562373095 1.414213562373095] + [0 1]])))) (let () - (define A (matrix/dim 4 4 1 2 3 4 1 2 4 5 1 2 5 6 1 2 6 7)) + (define A (matrix [[1 2 3 4] [1 2 4 5] [1 2 5 6] [1 2 6 7]])) (define-values (Q R) (matrix-qr A)) (equal? (list Q R) (list - (flat-vector->matrix - 4 4 (ann #(1/2 -0.6708203932499369 0.5477225575051662 -0.0 - 1/2 -0.22360679774997896 -0.7302967433402214 0.4082482904638629 - 1/2 0.22360679774997896 -0.18257418583505536 -0.8164965809277259 - 1/2 0.6708203932499369 0.3651483716701107 0.408248290463863) - (Vectorof Number))) - (flat-vector->matrix - 4 4 (ann #(2 4 9 11 0 0.0 2.23606797749979 2.23606797749979 - 0 0 0.0 4.440892098500626e-16 0 0 0 0.0) - (Vectorof Number))))))) + (vector->matrix + 4 4 ((inst vector Number) + 1/2 -0.6708203932499369 0.5477225575051662 -0.0 + 1/2 -0.22360679774997896 -0.7302967433402214 0.4082482904638629 + 1/2 0.22360679774997896 -0.18257418583505536 -0.8164965809277259 + 1/2 0.6708203932499369 0.3651483716701107 0.408248290463863)) + (vector->matrix + 4 4 ((inst vector Number) + 2 4 9 11 0 0.0 2.23606797749979 2.23606797749979 + 0 0 0.0 4.440892098500626e-16 0 0 0 0.0)))))) (list 'matrix-solve - (let* ([M (list->matrix '[[1 5] [2 3]])] - [b (list->matrix '[[5] [5]])]) + (let* ([M (list*->matrix '[[1 5] [2 3]])] + [b (list*->matrix '[[5] [5]])]) (equal? (matrix* M (matrix-solve M b)) b))) (list 'matrix-inverse - (equal? (let ([M (list->matrix '[[1 2] [3 4]])]) (matrix* M (matrix-inverse M))) + (equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* M (matrix-inverse M))) (identity-matrix 2)) - (equal? (let ([M (list->matrix '[[1 2] [3 4]])]) (matrix* (matrix-inverse M) M)) + (equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* (matrix-inverse M) M)) (identity-matrix 2))) (list 'matrix-determinant - (equal? (matrix-determinant (list->matrix '[[3]])) 3) - (equal? (matrix-determinant (list->matrix '[[1 2] [3 4]])) (- (* 1 4) (* 2 3))) - (equal? (matrix-determinant (list->matrix '[[1 2 3] [4 5 6] [7 8 9]])) 0) - (equal? (matrix-determinant (list->matrix '[[1 2 3] [4 -5 6] [7 8 9]])) 120) - (equal? (matrix-determinant (list->matrix '[[1 2 3 4] [-5 6 7 8] [9 10 -11 12] [13 14 15 16]])) 5280)) + (equal? (matrix-determinant (list*->matrix '[[3]])) 3) + (equal? (matrix-determinant (list*->matrix '[[1 2] [3 4]])) (- (* 1 4) (* 2 3))) + (equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 5 6] [7 8 9]])) 0) + (equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 -5 6] [7 8 9]])) 120) + (equal? (matrix-determinant (list*->matrix '[[1 2 3 4] + [-5 6 7 8] + [9 10 -11 12] + [13 14 15 16]])) + 5280)) (list 'matrix-scale - (equal? (matrix-scale 2 (list->matrix '[[1 2] [3 4]])) - (list->matrix '[[2 4] [6 8]]))) + (equal? (matrix-scale (list*->matrix '[[1 2] [3 4]]) 2) + (list*->matrix '[[2 4] [6 8]]))) (list 'matrix-transpose - (equal? (matrix-transpose (list->matrix '[[1 2] [3 4]])) - (list->matrix '[[1 3] [2 4]]))) + (equal? (matrix-transpose (list*->matrix '[[1 2] [3 4]])) + (list*->matrix '[[1 3] [2 4]]))) (list 'matrix-hermitian - (equal? (matrix-hermitian (list->matrix '[[1+i 2-i] [3+i 4-i]])) - (list->matrix '[[1-i 3-i] [2+i 4+i]]))) + (equal? (matrix-hermitian (list*->matrix '[[1+i 2-i] [3+i 4-i]])) + (list*->matrix '[[1-i 3-i] [2+i 4+i]]))) (let () (: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number)) (define (gauss-eliminate M u? p?) (let-values ([(M wp) (matrix-gauss-eliminate M u? p?)]) M)) (list 'matrix-gauss-eliminate - (equal? (let ([M (list->matrix '[[1 2] [3 4]])]) + (equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (gauss-eliminate M #f #f)) - (list->matrix '[[1 2] [0 -2]])) - (equal? (let ([M (list->matrix '[[2 4] [3 4]])]) + (list*->matrix '[[1 2] [0 -2]])) + (equal? (let ([M (list*->matrix '[[2 4] [3 4]])]) (gauss-eliminate M #t #f)) - (list->matrix '[[1 2] [0 1]])) - (equal? (let ([M (list->matrix '[[2. 4.] [3. 4.]])]) + (list*->matrix '[[1 2] [0 1]])) + (equal? (let ([M (list*->matrix '[[2. 4.] [3. 4.]])]) (gauss-eliminate M #t #t)) - (list->matrix '[[1. 1.3333333333333333] [0. 1.]])) - (equal? (let ([M (list->matrix '[[1 4] [2 4]])]) + (list*->matrix '[[1. 1.3333333333333333] [0. 1.]])) + (equal? (let ([M (list*->matrix '[[1 4] [2 4]])]) (gauss-eliminate M #t #t)) - (list->matrix '[[1 2] [0 1]])) - (equal? (let ([M (list->matrix '[[1 2] [2 4]])]) + (list*->matrix '[[1 2] [0 1]])) + (equal? (let ([M (list*->matrix '[[1 2] [2 4]])]) (gauss-eliminate M #f #t)) - (list->matrix '[[2 4] [0 0]])))) + (list*->matrix '[[2 4] [0 0]])))) (list 'matrix-scale-row (equal? (matrix-scale-row (identity-matrix 3) 0 2) @@ -265,10 +329,10 @@ (equal? (matrix-add-scaled-row (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real? ) 0 2 1) (list*->array '[[9 12 15] [4 5 6] [7 8 9]] real? ))) (let () - (define M (list->matrix '[[1 1 0 3] - [2 1 -1 1] - [3 -1 -1 2] - [-1 2 3 -1]])) + (define M (list*->matrix '[[1 1 0 3] + [2 1 -1 1] + [3 -1 -1 2] + [-1 2 3 -1]])) (define LU (matrix-lu M)) (if (eq? LU #f) (list 'matrix-lu #f) @@ -277,12 +341,12 @@ (define V (if (list? LU) (second LU) #f)) (list 'matrix-lu - (equal? L (list->matrix + (equal? L (list*->matrix '[[1 0 0 0] [2 1 0 0] [3 4 1 0] [-1 -3 0 1]])) - (equal? V (list->matrix + (equal? V (list*->matrix '[[1 1 0 3] [0 -1 -1 -5] [0 0 3 13] @@ -290,34 +354,34 @@ (equal? (matrix* L V) M))))) (list 'matrix-rank - (equal? (matrix-rank (list->matrix '[[0 0] [0 0]])) 0) - (equal? (matrix-rank (list->matrix '[[1 0] [0 0]])) 1) - (equal? (matrix-rank (list->matrix '[[1 0] [0 3]])) 2) - (equal? (matrix-rank (list->matrix '[[1 2] [2 4]])) 1) - (equal? (matrix-rank (list->matrix '[[1 2] [3 4]])) 2)) + (equal? (matrix-rank (list*->matrix '[[0 0] [0 0]])) 0) + (equal? (matrix-rank (list*->matrix '[[1 0] [0 0]])) 1) + (equal? (matrix-rank (list*->matrix '[[1 0] [0 3]])) 2) + (equal? (matrix-rank (list*->matrix '[[1 2] [2 4]])) 1) + (equal? (matrix-rank (list*->matrix '[[1 2] [3 4]])) 2)) (list 'matrix-nullity - (equal? (matrix-nullity (list->matrix '[[0 0] [0 0]])) 2) - (equal? (matrix-nullity (list->matrix '[[1 0] [0 0]])) 1) - (equal? (matrix-nullity (list->matrix '[[1 0] [0 3]])) 0) - (equal? (matrix-nullity (list->matrix '[[1 2] [2 4]])) 1) - (equal? (matrix-nullity (list->matrix '[[1 2] [3 4]])) 0)) + (equal? (matrix-nullity (list*->matrix '[[0 0] [0 0]])) 2) + (equal? (matrix-nullity (list*->matrix '[[1 0] [0 0]])) 1) + (equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0) + (equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1) + (equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0)) #;(let () (define-values (c1 n1) - (matrix-column+null-space (list->matrix '[[0 0] [0 0]]))) + (matrix-column+null-space (list*rix '[[0 0] [0 0]]))) (define-values (c2 n2) - (matrix-column+null-space (list->matrix '[[1 2] [2 4]]))) + (matrix-column+null-space (list*->matrix '[[1 2] [2 4]]))) (define-values (c3 n3) - (matrix-column+null-space (list->matrix '[[1 2] [2 5]]))) + (matrix-column+null-space (list*atrix '[[1 2] [2 5]]))) (list 'matrix-column+null-space (equal? c1 '()) - (equal? n1 (list (list->matrix '[[0] [0]]) - (list->matrix '[[0] [0]]))) - (equal? c2 (list (list->matrix '[[1] [2]]))) + (equal? n1 (list (list*->matrix '[[0] [0]]) + (list*->matrix '[[0] [0]]))) + (equal? c2 (list (list*->matrix '[[1] [2]]))) ;(equal? n2 '([0 0])) - (equal? c3 (list (list->matrix '[[1] [2]]) - (list->matrix '[[2] [5]]))) + (equal? c3 (list (list*->matrix '[[1] [2]]) + (list*->matrix '[[2] [5]]))) (equal? n3 '())))) @@ -327,48 +391,50 @@ #;(begin - "matrix-multiply.rkt" - (list 'matrix* - (let () - (define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6] [7 8]] '[[19 22] [43 50]])) - (equal? (matrix* (list->matrix A) (list->matrix B)) (list->matrix AB))) - (let () - (define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6 7] [8 9 10]] '[[21 24 27] [47 54 61]])) - (equal? (matrix* (list->matrix A) (list->matrix B)) (list->matrix AB))))) + "matrix-multiply.rkt" + (list 'matrix* + (let () + (define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6] [7 8]] '[[19 22] [43 50]])) + (equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB))) + (let () + (define-values (A B AB) (values '[[1 2] [3 4]] + '[[5 6 7] [8 9 10]] + '[[21 24 27] [47 54 61]])) + (equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB))))) #;(begin - "matrix-2d.rkt" - (let () - (define e1 (matrix-transpose (vector->matrix #(#( 1 0))))) - (define e2 (matrix-transpose (vector->matrix #(#( 0 1))))) - (define -e1 (matrix-transpose (vector->matrix #(#(-1 0))))) - (define -e2 (matrix-transpose (vector->matrix #(#( 0 -1))))) - (define O (matrix-transpose (vector->matrix #(#( 0 0))))) - (define 2*e1 (matrix-scale 2 e1)) - (define 4*e1 (matrix-scale 4 e1)) - (define 3*e2 (matrix-scale 3 e2)) - (define 4*e2 (matrix-scale 4 e2)) - (begin - (list 'matrix-2d-rotation - (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0) - (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e2) -e1)) epsilon.0)) - (list - 'matrix-2d-scaling - (equal? (matrix* (matrix-2d-scaling 2 3) (matrix+ e1 e2)) (matrix+ 2*e1 3*e2))) - (list - 'matrix-2d-shear-x - (equal? (matrix* (matrix-2d-shear-x 3) (matrix+ e1 e2)) (matrix+ 4*e1 e2))) - (list - 'matrix-2d-shear-y - (equal? (matrix* (matrix-2d-shear-y 3) (matrix+ e1 e2)) (matrix+ e1 4*e2))) - (list - 'matrix-2d-reflection - (equal? (matrix* (matrix-2d-reflection 0 1) e1) -e1) - (equal? (matrix* (matrix-2d-reflection 0 1) e2) e2) - (equal? (matrix* (matrix-2d-reflection 1 0) e1) e1) - (equal? (matrix* (matrix-2d-reflection 1 0) e2) -e2)) - (list - 'matrix-2d-orthogonal-projection - (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e1) e1) - (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O) - (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O) - (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2)))))) + "matrix-2d.rkt" + (let () + (define e1 (matrix-transpose (vector->matrix #(#( 1 0))))) + (define e2 (matrix-transpose (vector->matrix #(#( 0 1))))) + (define -e1 (matrix-transpose (vector->matrix #(#(-1 0))))) + (define -e2 (matrix-transpose (vector->matrix #(#( 0 -1))))) + (define O (matrix-transpose (vector->matrix #(#( 0 0))))) + (define 2*e1 (matrix-scale e1 2)) + (define 4*e1 (matrix-scale e1 4)) + (define 3*e2 (matrix-scale e2 3)) + (define 4*e2 (matrix-scale e2 4)) + (begin + (list 'matrix-2d-rotation + (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0) + (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e2) -e1)) epsilon.0)) + (list + 'matrix-2d-scaling + (equal? (matrix* (matrix-2d-scaling 2 3) (matrix+ e1 e2)) (matrix+ 2*e1 3*e2))) + (list + 'matrix-2d-shear-x + (equal? (matrix* (matrix-2d-shear-x 3) (matrix+ e1 e2)) (matrix+ 4*e1 e2))) + (list + 'matrix-2d-shear-y + (equal? (matrix* (matrix-2d-shear-y 3) (matrix+ e1 e2)) (matrix+ e1 4*e2))) + (list + 'matrix-2d-reflection + (equal? (matrix* (matrix-2d-reflection 0 1) e1) -e1) + (equal? (matrix* (matrix-2d-reflection 0 1) e2) e2) + (equal? (matrix* (matrix-2d-reflection 1 0) e1) e1) + (equal? (matrix* (matrix-2d-reflection 1 0) e2) -e2)) + (list + 'matrix-2d-orthogonal-projection + (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e1) e1) + (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O) + (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O) + (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2)))))) diff --git a/collects/math/tests/test-utils.rkt b/collects/math/tests/test-utils.rkt new file mode 100644 index 0000000000..3e51d4013e --- /dev/null +++ b/collects/math/tests/test-utils.rkt @@ -0,0 +1,11 @@ +#lang typed/racket + +(require (except-in typed/rackunit check-equal?)) + +(provide (all-from-out typed/rackunit) + check-equal?) + +;; This gets around the fact that typed/rackunit can no longer test higher-order values for equality, +;; since TR has firmed up its rules on passing `Any' types in and out of untyped code +(define-syntax-rule (check-equal? a b . message) + (check-true (equal? a b) . message))