From 8d5a069d41baf0ce30ecb4fa2a0bc1845733ec0c Mon Sep 17 00:00:00 2001 From: Neil Toronto Date: Thu, 20 Dec 2012 17:31:07 -0700 Subject: [PATCH] Moar `math/matrix' review/refactoring * Split "matrix-constructors.rkt" into three parts: * "matrix-constructors.rkt" * "matrix-conversion.rkt" * "matrix-syntax.rkt" * Made `matrix-map' automatically inline (it's dirt simple) * Renamed a few things, changed some type signatures * Fixed error in `matrix-dot' caught by testing (it was broadcasting) * Rewrote matrix comprehensions in terms of array comprehensions * Removed `in-column' and `in-row' (can use `in-array', `matrix-col' and `matrix-row') * Tons of new rackunit tests: only "matrix-2d.rkt" and "matrix-operations.rkt" are left (though the latter is large) --- collects/math/matrix.rkt | 11 +- collects/math/private/matrix/matrix-2d.rkt | 2 +- .../math/private/matrix/matrix-arithmetic.rkt | 21 +- collects/math/private/matrix/matrix-basic.rkt | 22 +- .../math/private/matrix/matrix-column.rkt | 2 +- .../private/matrix/matrix-comprehension.rkt | 191 ++-- .../private/matrix/matrix-constructors.rkt | 517 +++------- .../math/private/matrix/matrix-conversion.rkt | 202 ++++ .../math/private/matrix/matrix-operations.rkt | 10 +- .../math/private/matrix/matrix-sequences.rkt | 340 ------- .../math/private/matrix/matrix-syntax.rkt | 35 + collects/math/private/matrix/matrix-types.rkt | 57 +- .../matrix/untyped-matrix-arithmetic.rkt | 5 +- collects/math/private/matrix/utils.rkt | 2 +- collects/math/tests/matrix-tests.rkt | 931 +++++++++++++----- 15 files changed, 1203 insertions(+), 1145 deletions(-) create mode 100644 collects/math/private/matrix/matrix-conversion.rkt delete mode 100644 collects/math/private/matrix/matrix-sequences.rkt create mode 100644 collects/math/private/matrix/matrix-syntax.rkt diff --git a/collects/math/matrix.rkt b/collects/math/matrix.rkt index 3dc1c82ca1..2447f34ffb 100644 --- a/collects/math/matrix.rkt +++ b/collects/math/matrix.rkt @@ -2,21 +2,24 @@ (require "private/matrix/matrix-arithmetic.rkt" "private/matrix/matrix-constructors.rkt" + "private/matrix/matrix-conversion.rkt" + "private/matrix/matrix-syntax.rkt" "private/matrix/matrix-basic.rkt" "private/matrix/matrix-operations.rkt" "private/matrix/matrix-comprehension.rkt" - "private/matrix/matrix-sequences.rkt" "private/matrix/matrix-expt.rkt" "private/matrix/matrix-types.rkt" + "private/matrix/matrix-2d.rkt" "private/matrix/utils.rkt") (provide (all-from-out "private/matrix/matrix-arithmetic.rkt" "private/matrix/matrix-constructors.rkt" + "private/matrix/matrix-conversion.rkt" + "private/matrix/matrix-syntax.rkt" "private/matrix/matrix-basic.rkt" "private/matrix/matrix-operations.rkt" "private/matrix/matrix-comprehension.rkt" - "private/matrix/matrix-sequences.rkt" "private/matrix/matrix-expt.rkt" - "private/matrix/matrix-types.rkt") - matrix?) + "private/matrix/matrix-types.rkt" + "private/matrix/matrix-2d.rkt")) diff --git a/collects/math/private/matrix/matrix-2d.rkt b/collects/math/private/matrix/matrix-2d.rkt index 9efdde259e..9f30616bcd 100644 --- a/collects/math/private/matrix/matrix-2d.rkt +++ b/collects/math/private/matrix/matrix-2d.rkt @@ -2,7 +2,7 @@ (require math/array "matrix-types.rkt" - "matrix-constructors.rkt") + "matrix-syntax.rkt") (provide matrix-2d-rotation matrix-2d-scaling diff --git a/collects/math/private/matrix/matrix-arithmetic.rkt b/collects/math/private/matrix/matrix-arithmetic.rkt index e0d5e85499..cf65194af2 100644 --- a/collects/math/private/matrix/matrix-arithmetic.rkt +++ b/collects/math/private/matrix/matrix-arithmetic.rkt @@ -19,21 +19,18 @@ [(_ . es) (syntax/loc inner-stx (typed:fun . es))] [_ (syntax/loc inner-stx typed:fun)])))])) -(define/inline-macro matrix* (a . as) inline-matrix* typed:matrix*) -(define/inline-macro matrix+ (a . as) inline-matrix+ typed:matrix+) -(define/inline-macro matrix- (a . as) inline-matrix- typed:matrix-) +(define/inline-macro matrix* (a as ...) inline-matrix* typed:matrix*) +(define/inline-macro matrix+ (a as ...) inline-matrix+ typed:matrix+) +(define/inline-macro matrix- (a as ...) inline-matrix- typed:matrix-) (define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale) +(define/inline-macro do-matrix-map (f a as ...) inline-matrix-map matrix-map) + (provide - ;; Equality - (rename-out [typed:matrix= matrix=]) - ;; Mapping - inline-matrix-map - matrix-map - ;; Multiplication + (rename-out [do-matrix-map matrix-map] + [typed:matrix= matrix=] + [typed:matrix-sum matrix-sum]) matrix* - ;; Pointwise operators matrix+ matrix- - matrix-scale - (rename-out [typed:matrix-sum matrix-sum])) + matrix-scale) diff --git a/collects/math/private/matrix/matrix-basic.rkt b/collects/math/private/matrix/matrix-basic.rkt index c5f4b04e43..b84ad8d55d 100644 --- a/collects/math/private/matrix/matrix-basic.rkt +++ b/collects/math/private/matrix/matrix-basic.rkt @@ -5,6 +5,7 @@ math/array math/flonum "matrix-types.rkt" + "matrix-arithmetic.rkt" "utils.rkt" "../unsafe.rkt") @@ -18,7 +19,7 @@ matrix-rows matrix-cols ;; Predicates - zero-matrix? + matrix-zero? ;; Embiggenment matrix-augment matrix-stack @@ -73,7 +74,7 @@ (unsafe-vector-set! ij 0 0) res))])) -(: matrix-col (All (A) (Matrix A) Index -> (Matrix A))) +(: matrix-col (All (A) (Matrix A) Integer -> (Matrix A))) (define (matrix-col a j) (define-values (m n) (matrix-shape a)) (cond [(or (j . < . 0) (j . >= . n)) @@ -99,9 +100,9 @@ ;; =================================================================================================== ;; Predicates -(: zero-matrix? ((Array Number) -> Boolean)) -(define (zero-matrix? a) - (array-all-and (array-map zero? a))) +(: matrix-zero? ((Array Number) -> Boolean)) +(define (matrix-zero? a) + (array-all-and (matrix-map zero? a))) ;; =================================================================================================== ;; Embiggenment (this is a perfectly cromulent word) @@ -179,9 +180,14 @@ ((Array Number) (Array Number) -> Number))) ;; Computes the Frobenius inner product of two matrices (define (matrix-dot a b) - (cond [(not (matrix? a)) (raise-argument-error 'matrix-dot "matrix?" 0 a b)] - [(not (matrix? b)) (raise-argument-error 'matrix-dot "matrix?" 1 a b)] - [else (array-all-sum (array* a (array-conjugate b)))])) + (define-values (m n) (matrix-shapes 'matrix-dot a b)) + (define aproc (unsafe-array-proc a)) + (define bproc (unsafe-array-proc b)) + (array-all-sum + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (* (aproc js) (conjugate (bproc js))))))) ;; =================================================================================================== ;; Operators diff --git a/collects/math/private/matrix/matrix-column.rkt b/collects/math/private/matrix/matrix-column.rkt index ab864dfed7..9f452ac512 100644 --- a/collects/math/private/matrix/matrix-column.rkt +++ b/collects/math/private/matrix/matrix-column.rkt @@ -3,7 +3,7 @@ (require math/array math/base "matrix-types.rkt" - "matrix-constructors.rkt" + "matrix-conversion.rkt" "matrix-arithmetic.rkt" "../unsafe.rkt") diff --git a/collects/math/private/matrix/matrix-comprehension.rkt b/collects/math/private/matrix/matrix-comprehension.rkt index f9ac60b731..ed07e4d971 100644 --- a/collects/math/private/matrix/matrix-comprehension.rkt +++ b/collects/math/private/matrix/matrix-comprehension.rkt @@ -1,140 +1,63 @@ -#lang racket +#lang racket/base -(require math/array - typed/racket/base - "matrix-types.rkt" - "matrix-constructors.rkt") +(require (for-syntax racket/base + syntax/parse) + math/array) -(provide for/matrix - for*/matrix - for/matrix: - for*/matrix:) +(provide for/matrix: + for*/matrix: + for/matrix + for*/matrix) -;;; COMPREHENSIONS +(module typed-defs typed/racket/base + (require (for-syntax racket/base + syntax/parse) + math/array) + + (provide (all-defined-out)) + + (: ensure-matrix-dims (Symbol Any Any -> (Values Positive-Index Positive-Index))) + (define (ensure-matrix-dims name m n) + (cond [(or (not (index? m)) (zero? m)) (raise-argument-error name "Positive-Index" 0 m n)] + [(or (not (index? n)) (zero? n)) (raise-argument-error name "Positive-Index" 1 m n)] + [else (values m n)])) + + (define-syntax (base-for/matrix: stx) + (syntax-parse stx #:literals (:) + [(_ name:id for/array:id + m-expr:expr n-expr:expr + (~optional (~seq #:fill fill-expr:expr)) + (clause ...) + (~optional (~seq : A:expr)) + body:expr ...+) + (with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())] + [(maybe-type ...) (if (attribute A) #'(: A) #'())]) + (syntax/loc stx + (let-values ([(m n) (ensure-matrix-dims 'name + (ann m-expr Integer) + (ann n-expr Integer))]) + (for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...) maybe-type ... + body ...))))])) + + (define-syntax-rule (for/matrix: e ...) (base-for/matrix: for/matrix: for/array: e ...)) + (define-syntax-rule (for*/matrix: e ...) (base-for/matrix: for*/matrix: for*/array: e ...)) + + ) -; (for/matrix m n (clause ...) . defs+exprs) -; Return an m x n matrix with elements from the last expr. -; The first n values produced becomes the first row. -; The next n values becomes the second row and so on. -; The bindings in clauses run in parallel. -(define-syntax (for/matrix stx) - (syntax-case stx () - [(_ m-expr n-expr (clause ...) . defs+exprs) - (syntax/loc stx - (let ([m m-expr] [n n-expr]) - (define flat-vector - (for/vector #:length (* m n) - (clause ...) . defs+exprs)) - (vector->matrix m n flat-vector)))])) +(require (submod "." typed-defs)) -; (for*/matrix m n (clause ...) . defs+exprs) -; Return an m x n matrix with elements from the last expr. -; The first n values produced becomes the first row. -; The next n values becomes the second row and so on. -; The bindings in clauses run nested. -; (for*/matrix m n #:column (clause ...) . defs+exprs) -; Return an m x n matrix with elements from the last expr. -; The first m values produced becomes the first column. -; The next m values becomes the second column and so on. -; The bindings in clauses run nested. -(define-syntax (for*/matrix stx) - (syntax-case stx () - [(_ m-expr n-expr #:column (clause ...) . defs+exprs) - (syntax/loc stx - (let* ([m m-expr] - [n n-expr] - [v (make-vector (* m n) 0)] - [w (for*/vector #:length (* m n) (clause ...) . defs+exprs)]) - (for* ([i (in-range m)] [j (in-range n)]) - (vector-set! v (+ (* i n) j) - (vector-ref w (+ (* j m) i)))) - (vector->matrix m n v)))] - [(_ m-expr n-expr (clause ...) . defs+exprs) - (syntax/loc stx - (let ([m m-expr] [n n-expr]) - (vector->matrix - m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))])) +(define-syntax (base-for/matrix stx) + (syntax-parse stx + [(_ name:id for/array:id + m-expr:expr n-expr:expr + (~optional (~seq #:fill fill-expr:expr)) + (clause ...) + body:expr ...+) + (with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())]) + (syntax/loc stx + (let-values ([(m n) (ensure-matrix-dims 'name m-expr n-expr)]) + (for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...) + body ...))))])) - -(define-syntax (for/column: stx) - (syntax-case stx () - [(_ : type m-expr (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: flat-vector : (Vectorof Number) (make-vector m 0)) - (for: ([i (in-range m)] for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! flat-vector i x)) - (vector->col-matrix flat-vector)))])) - -(define-syntax (for/matrix: stx) - (syntax-case stx () - [(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (define: k : Index 0) - (for: ([i (in-range m*n)] for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v (+ (* n (remainder k m)) (quotient k m)) x) - (set! k (assert (+ k 1) index?))) - (vector->matrix m n v)))] - [(_ : type m-expr n-expr (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (for: ([i (in-range m*n)] for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v i x)) - (vector->matrix m n v)))])) - -(define-syntax (for*/matrix: stx) - (syntax-case stx () - [(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (define: k : Index 0) - (for*: (for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v (+ (* n (remainder k m)) (quotient k m)) x) - (set! k (assert (+ k 1) index?))) - (vector->matrix m n v)))] - [(_ : type m-expr n-expr (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: m : Index m-expr) - (define: n : Index n-expr) - (define: m*n : Index (assert (* m n) index?)) - (define: v : (Vectorof Number) (make-vector m*n 0)) - (define: i : Index 0) - (for*: (for:-clause ...) - (define x (let () . defs+exprs)) - (vector-set! v i x) - (set! i (assert (+ i 1) index?))) - (vector->matrix m n v)))])) -#; -(module* test #f - (require rackunit) - ; "matrix-sequences.rkt" - ; These work in racket not in typed racket - (check-equal? (matrix->list* (for*/matrix 2 3 ([i 2] [j 3]) (+ i j))) - '[[0 1 2] [1 2 3]]) - (check-equal? (matrix->list* (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j))) - '[[0 2 2] [1 1 3]]) - (check-equal? (matrix->list* (for*/matrix 2 2 #:column ([i 4]) i)) - '[[0 2] [1 3]]) - (check-equal? (matrix->list* (for/matrix 2 2 ([i 4]) i)) - '[[0 1] [2 3]]) - (check-equal? (matrix->list* (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j))) - '[[6 8 10] [12 14 16]])) +(define-syntax-rule (for/matrix e ...) (base-for/matrix for/matrix for/array e ...)) +(define-syntax-rule (for*/matrix e ...) (base-for/matrix for*/matrix for*/array e ...)) diff --git a/collects/math/private/matrix/matrix-constructors.rkt b/collects/math/private/matrix/matrix-constructors.rkt index c4fc391b2f..e920990ac5 100644 --- a/collects/math/private/matrix/matrix-constructors.rkt +++ b/collects/math/private/matrix/matrix-constructors.rkt @@ -1,376 +1,151 @@ -#lang racket/base +#lang typed/racket/base -(provide - ;; Constructors - identity-matrix - make-matrix - build-matrix - diagonal-matrix/zero - diagonal-matrix - block-diagonal-matrix/zero - block-diagonal-matrix - vandermonde-matrix - ;; Basic conversion - list->matrix - matrix->list - vector->matrix - matrix->vector - ->row-matrix - ->col-matrix - ;; Nested conversion - list*->matrix - matrix->list* - vector*->matrix - matrix->vector* - ;; Syntax - matrix - row-matrix - col-matrix) - -(module typed-defs typed/racket/base - (require racket/fixnum - racket/list - racket/vector - math/array - "../array/utils.rkt" - "matrix-types.rkt" - "utils.rkt" - "../unsafe.rkt") - - (provide (all-defined-out)) - - ;; ================================================================================================= - ;; Constructors - - (: identity-matrix (Integer -> (Matrix (U 0 1)))) - (define (identity-matrix m) (diagonal-array 2 m 1 0)) - - (: make-matrix (All (A) (Integer Integer A -> (Matrix A)))) - (define (make-matrix m n x) - (make-array (vector m n) x)) - - (: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A)))) - (define (build-matrix m n proc) - (cond [(or (not (index? m)) (= m 0)) - (raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)] - [(or (not (index? n)) (= n 0)) - (raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)] - [else - (unsafe-build-array - ((inst vector Index) m n) - (λ: ([js : Indexes]) - (proc (unsafe-vector-ref js 0) - (unsafe-vector-ref js 1))))])) - - (: diagonal-matrix/zero (All (A) (Array A) A -> (Array A))) - (define (diagonal-matrix/zero a zero) - (define ds (array-shape a)) - (cond [(= 1 (vector-length ds)) - (define m (unsafe-vector-ref ds 0)) - (define proc (unsafe-array-proc a)) - (unsafe-build-array - ((inst vector Index) m m) - (λ: ([js : Indexes]) - (define i (unsafe-vector-ref js 0)) - (cond [(= i (unsafe-vector-ref js 1)) (proc ((inst vector Index) i))] - [else zero])))] - [else - (raise-argument-error 'diagonal-matrix "Array with one dimension" a)])) - - (: diagonal-matrix (case-> ((Array Real) -> (Array Real)) - ((Array Number) -> (Array Number)))) - (define (diagonal-matrix a) - (diagonal-matrix/zero a 0)) - - (: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A))) - (define (block-diagonal-matrix/zero* as zero) - (define num (vector-length as)) - (define-values (ms ns) - (let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty] - [ns : (Listof Index) empty] - ) ([a (in-vector as)]) - (define-values (m n) (matrix-shape a)) - (values (cons m ms) (cons n ns)))]) - (values (reverse ms) (reverse ns)))) - (define res-m (assert (apply + ms) index?)) - (define res-n (assert (apply + ns) index?)) - (define vs ((inst make-vector Index) res-m 0)) - (define hs ((inst make-vector Index) res-n 0)) - (define is ((inst make-vector Index) res-m 0)) - (define js ((inst make-vector Index) res-n 0)) - (define-values (_res-i _res-j) - (for/fold: ([res-i : Nonnegative-Fixnum 0] - [res-j : Nonnegative-Fixnum 0] - ) ([m (in-list ms)] - [n (in-list ns)] - [k : Nonnegative-Fixnum (in-range num)]) - (let ([k (assert k index?)]) - (for: ([i : Nonnegative-Fixnum (in-range m)]) - (vector-set! vs (unsafe-fx+ res-i i) k) - (vector-set! is (unsafe-fx+ res-i i) (assert i index?))) - (for: ([j : Nonnegative-Fixnum (in-range n)]) - (vector-set! hs (unsafe-fx+ res-j j) k) - (vector-set! js (unsafe-fx+ res-j j) (assert j index?)))) - (values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n)))) - (define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as)) - (unsafe-build-array - ((inst vector Index) res-m res-n) - (λ: ([ij : Indexes]) - (define i (unsafe-vector-ref ij 0)) - (define j (unsafe-vector-ref ij 1)) - (define v (unsafe-vector-ref vs i)) - (cond [(fx= v (vector-ref hs j)) - (define proc (unsafe-vector-ref procs v)) - (define iv (unsafe-vector-ref is i)) - (define jv (unsafe-vector-ref js j)) - (unsafe-vector-set! ij 0 iv) - (unsafe-vector-set! ij 1 jv) - (define res (proc ij)) - (unsafe-vector-set! ij 0 i) - (unsafe-vector-set! ij 1 j) - res] - [else - zero])))) - - (: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A))) - (define (block-diagonal-matrix/zero as zero) - (let ([as (list->vector as)]) - (define num (vector-length as)) - (cond [(= num 0) - (raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)] - [(= num 1) - (unsafe-vector-ref as 0)] - [else - (block-diagonal-matrix/zero* as zero)]))) - - (: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real)) - ((Listof (Array Number)) -> (Array Number)))) - (define (block-diagonal-matrix as) - (block-diagonal-matrix/zero as 0)) - - (: expt-hack (case-> (Real Integer -> Real) - (Number Integer -> Number))) - ;; Stop using this when TR correctly derives expt : Real Integer -> Real - (define (expt-hack x n) - (cond [(real? x) (assert (expt x n) real?)] - [else (expt x n)])) - - (: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real)) - ((Listof Number) Integer -> (Array Number)))) - (define (vandermonde-matrix xs n) - (cond [(empty? xs) - (raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)] - [(or (not (index? n)) (zero? n)) - (raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)] - [else - (array-axis-expand (list->array xs) 1 n expt-hack)])) - - ;; ================================================================================================= - ;; Flat conversion - - (: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A)))) - (define (list->matrix m n xs) - (cond [(or (not (index? m)) (= m 0)) - (raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)] - [(or (not (index? n)) (= n 0)) - (raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)] - [else (list->array (vector m n) xs)])) - - (: matrix->list (All (A) ((Array A) -> (Listof A)))) - (define (matrix->list a) - (array->list (ensure-matrix 'matrix->list a))) - - (: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A)))) - (define (vector->matrix m n v) - (cond [(or (not (index? m)) (= m 0)) - (raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)] - [(or (not (index? n)) (= n 0)) - (raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)] - [else (vector->array (vector m n) v)])) - - (: matrix->vector (All (A) ((Array A) -> (Vectorof A)))) - (define (matrix->vector a) - (array->vector (ensure-matrix 'matrix->vector a))) - - (: list->row-matrix (All (A) ((Listof A) -> (Array A)))) - (define (list->row-matrix xs) - (cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)] - [else (list->array ((inst vector Index) 1 (length xs)) xs)])) - - (: list->col-matrix (All (A) ((Listof A) -> (Array A)))) - (define (list->col-matrix xs) - (cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)] - [else (list->array ((inst vector Index) (length xs) 1) xs)])) - - (: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) - (define (vector->row-matrix xs) - (define n (vector-length xs)) - (cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)] - [else (vector->array ((inst vector Index) 1 n) xs)])) - - (: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) - (define (vector->col-matrix xs) - (define n (vector-length xs)) - (cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)] - [else (vector->array ((inst vector Index) n 1) xs)])) - - (: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index))) - (define (find-nontrivial-axis ds) - (define dims (vector-length ds)) - (let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0]) - (cond [(k . < . dims) (define dk (unsafe-vector-ref ds k)) - (if (dk . > . 1) (values k dk) (loop (fx+ k 1)))] - [else (values 0 0)]))) - - (: array->row-matrix (All (A) ((Array A) -> (Array A)))) - (define (array->row-matrix arr) - (define (fail) - (raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr)) - (define ds (array-shape arr)) - (define dims (vector-length ds)) - (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) - (cond [(zero? (array-size arr)) (fail)] - [(row-matrix? arr) arr] - [(= num-ones dims) - (define: js : (Vectorof Index) (make-vector dims 0)) - (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) 1 1) - (λ: ([ij : Indexes]) (proc js)))] - [(= num-ones (- dims 1)) - (define-values (k n) (find-nontrivial-axis ds)) - (define js (make-thread-local-indexes dims)) - (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) 1 n) - (λ: ([ij : Indexes]) - (let ([js (js)]) - (unsafe-vector-set! js k (unsafe-vector-ref ij 1)) - (proc js))))] - [else (fail)])) - - (: array->col-matrix (All (A) ((Array A) -> (Array A)))) - (define (array->col-matrix arr) - (define (fail) - (raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr)) - (define ds (array-shape arr)) - (define dims (vector-length ds)) - (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) - (cond [(zero? (array-size arr)) (fail)] - [(col-matrix? arr) arr] - [(= num-ones dims) - (define: js : (Vectorof Index) (make-vector dims 0)) - (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) 1 1) - (λ: ([ij : Indexes]) (proc js)))] - [(= num-ones (- dims 1)) - (define-values (k m) (find-nontrivial-axis ds)) - (define js (make-thread-local-indexes dims)) - (define proc (unsafe-array-proc arr)) - (unsafe-build-array ((inst vector Index) m 1) - (λ: ([ij : Indexes]) - (let ([js (js)]) - (unsafe-vector-set! js k (unsafe-vector-ref ij 0)) - (proc js))))] - [else (fail)])) - - (: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) - (define (->row-matrix xs) - (cond [(list? xs) (list->row-matrix xs)] - [(array? xs) (array->row-matrix xs)] - [else (vector->row-matrix xs)])) - - (: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) - (define (->col-matrix xs) - (cond [(list? xs) (list->col-matrix xs)] - [(array? xs) (array->col-matrix xs)] - [else (vector->col-matrix xs)])) - - ;; ================================================================================================= - ;; Nested conversion - - (: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index))) - (define (list*-shape xss fail) - (define m (length xss)) - (cond [(m . > . 0) - (define n (length (first xss))) - (cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss))) - (values m n)] - [else (fail)])] - [else (fail)])) - - (: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing) - -> (Values Positive-Index Positive-Index))) - (define (vector*-shape xss fail) - (define m (vector-length xss)) - (cond [(m . > . 0) - (define ns ((inst vector-map Index (Vectorof A)) vector-length xss)) - (define n (vector-length (unsafe-vector-ref xss 0))) - (cond [(and (n . > . 0) - (let: loop : Boolean ([i : Nonnegative-Fixnum 1]) - (cond [(i . fx< . m) - (if (= n (vector-length (unsafe-vector-ref xss i))) - (loop (fx+ i 1)) - #f)] - [else #t]))) - (values m n)] - [else (fail)])] - [else (fail)])) - - (: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A))) - (define (list*->matrix xss) - (define (fail) - (raise-argument-error 'list*->matrix - "nested lists with rectangular shape and at least one matrix element" - xss)) - (define-values (m n) (list*-shape xss fail)) - (list->array ((inst vector Index) m n) (apply append xss))) - - (: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A)))) - (define (matrix->list* a) - (cond [(matrix? a) (array->list (array->list-array a 1))] - [else (raise-argument-error 'matrix->list* "matrix?" a)])) - - (: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A))) - (define (vector*->matrix xss) - (define (fail) - (raise-argument-error 'vector*->matrix - "nested vectors with rectangular shape and at least one matrix element" - xss)) - (define-values (m n) (vector*-shape xss fail)) - (vector->matrix m n (apply vector-append (vector->list xss)))) - - (: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A)))) - (define (matrix->vector* a) - (cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))] - [else (raise-argument-error 'matrix->vector* "matrix?" a)])) - ) ; module - -(require (for-syntax racket/base - syntax/parse) - (only-in typed/racket/base :) +(require racket/fixnum + racket/list + racket/vector math/array - (submod "." typed-defs)) + "matrix-types.rkt" + "../unsafe.rkt") -(define-syntax (matrix stx) - (syntax-parse stx #:literals (:) - [(_ [[x0 xs0 ...] [x xs ...] ...]) - (syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))] - [(_ [[x0 xs0 ...] [x xs ...] ...] : T) - (syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))] - [(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T))) - (raise-syntax-error 'matrix "given empty row" stx #'r)] - [(_ (~and [] c) (~optional (~seq : T))) - (raise-syntax-error 'matrix "given empty matrix" stx #'c)])) +(provide identity-matrix + make-matrix + build-matrix + diagonal-matrix/zero + diagonal-matrix + block-diagonal-matrix/zero + block-diagonal-matrix + vandermonde-matrix) -(define-syntax (row-matrix stx) - (syntax-parse stx #:literals (:) - [(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))] - [(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))] - [(_ (~and [] r) (~optional (~seq : T))) - (raise-syntax-error 'row-matrix "given empty row" stx #'r)])) +;; =================================================================================================== +;; Basic constructors -(define-syntax (col-matrix stx) - (syntax-parse stx #:literals (:) - [(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))] - [(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))] - [(_ (~and [] c) (~optional (~seq : T))) - (raise-syntax-error 'row-matrix "given empty column" stx #'c)])) +(: identity-matrix (Integer -> (Matrix (U 0 1)))) +(define (identity-matrix m) (diagonal-array 2 m 1 0)) + +(: make-matrix (All (A) (Integer Integer A -> (Matrix A)))) +(define (make-matrix m n x) + (make-array (vector m n) x)) + +(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A)))) +(define (build-matrix m n proc) + (cond [(or (not (index? m)) (= m 0)) + (raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)] + [(or (not (index? n)) (= n 0)) + (raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)] + [else + (unsafe-build-array + ((inst vector Index) m n) + (λ: ([js : Indexes]) + (proc (unsafe-vector-ref js 0) + (unsafe-vector-ref js 1))))])) + +;; =================================================================================================== +;; Diagonal matrices + +(: diagonal-matrix/zero (All (A) (Listof A) A -> (Array A))) +(define (diagonal-matrix/zero xs zero) + (cond [(empty? xs) + (raise-argument-error 'diagonal-matrix "nonempty List" xs)] + [else + (define vs (list->vector xs)) + (define m (vector-length vs)) + (unsafe-build-array + ((inst vector Index) m m) + (λ: ([js : Indexes]) + (define i (unsafe-vector-ref js 0)) + (cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)] + [else zero])))])) + +(: diagonal-matrix (case-> ((Listof Real) -> (Array Real)) + ((Listof Number) -> (Array Number)))) +(define (diagonal-matrix xs) + (diagonal-matrix/zero xs 0)) + +;; =================================================================================================== +;; Block diagonal matrices + +(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A))) +(define (block-diagonal-matrix/zero* as zero) + (define num (vector-length as)) + (define-values (ms ns) + (let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty] + [ns : (Listof Index) empty] + ) ([a (in-vector as)]) + (define-values (m n) (matrix-shape a)) + (values (cons m ms) (cons n ns)))]) + (values (reverse ms) (reverse ns)))) + (define res-m (assert (apply + ms) index?)) + (define res-n (assert (apply + ns) index?)) + (define vs ((inst make-vector Index) res-m 0)) + (define hs ((inst make-vector Index) res-n 0)) + (define is ((inst make-vector Index) res-m 0)) + (define js ((inst make-vector Index) res-n 0)) + (define-values (_res-i _res-j) + (for/fold: ([res-i : Nonnegative-Fixnum 0] + [res-j : Nonnegative-Fixnum 0] + ) ([m (in-list ms)] + [n (in-list ns)] + [k : Nonnegative-Fixnum (in-range num)]) + (let ([k (assert k index?)]) + (for: ([i : Nonnegative-Fixnum (in-range m)]) + (vector-set! vs (unsafe-fx+ res-i i) k) + (vector-set! is (unsafe-fx+ res-i i) (assert i index?))) + (for: ([j : Nonnegative-Fixnum (in-range n)]) + (vector-set! hs (unsafe-fx+ res-j j) k) + (vector-set! js (unsafe-fx+ res-j j) (assert j index?)))) + (values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n)))) + (define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as)) + (unsafe-build-array + ((inst vector Index) res-m res-n) + (λ: ([ij : Indexes]) + (define i (unsafe-vector-ref ij 0)) + (define j (unsafe-vector-ref ij 1)) + (define v (unsafe-vector-ref vs i)) + (cond [(fx= v (vector-ref hs j)) + (define proc (unsafe-vector-ref procs v)) + (define iv (unsafe-vector-ref is i)) + (define jv (unsafe-vector-ref js j)) + (unsafe-vector-set! ij 0 iv) + (unsafe-vector-set! ij 1 jv) + (define res (proc ij)) + (unsafe-vector-set! ij 0 i) + (unsafe-vector-set! ij 1 j) + res] + [else + zero])))) + +(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A))) +(define (block-diagonal-matrix/zero as zero) + (let ([as (list->vector as)]) + (define num (vector-length as)) + (cond [(= num 0) + (raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)] + [(= num 1) + (unsafe-vector-ref as 0)] + [else + (block-diagonal-matrix/zero* as zero)]))) + +(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real)) + ((Listof (Array Number)) -> (Array Number)))) +(define (block-diagonal-matrix as) + (block-diagonal-matrix/zero as 0)) + +;; =================================================================================================== +;; Special matrices + +(: expt-hack (case-> (Real Integer -> Real) + (Number Integer -> Number))) +;; Stop using this when TR correctly derives expt : Real Integer -> Real +(define (expt-hack x n) + (cond [(real? x) (assert (expt x n) real?)] + [else (expt x n)])) + +(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real)) + ((Listof Number) Integer -> (Array Number)))) +(define (vandermonde-matrix xs n) + (cond [(empty? xs) + (raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)] + [(or (not (index? n)) (zero? n)) + (raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)] + [else + (array-axis-expand (list->array xs) 1 n expt-hack)])) diff --git a/collects/math/private/matrix/matrix-conversion.rkt b/collects/math/private/matrix/matrix-conversion.rkt new file mode 100644 index 0000000000..57c08c7281 --- /dev/null +++ b/collects/math/private/matrix/matrix-conversion.rkt @@ -0,0 +1,202 @@ +#lang typed/racket/base + +(require racket/fixnum + racket/list + racket/vector + math/array + "matrix-types.rkt" + "utils.rkt" + "../array/utils.rkt" + "../unsafe.rkt") + +(provide + ;; Flat conversion + list->matrix + matrix->list + vector->matrix + matrix->vector + ->row-matrix + ->col-matrix + ;; Nested conversion + list*->matrix + matrix->list* + vector*->matrix + matrix->vector*) + +;; =================================================================================================== +;; Flat conversion + +(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A)))) +(define (list->matrix m n xs) + (cond [(or (not (index? m)) (= m 0)) + (raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)] + [(or (not (index? n)) (= n 0)) + (raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)] + [else (list->array (vector m n) xs)])) + +(: matrix->list (All (A) ((Array A) -> (Listof A)))) +(define (matrix->list a) + (array->list (ensure-matrix 'matrix->list a))) + +(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A)))) +(define (vector->matrix m n v) + (cond [(or (not (index? m)) (= m 0)) + (raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)] + [(or (not (index? n)) (= n 0)) + (raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)] + [else (vector->array (vector m n) v)])) + +(: matrix->vector (All (A) ((Array A) -> (Vectorof A)))) +(define (matrix->vector a) + (array->vector (ensure-matrix 'matrix->vector a))) + +(: list->row-matrix (All (A) ((Listof A) -> (Array A)))) +(define (list->row-matrix xs) + (cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)] + [else (list->array ((inst vector Index) 1 (length xs)) xs)])) + +(: list->col-matrix (All (A) ((Listof A) -> (Array A)))) +(define (list->col-matrix xs) + (cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)] + [else (list->array ((inst vector Index) (length xs) 1) xs)])) + +(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) +(define (vector->row-matrix xs) + (define n (vector-length xs)) + (cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)] + [else (vector->array ((inst vector Index) 1 n) xs)])) + +(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A)))) +(define (vector->col-matrix xs) + (define n (vector-length xs)) + (cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)] + [else (vector->array ((inst vector Index) n 1) xs)])) + +(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index))) +(define (find-nontrivial-axis ds) + (define dims (vector-length ds)) + (let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0]) + (cond [(k . < . dims) (define dk (unsafe-vector-ref ds k)) + (if (dk . > . 1) (values k dk) (loop (fx+ k 1)))] + [else (values 0 0)]))) + +(: array->row-matrix (All (A) ((Array A) -> (Array A)))) +(define (array->row-matrix arr) + (define (fail) + (raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr)) + (define ds (array-shape arr)) + (define dims (vector-length ds)) + (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) + (cond [(zero? (array-size arr)) (fail)] + [(row-matrix? arr) arr] + [(= num-ones dims) + (define: js : (Vectorof Index) (make-vector dims 0)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) 1 1) + (λ: ([ij : Indexes]) (proc js)))] + [(= num-ones (- dims 1)) + (define-values (k n) (find-nontrivial-axis ds)) + (define js (make-thread-local-indexes dims)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) 1 n) + (λ: ([ij : Indexes]) + (let ([js (js)]) + (unsafe-vector-set! js k (unsafe-vector-ref ij 1)) + (proc js))))] + [else (fail)])) + +(: array->col-matrix (All (A) ((Array A) -> (Array A)))) +(define (array->col-matrix arr) + (define (fail) + (raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr)) + (define ds (array-shape arr)) + (define dims (vector-length ds)) + (define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds)) + (cond [(zero? (array-size arr)) (fail)] + [(col-matrix? arr) arr] + [(= num-ones dims) + (define: js : (Vectorof Index) (make-vector dims 0)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) 1 1) + (λ: ([ij : Indexes]) (proc js)))] + [(= num-ones (- dims 1)) + (define-values (k m) (find-nontrivial-axis ds)) + (define js (make-thread-local-indexes dims)) + (define proc (unsafe-array-proc arr)) + (unsafe-build-array ((inst vector Index) m 1) + (λ: ([ij : Indexes]) + (let ([js (js)]) + (unsafe-vector-set! js k (unsafe-vector-ref ij 0)) + (proc js))))] + [else (fail)])) + +(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) +(define (->row-matrix xs) + (cond [(list? xs) (list->row-matrix xs)] + [(array? xs) (array->row-matrix xs)] + [else (vector->row-matrix xs)])) + +(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A)))) +(define (->col-matrix xs) + (cond [(list? xs) (list->col-matrix xs)] + [(array? xs) (array->col-matrix xs)] + [else (vector->col-matrix xs)])) + +;; =================================================================================================== +;; Nested conversion + +(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index))) +(define (list*-shape xss fail) + (define m (length xss)) + (cond [(m . > . 0) + (define n (length (first xss))) + (cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss))) + (values m n)] + [else (fail)])] + [else (fail)])) + +(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing) + -> (Values Positive-Index Positive-Index))) +(define (vector*-shape xss fail) + (define m (vector-length xss)) + (cond [(m . > . 0) + (define ns ((inst vector-map Index (Vectorof A)) vector-length xss)) + (define n (vector-length (unsafe-vector-ref xss 0))) + (cond [(and (n . > . 0) + (let: loop : Boolean ([i : Nonnegative-Fixnum 1]) + (cond [(i . fx< . m) + (if (= n (vector-length (unsafe-vector-ref xss i))) + (loop (fx+ i 1)) + #f)] + [else #t]))) + (values m n)] + [else (fail)])] + [else (fail)])) + +(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A))) +(define (list*->matrix xss) + (define (fail) + (raise-argument-error 'list*->matrix + "nested lists with rectangular shape and at least one matrix element" + xss)) + (define-values (m n) (list*-shape xss fail)) + (list->array ((inst vector Index) m n) (apply append xss))) + +(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A)))) +(define (matrix->list* a) + (cond [(matrix? a) (array->list (array->list-array a 1))] + [else (raise-argument-error 'matrix->list* "matrix?" a)])) + +(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A))) +(define (vector*->matrix xss) + (define (fail) + (raise-argument-error 'vector*->matrix + "nested vectors with rectangular shape and at least one matrix element" + xss)) + (define-values (m n) (vector*-shape xss fail)) + (vector->matrix m n (apply vector-append (vector->list xss)))) + +(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A)))) +(define (matrix->vector* a) + (cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))] + [else (raise-argument-error 'matrix->vector* "matrix?" a)])) diff --git a/collects/math/private/matrix/matrix-operations.rkt b/collects/math/private/matrix/matrix-operations.rkt index f620b9bcef..1b757e52da 100644 --- a/collects/math/private/matrix/matrix-operations.rkt +++ b/collects/math/private/matrix/matrix-operations.rkt @@ -6,6 +6,7 @@ "../unsafe.rkt" "matrix-types.rkt" "matrix-constructors.rkt" + "matrix-conversion.rkt" "matrix-arithmetic.rkt" "matrix-basic.rkt" "matrix-column.rkt" @@ -19,13 +20,6 @@ ; 4. Pseudo inverse ; 5. Eigenvalues and eigenvectors -; 6. "Bug" -; (for*/matrix : Number 2 3 ([i (in-naturals)]) i) -; ought to generate a matrix with numbers from 0 to 5. -; Problem: In expansion of for/matrix an extra [i (in-range (* m n))] -; is added to make sure the comprehension stops. -; But TR has problems with #:when so what is the proper expansion ? - (provide matrix-inverse ; row and column @@ -582,7 +576,7 @@ ; Note: We project onto vs (not on the original ws) ; in order to get numerical stability. (let ([w-minus-proj (array-strict (array- w w-proj))]) - (if (zero-matrix? w-minus-proj) + (if (matrix-zero? w-minus-proj) (loop vs (cdr ws)) ; w in span{vs} => omit it (loop (cons w-minus-proj vs) (cdr ws)))))])) (reverse (loop (list (car ws)) (cdr ws)))])) diff --git a/collects/math/private/matrix/matrix-sequences.rkt b/collects/math/private/matrix/matrix-sequences.rkt deleted file mode 100644 index a0c2be15cf..0000000000 --- a/collects/math/private/matrix/matrix-sequences.rkt +++ /dev/null @@ -1,340 +0,0 @@ -#lang racket - -(provide in-row - in-column) - -(require math/array - "matrix-types.rkt" - "matrix-basic.rkt" - "matrix-constructors.rkt" - ) - -(define (in-row/proc M r) - (define-values (m n) (matrix-shape M)) - (make-do-sequence - (λ () - (values - ; pos->element - (λ (j) (matrix-ref M r j)) - ; next-pos - (λ (j) (+ j 1)) - ; initial-pos - 0 - ; continue-with-pos? - (λ (j) (< j n)) - #f #f )))) - -; (in-row M i] -; Returns a sequence of all elements of row i, -; that is xi0, xi1, xi2, ... -(define-sequence-syntax in-row - (λ () #'in-row/proc) - (λ (stx) - (syntax-case stx () - [[(x) (_ M-expr r-expr)] - #'((x) - (:do-in - ([(M r n d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 r-expr rd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? r) - (raise-type-error 'in-row "expected row number, got ~a" r)) - (unless (and (integer? r) (and (<= 0 r ) (< r n))) - (raise-type-error 'in-row "expected row number, got ~a" r))) - ([j 0]) - (< j n) - ([(x) (vector-ref d (+ (* r n) j))]) - #true - #true - [(+ j 1)]))] - [[(i x) (_ M-expr r-expr)] - #'((i x) - (:do-in - ([(M r n d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 r-expr rd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? r) - (raise-type-error 'in-row "expected row number, got ~a" r))) - ([j 0]) - (< j n) - ([(x) (vector-ref d (+ (* r n) j))] - [(i) j]) - #true - #true - [(+ j 1)]))] - [[_ clause] (raise-syntax-error - 'in-row "expected (in-row )" #'clause #'clause)]))) - -; (in-column M j] -; Returns a sequence of all elements of column j, -; that is x0j, x1j, x2j, ... - - -(define (in-column/proc M s) - (define-values (m n) (matrix-shape M)) - (make-do-sequence - (λ () - (values - ; pos->element - (λ (i) (matrix-ref M i s)) - ; next-pos - (λ (i) (+ i 1)) - ; initial-pos - 0 - ; continue-with-pos? - (λ (i) (< i m)) - #f #f )))) - -(define-sequence-syntax in-column - (λ () #'in-column/proc) - (λ (stx) - (syntax-case stx () - [[(x) (_ M-expr s-expr)] - #'((x) - (:do-in - ([(M s n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 s-expr rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? s) - (raise-type-error 'in-row "expected col number, got ~a" s)) - (unless (and (integer? s) (and (<= 0 s ) (< s m))) - (raise-type-error 'in-col "expected col number, got ~a" s))) - ([j 0]) - (< j m) - ([(x) (vector-ref d (+ (* j n) s))]) - #true - #true - [(+ j 1)]))] - [[(i x) (_ M-expr s-expr)] - #'((x) - (:do-in - ([(M s n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 s-expr rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-column "expected matrix, got ~a" M)) - (unless (integer? s) - (raise-type-error 'in-column "expected col number, got ~a" s)) - (unless (and (integer? s) (and (<= 0 s ) (< s m))) - (raise-type-error 'in-column "expected col number, got ~a" s))) - ([j 0]) - (< j m) - ([(x) (vector-ref d (+ (* j n) s))] - [(i) j]) - #true - #true - [(+ j 1)]))] - [[_ clause] (raise-syntax-error - 'in-column "expected (in-column )" #'clause #'clause)]))) - -(define-syntax (for/matrix-sum: stx) - (syntax-case stx () - [(_ : type (for:-clause ...) . defs+exprs) - (syntax/loc stx - (let () - (define: sum : (U False (Matrix Number)) #f) - (for: (for:-clause ...) - (define a (let () . defs+exprs)) - (set! sum (if sum (array+ (assert sum) a) a))) - (assert sum)))])) -#| -;;; -;;; SEQUENCES -;;; - -(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number)) -(define (in-row/proc M r) - (define-values (m n) (matrix-shape M)) - (make-do-sequence - (λ () - (values - ; pos->element - (λ: ([j : Index]) (matrix-ref M r j)) - ; next-pos - (λ: ([j : Index]) (assert (+ j 1) index?)) - ; initial-pos - 0 - ; continue-with-pos? - (λ: ([j : Index ]) (< j n)) - #f #f)))) - -(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number)) -(define (in-column/proc M s) - (define-values (m n) (matrix-shape M)) - (make-do-sequence - (λ () - (values - ; pos->element - (λ: ([i : Index]) (matrix-ref M i s)) - ; next-pos - (λ: ([i : Index]) (assert (+ i 1) index?)) - ; initial-pos - 0 - ; continue-with-pos? - (λ: ([i : Index]) (< i m)) - #f #f)))) - -; (in-row M i] -; Returns a sequence of all elements of row i, -; that is xi0, xi1, xi2, ... -(define-sequence-syntax in-row - (λ () #'in-row/proc) - (λ (stx) - (syntax-case stx () - [[(x) (_ M-expr r-expr)] - #'((x) - (:do-in - ([(M r n d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 r-expr rd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? r) - (raise-type-error 'in-row "expected row number, got ~a" r)) - (unless (and (integer? r) (and (<= 0 r ) (< r n))) - (raise-type-error 'in-row "expected row number, got ~a" r))) - ([j 0]) - (< j n) - ([(x) (vector-ref d (+ (* r n) j))]) - #true - #true - [(+ j 1)]))] - [[(i x) (_ M-expr r-expr)] - #'((i x) - (:do-in - ([(M r n d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 r-expr rd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? r) - (raise-type-error 'in-row "expected row number, got ~a" r))) - ([j 0]) - (< j n) - ([(x) (vector-ref d (+ (* r n) j))] - [(i) j]) - #true - #true - [(+ j 1)]))] - [[_ clause] (raise-syntax-error - 'in-row "expected (in-row )" #'clause #'clause)]))) - -; (in-column M j] -; Returns a sequence of all elements of column j, -; that is x0j, x1j, x2j, ... - -(define-sequence-syntax in-column - (λ () #'in-column/proc) - (λ (stx) - (syntax-case stx () - ; M-expr evaluates to column - [[(x) (_ M-expr)] - #'((x) - (:do-in - ([(M n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (unless (matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - ([j 0]) - (< j n) - ([(x) (vector-ref d j)]) - #true - #true - [(+ j 1)]))] - ; M-expr evaluats to matrix, s-expr to the column index - [[(x) (_ M-expr s-expr)] - #'((x) - (:do-in - ([(M s n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 s-expr rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-row "expected matrix, got ~a" M)) - (unless (integer? s) - (raise-type-error 'in-row "expected col number, got ~a" s)) - (unless (and (integer? s) (and (<= 0 s ) (< s m))) - (raise-type-error 'in-col "expected col number, got ~a" s))) - ([j 0]) - (< j m) - ([(x) (vector-ref d (+ (* j n) s))]) - #true - #true - [(+ j 1)]))] - [[(i x) (_ M-expr s-expr)] - #'((x) - (:do-in - ([(M s n m d) - (let ([M1 M-expr]) - (define-values (rd cd) (matrix-shape M1)) - (values M1 s-expr rd cd - (mutable-array-data - (array->mutable-array M1))))]) - (begin - (unless (matrix? M) - (raise-type-error 'in-column "expected matrix, got ~a" M)) - (unless (integer? s) - (raise-type-error 'in-column "expected col number, got ~a" s)) - (unless (and (integer? s) (and (<= 0 s ) (< s m))) - (raise-type-error 'in-column "expected col number, got ~a" s))) - ([j 0]) - (< j m) - ([(x) (vector-ref d (+ (* j n) s))] - [(i) j]) - #true - #true - [(+ j 1)]))] - [[_ clause] (raise-syntax-error - 'in-column "expected (in-column )" #'clause #'clause)]))) -|# -#; -(module* test #f - (require rackunit) - ; "matrix-sequences.rkt" - (check-equal? (for/list ([x (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) x) - '(3 4)) - (check-equal? (for/list ([(i x) (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) - (list i x)) - '((0 3) (1 4))) - (check-equal? (for/list ([x (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) x) - '(2 4)) - (check-equal? (for/list ([(i x) (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) - (list i x)) - '((0 2) (1 4)))) diff --git a/collects/math/private/matrix/matrix-syntax.rkt b/collects/math/private/matrix/matrix-syntax.rkt new file mode 100644 index 0000000000..62ffc52d8c --- /dev/null +++ b/collects/math/private/matrix/matrix-syntax.rkt @@ -0,0 +1,35 @@ +#lang racket/base + +(require (for-syntax racket/base + syntax/parse) + (only-in typed/racket/base :) + math/array) + +(provide matrix row-matrix col-matrix) + +(define-syntax (matrix stx) + (syntax-parse stx #:literals (:) + [(_ [[x0 xs0 ...] [x xs ...] ...]) + (syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))] + [(_ [[x0 xs0 ...] [x xs ...] ...] : T) + (syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))] + [(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T))) + (raise-syntax-error 'matrix "given empty row" stx #'r)] + [(_ (~and [] c) (~optional (~seq : T))) + (raise-syntax-error 'matrix "given empty matrix" stx #'c)] + [(_ x (~optional (~seq : T))) + (raise-syntax-error 'matrix "expected two-dimensional data" stx)])) + +(define-syntax (row-matrix stx) + (syntax-parse stx #:literals (:) + [(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))] + [(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))] + [(_ (~and [] r) (~optional (~seq : T))) + (raise-syntax-error 'row-matrix "given empty row" stx #'r)])) + +(define-syntax (col-matrix stx) + (syntax-parse stx #:literals (:) + [(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))] + [(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))] + [(_ (~and [] c) (~optional (~seq : T))) + (raise-syntax-error 'row-matrix "given empty column" stx #'c)])) diff --git a/collects/math/private/matrix/matrix-types.rkt b/collects/math/private/matrix/matrix-types.rkt index ae1950ff04..74925601fa 100644 --- a/collects/math/private/matrix/matrix-types.rkt +++ b/collects/math/private/matrix/matrix-types.rkt @@ -27,44 +27,53 @@ (define-type (Column-Matrix A) (Matrix A)) ; a column vector represented as a matrix -(define matrix? - (plambda: (A) ([arr : (Array A)]) - (and (> (array-size arr) 0) - (= (array-dims arr) 2)))) +(: matrix? (All (A) ((Array A) -> Boolean))) +(define (matrix? arr) + (and (> (array-size arr) 0) + (= (array-dims arr) 2))) -(define square-matrix? - (plambda: (A) ([arr : (Array A)]) - (and (matrix? arr) - (let ([sh (array-shape arr)]) - (= (vector-ref sh 0) (vector-ref sh 1)))))) +(: square-matrix? (All (A) ((Array A) -> Boolean))) +(define (square-matrix? arr) + (define ds (array-shape arr)) + (and (= (vector-length ds) 2) + (let ([d0 (unsafe-vector-ref ds 0)] + [d1 (unsafe-vector-ref ds 1)]) + (and (> d0 0) (> d1 0) (= d0 d1))))) -(define row-matrix? - (plambda: (A) ([arr : (Array A)]) - (and (matrix? arr) - (= (vector-ref (array-shape arr) 0) 1)))) +(: row-matrix? (All (A) ((Array A) -> Boolean))) +(define (row-matrix? arr) + (define ds (array-shape arr)) + (and (= (vector-length ds) 2) + (= (unsafe-vector-ref ds 0) 1) + (> (unsafe-vector-ref ds 1) 0))) -(define col-matrix? - (plambda: (A) ([arr : (Array A)]) - (and (matrix? arr) - (= (vector-ref (array-shape arr) 1) 1)))) +(: col-matrix? (All (A) ((Array A) -> Boolean))) +(define (col-matrix? arr) + (define ds (array-shape arr)) + (and (= (vector-length ds) 2) + (> (unsafe-vector-ref ds 0) 0) + (= (unsafe-vector-ref ds 1) 1))) -(: matrix-shape : (All (A) (Matrix A) -> (Values Index Index))) +(: matrix-shape (All (A) ((Array A) -> (Values Index Index)))) (define (matrix-shape a) - (cond [(matrix? a) (define sh (array-shape a)) - (values (unsafe-vector-ref sh 0) (unsafe-vector-ref sh 1))] - [else (raise-argument-error 'matrix-shape "matrix?" a)])) + (define ds (array-shape a)) + (if (and (> (array-size a) 0) + (= (vector-length ds) 2)) + (values (unsafe-vector-ref ds 0) + (unsafe-vector-ref ds 1)) + (raise-argument-error 'matrix-shape "matrix?" a))) -(: square-matrix-size (All (A) ((Matrix A) -> Index))) +(: square-matrix-size (All (A) ((Array A) -> Index))) (define (square-matrix-size arr) (cond [(square-matrix? arr) (unsafe-vector-ref (array-shape arr) 0)] [else (raise-argument-error 'square-matrix-size "square-matrix?" arr)])) -(: matrix-num-rows (All (A) ((Matrix A) -> Index))) +(: matrix-num-rows (All (A) ((Array A) -> Index))) (define (matrix-num-rows a) (cond [(matrix? a) (vector-ref (array-shape a) 0)] [else (raise-argument-error 'matrix-col-length "matrix?" a)])) -(: matrix-num-cols (All (A) ((Matrix A) -> Index))) +(: matrix-num-cols (All (A) ((Array A) -> Index))) (define (matrix-num-cols a) (cond [(matrix? a) (vector-ref (array-shape a) 1)] [else (raise-argument-error 'matrix-row-length "matrix?" a)])) diff --git a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt index 9dceed073f..f7cd81d5b9 100644 --- a/collects/math/private/matrix/untyped-matrix-arithmetic.rkt +++ b/collects/math/private/matrix/untyped-matrix-arithmetic.rkt @@ -73,8 +73,6 @@ ) ; module -(require 'syntax-defs) - (module untyped-defs typed/racket/base (require math/array (submod ".." syntax-defs) @@ -102,4 +100,5 @@ ) ; module -(require 'untyped-defs) +(require 'syntax-defs + 'untyped-defs) diff --git a/collects/math/private/matrix/utils.rkt b/collects/math/private/matrix/utils.rkt index fefced18be..016e1ff0a7 100644 --- a/collects/math/private/matrix/utils.rkt +++ b/collects/math/private/matrix/utils.rkt @@ -15,7 +15,7 @@ (define (matrix-shapes name arr . brrs) (define-values (m n) (matrix-shape arr)) (unless (andmap (λ: ([brr : (Matrix Any)]) - (match-define (vector bm bn) (array-shape brr)) + (define-values (bm bn) (matrix-shape brr)) (and (= bm m) (= bn n))) brrs) (error name diff --git a/collects/math/tests/matrix-tests.rkt b/collects/math/tests/matrix-tests.rkt index 24b8eb1a62..96dc77d514 100644 --- a/collects/math/tests/matrix-tests.rkt +++ b/collects/math/tests/matrix-tests.rkt @@ -7,192 +7,673 @@ "../private/matrix/matrix-column.rkt" "test-utils.rkt") -(: random-matrix (Integer Integer Integer -> (Matrix Integer))) -;; Generates a random matrix with integer elements < k. Useful to test properties. -(define (random-matrix m n k) +(: random-matrix (case-> (Integer Integer -> (Matrix Integer)) + (Integer Integer Integer -> (Matrix Integer)))) +;; Generates a random matrix with Natural elements < k. Useful to test properties. +(define (random-matrix m n [k 100]) (array-strict (build-array (vector m n) (λ (_) (random k))))) +(define nonmatrices + (list (make-array #() 0) + (make-array #(1) 0) + (make-array #(1 0) 0) + (make-array #(0 1) 0) + (make-array #(0 0) 0) + (make-array #(1 1 1) 0))) + ;; =================================================================================================== -;; Types +;; Literal syntax + +(check-equal? (matrix [[1]]) + (array #[#[1]])) + +(check-equal? (matrix [[1 2 3 4]]) + (array #[#[1 2 3 4]])) + +(check-equal? (matrix [[1 2] [3 4]]) + (array #[#[1 2] #[3 4]])) + +(check-equal? (matrix [[1] [2] [3] [4]]) + (array #[#[1] #[2] #[3] #[4]])) + +(check-equal? (row-matrix [1 2 3 4]) + (matrix [[1 2 3 4]])) + +(check-equal? (col-matrix [1 2 3 4]) + (matrix [[1] [2] [3] [4]])) + +;; =================================================================================================== +;; Predicates (check-true (matrix? (array #[#[1]]))) (check-false (matrix? (array #[1]))) (check-false (matrix? (array 1))) (check-false (matrix? (array #[]))) +(for: ([a (in-list nonmatrices)]) + (check-false (matrix? a))) + +(check-true (square-matrix? (matrix [[1]]))) +(check-true (square-matrix? (matrix [[1 1] [1 1]]))) +(check-false (square-matrix? (matrix [[1 2]]))) +(check-false (square-matrix? (matrix [[1] [2]]))) +(for: ([a (in-list nonmatrices)]) + (check-false (square-matrix? a))) (check-true (row-matrix? (matrix [[1 2 3 4]]))) +(check-true (row-matrix? (matrix [[1]]))) (check-false (row-matrix? (matrix [[1] [2] [3] [4]]))) +(for: ([a (in-list nonmatrices)]) + (check-false (row-matrix? a))) (check-true (col-matrix? (matrix [[1] [2] [3] [4]]))) +(check-true (col-matrix? (matrix [[1]]))) (check-false (col-matrix? (matrix [[1 2 3 4]]))) +(check-false (col-matrix? (array #[1]))) +(check-false (col-matrix? (array 1))) +(check-false (col-matrix? (array #[]))) +(for: ([a (in-list nonmatrices)]) + (check-false (col-matrix? a))) + +(check-true (matrix-zero? (make-matrix 4 3 0))) +(check-true (matrix-zero? (make-matrix 4 3 0.0))) +(check-true (matrix-zero? (make-matrix 4 3 0+0.0i))) +(check-false (matrix-zero? (row-matrix [0 0 0 0 1]))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-zero? a)))) ;; =================================================================================================== -;; Matrix multiplication +;; Accessors -(check-equal? (matrix* (identity-matrix 2) - (matrix [[1 20] [300 4000]])) - (matrix [[1 20] [300 4000]])) +;; matrix-shape -(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]]) - (matrix [[1 2 3] [4 5 6] [7 8 9]])) - (matrix [[30 36 42] [66 81 96] [102 126 150]])) +(check-equal? (let-values ([(m n) (matrix-shape (matrix [[1 2 3] [4 5 6]]))]) + (list m n)) + (list 2 3)) -(let ([m0 (random-matrix 4 5 100)] - [m1 (random-matrix 5 2 100)] - [m2 (random-matrix 2 10 100)]) - (check-equal? (matrix* (matrix* m0 m1) m2) - (matrix* m0 (matrix* m1 m2)))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (let-values ([(m n) (matrix-shape a)]) + (void))))) + +;; square-matrix-size + +(check-equal? (square-matrix-size (matrix [[1 2] [3 4]])) + 2) + +(check-exn exn:fail:contract? (λ () (square-matrix-size (matrix [[1 2]])))) +(check-exn exn:fail:contract? (λ () (square-matrix-size (matrix [[1] [2]])))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (square-matrix-size a)))) + +;; matrix-num-rows + +(check-equal? (matrix-num-rows (matrix [[1 2 3] [4 5 6]])) + 2) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-num-rows a)))) + +;; matrix-num-cols + +(check-equal? (matrix-num-cols (matrix [[1 2 3] [4 5 6]])) + 3) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-num-cols a)))) ;; =================================================================================================== -;; Construction +;; Constructors + +;; identity-matrix + +(check-equal? (identity-matrix 1) (matrix [[1]])) +(check-equal? (identity-matrix 2) (matrix [[1 0] [0 1]])) +(check-equal? (identity-matrix 3) (matrix [[1 0 0] [0 1 0] [0 0 1]])) +(check-exn exn:fail:contract? (λ () (identity-matrix 0))) + +;; make-matrix + +(check-equal? (make-matrix 1 1 4) (matrix [[4]])) +(check-equal? (make-matrix 2 2 3) (matrix [[3 3] [3 3]])) +(check-exn exn:fail:contract? (λ () (make-matrix 1 0 4))) +(check-exn exn:fail:contract? (λ () (make-matrix 0 1 4))) + +;; build-matrix + +(check-equal? (build-matrix 4 4 (λ: ([i : Index] [j : Index]) + (+ i j))) + (build-array #(4 4) (λ: ([js : Indexes]) + (+ (vector-ref js 0) (vector-ref js 1))))) +(check-exn exn:fail:contract? (λ () (build-matrix 1 0 (λ: ([i : Index] [j : Index]) (+ i j))))) +(check-exn exn:fail:contract? (λ () (build-matrix 0 1 (λ: ([i : Index] [j : Index]) (+ i j))))) + +;; diagonal-matrix + +(check-equal? (diagonal-matrix '(1 2 3 4)) + (matrix [[1 0 0 0] + [0 2 0 0] + [0 0 3 0] + [0 0 0 4]])) + +(check-exn exn:fail:contract? (λ () (diagonal-matrix '()))) + +;; block-diagonal-matrix + +(let ([m (random-matrix 4 4 100)]) + (check-equal? (block-diagonal-matrix (list m)) + m)) (check-equal? (block-diagonal-matrix - (list - (matrix [[1 2] [3 4]]) - (matrix [[1 2 3] [4 5 6]]) - (matrix [[1] [3] [5]]))) - (matrix - [[1 2 0 0 0 0] - [3 4 0 0 0 0] - [0 0 1 2 3 0] - [0 0 4 5 6 0] - [0 0 0 0 0 1] - [0 0 0 0 0 3] - [0 0 0 0 0 5]])) + (list (matrix [[1 2] [3 4]]) + (matrix [[1 2 3] [4 5 6]]) + (matrix [[1] [3] [5]]) + (matrix [[2 4 6]]))) + (matrix [[1 2 0 0 0 0 0 0 0] + [3 4 0 0 0 0 0 0 0] + [0 0 1 2 3 0 0 0 0] + [0 0 4 5 6 0 0 0 0] + [0 0 0 0 0 1 0 0 0] + [0 0 0 0 0 3 0 0 0] + [0 0 0 0 0 5 0 0 0] + [0 0 0 0 0 0 2 4 6]])) + +(check-equal? + (block-diagonal-matrix (map (λ: ([i : Integer]) (matrix [[i]])) '(1 2 3 4))) + (diagonal-matrix '(1 2 3 4))) + +(check-exn exn:fail:contract? (λ () (block-diagonal-matrix '()))) + +;; Vandermonde matrix + +(check-equal? (vandermonde-matrix '(10) 1) + (matrix [[1]])) +(check-equal? (vandermonde-matrix '(10) 4) + (matrix [[1 10 100 1000]])) +(check-equal? (vandermonde-matrix '(1 2 3 4) 3) + (matrix [[1 1 1] [1 2 4] [1 3 9] [1 4 16]])) +(check-exn exn:fail:contract? (λ () (vandermonde-matrix '() 1))) +(check-exn exn:fail:contract? (λ () (vandermonde-matrix '(1) 0))) ;; =================================================================================================== +;; Flat conversion -(begin - (begin "matrix-types.rkt" - (list - 'matrix? - (matrix? (list*->array '[[1 2] [3 4]] real? )) - (not (matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? )))) - (list - 'square-matrix? - (square-matrix? (list*->array '[[1 2] [3 4]] real? )) - (not (square-matrix? (list*->array '[[1 2 3] [4 5 6]] real? )))) - (list - 'square-matrix-size - (= 3 (square-matrix-size (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real?)))) - (list - 'matrix= - (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? )) - #;(not (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? )))) - (list - 'matrix-shape - (let-values ([(m n) (matrix-shape (list*->matrix '[[1 2 3] [4 5 6]]))]) - (equal? (list m n) '(2 3))))) - - (begin "matrix-constructors.rkt" - (list - 'identity-matrix - (equal? (array->list* (identity-matrix 1)) '[[1]]) - (equal? (array->list* (identity-matrix 2)) '[[1 0] [0 1]]) - (equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]])) - (list - 'const-matrix - (equal? (array->list* (make-matrix 2 3 0)) '((0 0 0) (0 0 0))) - (equal? (array->list* (make-matrix 2 3 0.)) '((0. 0. 0.) (0. 0. 0.)))) - (list - 'matrix->list - (equal? (matrix->list* (list*->matrix '((1 2) (3 4)))) '((1 2) (3 4))) - (equal? (matrix->list* (list*->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.)))) - (list - 'matrix->vector - (equal? (matrix->vector* ((inst vector*->matrix Integer) '#(#(1 2) #(3 4)))) - '#(#(1 2) #(3 4))) - (equal? (matrix->vector* ((inst vector*->matrix Flonum) '#(#(1. 2.) #(3. 4.)))) - '#(#(1. 2.) #(3. 4.)))) - (list - 'matrix-row - (equal? (matrix-row (identity-matrix 3) 0) (list*->matrix '[[1 0 0]])) - (equal? (matrix-row (identity-matrix 3) 1) (list*->matrix '[[0 1 0]])) - (equal? (matrix-row (identity-matrix 3) 2) (list*->matrix '[[0 0 1]]))) - (list - 'matrix-col - (equal? (matrix-col (identity-matrix 3) 0) (list*->matrix '[[1] [0] [0]])) - (equal? (matrix-col (identity-matrix 3) 1) (list*->matrix '[[0] [1] [0]])) - (equal? (matrix-col (identity-matrix 3) 2) (list*->matrix '[[0] [0] [1]]))) - (list - 'submatrix - (equal? (submatrix (identity-matrix 3) - (in-range 0 1) (in-range 0 2)) (list*->matrix '[[1 0]])) - (equal? (submatrix (identity-matrix 3) - (in-range 0 2) (in-range 0 3)) (list*->matrix '[[1 0 0] [0 1 0]])))) - - (begin - "matrix-pointwise.rkt" - (let () - (define A (list*->matrix '[[1 2] [3 4]])) - (define ~A (list*->matrix '[[-1 -2] [-3 -4]])) - (define B (list*->matrix '[[5 6] [7 8]])) - (define A+B (list*->matrix '[[6 8] [10 12]])) - (define A-B (list*->matrix '[[-4 -4] [-4 -4]])) - (list 'matrix+ (equal? (matrix+ A B) A+B)) - (list 'matrix- - (equal? (matrix- A B) A-B) - (equal? (matrix- A) ~A)))) - - (begin - "matrix-expt.rkt" - (let () - (define A (list*->matrix '[[1 2] [3 4]])) - (list - 'matrix-expt - (equal? (matrix-expt A 0) (identity-matrix 2)) - (equal? (matrix-expt A 1) A) - (equal? (matrix-expt A 2) (list*->matrix '[[7 10] [15 22]])) - (equal? (matrix-expt A 3) (list*->matrix '[[37 54] [81 118]])) - (equal? (matrix-expt A 8) (list*->matrix '[[165751 241570] [362355 528106]]))))) +(check-equal? (list->matrix 1 3 '(1 2 3)) (row-matrix [1 2 3])) +(check-equal? (list->matrix 3 1 '(1 2 3)) (col-matrix [1 2 3])) +(check-exn exn:fail:contract? (λ () (list->matrix 0 1 '()))) +(check-exn exn:fail:contract? (λ () (list->matrix 1 0 '()))) +(check-exn exn:fail:contract? (λ () (list->matrix 1 1 '(1 2)))) + +(check-equal? (vector->matrix 1 3 #(1 2 3)) (row-matrix [1 2 3])) +(check-equal? (vector->matrix 3 1 #(1 2 3)) (col-matrix [1 2 3])) +(check-exn exn:fail:contract? (λ () (vector->matrix 0 1 #()))) +(check-exn exn:fail:contract? (λ () (vector->matrix 1 0 #()))) +(check-exn exn:fail:contract? (λ () (vector->matrix 1 1 #(1 2)))) + +(check-equal? (->row-matrix '(1 2 3)) (row-matrix [1 2 3])) +(check-equal? (->row-matrix #(1 2 3)) (row-matrix [1 2 3])) +(check-equal? (->row-matrix (row-matrix [1 2 3])) (row-matrix [1 2 3])) +(check-equal? (->row-matrix (col-matrix [1 2 3])) (row-matrix [1 2 3])) +(check-equal? (->row-matrix (make-array #() 1)) (row-matrix [1])) +(check-equal? (->row-matrix (make-array #(3) 1)) (row-matrix [1 1 1])) +(check-equal? (->row-matrix (make-array #(1 3 1) 1)) (row-matrix [1 1 1])) +(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(2 3 1) 1)))) +(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(1 3 2) 1)))) +(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(0 3) 1)))) +(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(3 0) 1)))) + +(check-equal? (->col-matrix '(1 2 3)) (col-matrix [1 2 3])) +(check-equal? (->col-matrix #(1 2 3)) (col-matrix [1 2 3])) +(check-equal? (->col-matrix (col-matrix [1 2 3])) (col-matrix [1 2 3])) +(check-equal? (->col-matrix (row-matrix [1 2 3])) (col-matrix [1 2 3])) +(check-equal? (->col-matrix (make-array #() 1)) (col-matrix [1])) +(check-equal? (->col-matrix (make-array #(3) 1)) (col-matrix [1 1 1])) +(check-equal? (->col-matrix (make-array #(1 3 1) 1)) (col-matrix [1 1 1])) +(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(2 3 1) 1)))) +(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(1 3 2) 1)))) +(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(0 3) 1)))) +(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(3 0) 1)))) + +(check-equal? (matrix->list (matrix [[1 2 3] [4 5 6]])) '(1 2 3 4 5 6)) +(check-equal? (matrix->list (row-matrix [1 2 3])) '(1 2 3)) +(check-equal? (matrix->list (col-matrix [1 2 3])) '(1 2 3)) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix->list a)))) + +(check-equal? (matrix->vector (matrix [[1 2 3] [4 5 6]])) #(1 2 3 4 5 6)) +(check-equal? (matrix->vector (row-matrix [1 2 3])) #(1 2 3)) +(check-equal? (matrix->vector (col-matrix [1 2 3])) #(1 2 3)) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix->vector a)))) + +;; =================================================================================================== +;; Nested conversion + +(check-equal? (list*->matrix '((1 2 3) (4 5 6))) (matrix [[1 2 3] [4 5 6]])) +(check-exn exn:fail:contract? (λ () (list*->matrix '((1 2 3) (4 5))))) +(check-exn exn:fail:contract? (λ () (list*->matrix '(() () ())))) +(check-exn exn:fail:contract? (λ () (list*->matrix '()))) + +(check-equal? ((inst vector*->matrix Integer) #(#(1 2 3) #(4 5 6))) (matrix [[1 2 3] [4 5 6]])) +(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #(#(1 2 3) #(4 5))))) +(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #(#() #() #())))) +(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #()))) + +(check-equal? (matrix->list* (matrix [[1 2 3] [4 5 6]])) '((1 2 3) (4 5 6))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix->list* a)))) + +(check-equal? (matrix->vector* (matrix [[1 2 3] [4 5 6]])) #(#(1 2 3) #(4 5 6))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix->vector* a)))) + +;; =================================================================================================== +;; Equality + +(check-true (matrix= (matrix [[1 2 3] + [4 5 6]]) + (matrix [[1.0 2.0 3.0] + [4.0 5.0 6.0]]))) + +(check-true (matrix= (matrix [[1 2 3] + [4 5 6]]) + (matrix [[1.0 2.0 3.0] + [4.0 5.0 6.0]]) + (matrix [[1.0+0.0i 2.0+0.0i 3.0+0.0i] + [4.0+0.0i 5.0+0.0i 6.0+0.0i]]))) + +(check-false (matrix= (matrix [[1 2 3] [4 5 6]]) + (matrix [[1 2 3] [4 5 7]]))) + +(check-false (matrix= (matrix [[0 2 3] [4 5 6]]) + (matrix [[1 2 3] [4 5 7]]))) + +(check-false (matrix= (matrix [[1 2 3] [4 5 6]]) + (matrix [[1 4] [2 5] [3 6]]))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix= a (matrix [[1]])))) + (check-exn exn:fail:contract? (λ () (matrix= (matrix [[1]]) a))) + (check-exn exn:fail:contract? (λ () (matrix= (matrix [[1]]) (matrix [[1]]) a)))) + +;; =================================================================================================== +;; Pointwise operations + +(define-syntax-rule (test-matrix-map (matrix-map ...) (array-map ...)) + (begin + (for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-map ... a))) + (check-exn exn:fail:contract? (λ () (matrix-map ... (matrix [[1]]) a)))) + + (for*: ([m '(2 3 4)] + [n '(2 3 4)]) + (define a0 (random-matrix m n)) + (define a1 (random-matrix m n)) + (define a2 (random-matrix m n)) + (check-equal? (matrix-map ... a0) + (array-map ... a0)) + (check-equal? (matrix-map ... a0 a1) + (array-map ... a0 a1)) + (check-equal? (matrix-map ... a0 a1 a2) + (array-map ... a0 a1 a2)) + ;; Don't know why this (void) is necessary, but TR complains without it + (void)))) + +(test-matrix-map (matrix-map -) (array-map -)) +(test-matrix-map ((values matrix-map) -) (array-map -)) + +(test-matrix-map (matrix+) (array+)) +(test-matrix-map ((values matrix+)) (array+)) + +(test-matrix-map (matrix-) (array-)) +(test-matrix-map ((values matrix-)) (array-)) + +(check-equal? (matrix-sum (list (matrix [[1 2 3] [4 5 6]]))) + (matrix [[1 2 3] [4 5 6]])) +(check-equal? (matrix-sum (list (matrix [[1 2 3] [4 5 6]]) + (matrix [[0 1 2] [3 4 5]]))) + (matrix+ (matrix [[1 2 3] [4 5 6]]) + (matrix [[0 1 2] [3 4 5]]))) +(check-exn exn:fail:contract? (λ () (matrix-sum '()))) + +(check-equal? (matrix-scale (matrix [[1 2 3] [4 5 6]]) 10) + (matrix [[10 20 30] [40 50 60]])) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-scale a 0)))) + +;; =================================================================================================== +;; Multiplication + +(define-syntax-rule (test-matrix* matrix*) + (begin + (for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix* a (matrix [[1]]))))) + + (check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]]) + (matrix [[1 2 3] [4 5 6] [7 8 9]])) + (matrix [[30 36 42] [66 81 96] [102 126 150]])) + + (check-equal? (matrix* (row-matrix [1 2 3 4]) + (col-matrix [1 2 3 4])) + (matrix [[30]])) + + (check-equal? (matrix* (col-matrix [1 2 3 4]) + (row-matrix [1 2 3 4])) + (matrix [[1 2 3 4] + [2 4 6 8] + [3 6 9 12] + [4 8 12 16]])) + + (check-equal? (matrix* (matrix [[3]]) (matrix [[7]])) + (matrix [[21]])) + + ;; Left/right identity + (let ([m (random-matrix 2 2)]) + (check-equal? (matrix* (identity-matrix 2) m) + m) + (check-equal? (matrix* m (identity-matrix 2)) + m)) + + ;; Shape + (let ([m0 (random-matrix 4 5)] + [m1 (random-matrix 5 2)] + [m2 (random-matrix 2 10)]) + (check-equal? (let-values ([(m n) (matrix-shape (matrix* m0 m1))]) + (list m n)) + (list 4 2)) + (check-equal? (let-values ([(m n) (matrix-shape (matrix* m1 m2))]) + (list m n)) + (list 5 10)) + (check-equal? (let-values ([(m n) (matrix-shape (matrix* m0 m1 m2))]) + (list m n)) + (list 4 10))) + + (check-exn exn:fail? (λ () (matrix* (random-matrix 1 2) (random-matrix 3 2)))) + + ;; Associativity + (let ([m0 (random-matrix 4 5)] + [m1 (random-matrix 5 2)] + [m2 (random-matrix 2 10)]) + (check-equal? (matrix* m0 m1 m2) + (matrix* (matrix* m0 m1) m2)) + (check-equal? (matrix* (matrix* m0 m1) m2) + (matrix* m0 (matrix* m1 m2)))) + )) + +(test-matrix* matrix*) +;; `matrix*' is an inlining macro, so we need to check the function version as well +(test-matrix* (values matrix*)) + +;; =================================================================================================== +;; Exponentiation + +(let ([A (matrix [[1 2] [3 4]])]) + (check-equal? (matrix-expt A 0) (identity-matrix 2)) + (check-equal? (matrix-expt A 1) A) + (check-equal? (matrix-expt A 2) (matrix [[7 10] [15 22]])) + (check-equal? (matrix-expt A 3) (matrix [[37 54] [81 118]])) + (check-equal? (matrix-expt A 8) (matrix [[165751 241570] [362355 528106]]))) + +(check-equal? (matrix-expt (matrix [[2]]) 10) (matrix [[(expt 2 10)]])) + +(check-exn exn:fail:contract? (λ () (matrix-expt (row-matrix [1 2 3]) 0))) +(check-exn exn:fail:contract? (λ () (matrix-expt (col-matrix [1 2 3]) 0))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-expt a 0)))) + +;; =================================================================================================== +;; Comprehensions + +;; for/matrix and friends are defined in terms of for/array and friends, so we only need to test that +;; it works for one case each, and that they properly raise exceptions when given zero-length axes + +(check-equal? + (for/matrix 2 2 ([i (in-range 4)]) i) + (matrix [[0 1] [2 3]])) + +#;; TR can't type this, but it's defined using exactly the same wrapper as `for/matrix' +(check-equal? + (for*/matrix 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j)) + (matrix [[0 1] [1 2]])) + +(check-equal? + (for/matrix: 2 2 ([i (in-range 4)]) i) + (matrix [[0 1] [2 3]])) + +(check-equal? + (for*/matrix: 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j)) + (matrix [[0 1] [1 2]])) + +(check-exn exn:fail:contract? (λ () (for/matrix 2 0 () 0))) +(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0))) +(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0))) +(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0))) + +(check-exn exn:fail:contract? (λ () (for/matrix: 2 0 () 0))) +(check-exn exn:fail:contract? (λ () (for/matrix: 0 2 () 0))) +(check-exn exn:fail:contract? (λ () (for*/matrix: 2 0 () 0))) +(check-exn exn:fail:contract? (λ () (for*/matrix: 0 2 () 0))) + +;; =================================================================================================== +;; Extraction + +;; matrix-ref + +(let ([a (matrix [[10 11] [12 13]])]) + (check-equal? (matrix-ref a 0 0) 10) + (check-equal? (matrix-ref a 0 1) 11) + (check-equal? (matrix-ref a 1 0) 12) + (check-equal? (matrix-ref a 1 1) 13) + (check-exn exn:fail? (λ () (matrix-ref a 2 0))) + (check-exn exn:fail? (λ () (matrix-ref a 0 2))) + (check-exn exn:fail? (λ () (matrix-ref a -1 0))) + (check-exn exn:fail? (λ () (matrix-ref a 0 -1)))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-ref a 0 0)))) + +;; matrix-diagonal + +(check-equal? (matrix-diagonal (diagonal-matrix '(1 2 3 4))) + (array #[1 2 3 4])) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-diagonal a)))) + +;; submatrix + +(check-equal? (submatrix (identity-matrix 8) (:: 2 4) (:: 2 4)) + (identity-matrix 2)) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (submatrix a '(0) '(0))))) + +;; matrix-row + +(let ([a (matrix [[1 2 3] [4 5 6]])]) + (check-equal? (matrix-row a 0) (row-matrix [1 2 3])) + (check-equal? (matrix-row a 1) (row-matrix [4 5 6])) + (check-exn exn:fail? (λ () (matrix-row a -1))) + (check-exn exn:fail? (λ () (matrix-row a 2)))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-row a 0)))) + +;; matrix-col + +(let ([a (matrix [[1 2 3] [4 5 6]])]) + (check-equal? (matrix-col a 0) (col-matrix [1 4])) + (check-equal? (matrix-col a 1) (col-matrix [2 5])) + (check-equal? (matrix-col a 2) (col-matrix [3 6])) + (check-exn exn:fail? (λ () (matrix-col a -1))) + (check-exn exn:fail? (λ () (matrix-col a 3)))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-col a 0)))) + +;; matrix-rows + +(check-equal? (matrix-rows (matrix [[1 2 3] [4 5 6]])) + (list (row-matrix [1 2 3]) + (row-matrix [4 5 6]))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-rows a)))) + +;; matrix-cols + +(check-equal? (matrix-cols (matrix [[1 2 3] [4 5 6]])) + (list (col-matrix [1 4]) + (col-matrix [2 5]) + (col-matrix [3 6]))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-cols a)))) + +;; =================================================================================================== +;; Embiggenment (it's a perfectly cromulent word) + +;; matrix-augment + +(let ([a (random-matrix 3 5)]) + (check-equal? (matrix-augment (list a)) a) + (check-equal? (matrix-augment (matrix-cols a)) a)) + +(check-equal? (matrix-augment (list (col-matrix [1 2 3]) (col-matrix [4 5 6]))) + (matrix [[1 4] [2 5] [3 6]])) + +(check-equal? (matrix-augment (list (matrix [[1 2] [4 5]]) (col-matrix [3 6]))) + (matrix [[1 2 3] [4 5 6]])) + +(check-exn exn:fail? (λ () (matrix-augment (list (matrix [[1 2] [4 5]]) (col-matrix [3]))))) +(check-exn exn:fail:contract? (λ () (matrix-augment '()))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-augment (list a)))) + (check-exn exn:fail:contract? (λ () (matrix-augment (list (matrix [[1]]) a))))) + +;; matrix-stack + +(let ([a (random-matrix 5 3)]) + (check-equal? (matrix-stack (list a)) a) + (check-equal? (matrix-stack (matrix-rows a)) a)) + +(check-equal? (matrix-stack (list (row-matrix [1 2 3]) (row-matrix [4 5 6]))) + (matrix [[1 2 3] [4 5 6]])) + +(check-equal? (matrix-stack (list (matrix [[1 2 3] [4 5 6]]) (row-matrix [7 8 9]))) + (matrix [[1 2 3] [4 5 6] [7 8 9]])) + +(check-exn exn:fail? (λ () (matrix-stack (list (matrix [[1 2 3] [4 5 6]]) (row-matrix [7 8]))))) +(check-exn exn:fail:contract? (λ () (matrix-stack '()))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-stack (list a)))) + (check-exn exn:fail:contract? (λ () (matrix-stack (list (matrix [[1]]) a))))) + +;; =================================================================================================== +;; Inner product space + +;; matrix-norm + +(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]])) + (sqrt (+ (* 1 1) (* 2 2) (* 3 3) (* 4 4) (* 5 5) (* 6 6)))) + +;; Default norm is Frobenius norm +(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]])) + (matrix-norm (matrix [[1 2 3] [4 5 6]]) 2)) + +;; This shouldn't overflow (so we check against `flhypot', which also shouldn't overflow) +(check-equal? (matrix-norm (matrix [[1e200 1e199]])) + (flhypot 1e200 1e199)) + +;; Taxicab (Manhattan) norm +(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) 1) + (+ 1 2 3 4 5 6)) + +;; Infinity (maximum) norm +(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) +inf.0) + (max 1 2 3 4 5 6)) + +;; The actual norm is indistinguishable from floating-point 6 +(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) 1000) + 6.0) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-norm a 1))) + (check-exn exn:fail:contract? (λ () (matrix-norm a))) + (check-exn exn:fail:contract? (λ () (matrix-norm a 5))) + (check-exn exn:fail:contract? (λ () (matrix-norm a +inf.0)))) + +(check-equal? (matrix-norm (row-matrix [1+1i])) + (sqrt 2)) + +(check-equal? (matrix-norm (row-matrix [1+1i 2+2i 3+3i])) + (matrix-norm (row-matrix [(magnitude 1+1i) (magnitude 2+2i) (magnitude 3+3i)]))) + +;; matrix-dot (induces the Frobenius norm) + +(check-equal? (matrix-dot (matrix [[1 -2 3] [-4 5 -6]]) + (matrix [[-1 2 -3] [4 -5 6]])) + (+ (* 1 -1) (* -2 2) (* 3 -3) (* -4 4) (* 5 -5) (* -6 6))) + +(check-equal? (matrix-dot (row-matrix [1 2 3]) + (row-matrix [0+4i 0-5i 0+6i])) + (+ (* 1 0-4i) (* 2 0+5i) (* 3 0-6i))) + +(check-exn exn:fail? (λ () (matrix-dot (random-matrix 1 3) (random-matrix 3 1)))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-dot a (matrix [[1]])))) + (check-exn exn:fail:contract? (λ () (matrix-dot (matrix [[1]]) a)))) + +;; =================================================================================================== +;; Simple operators + +;; matrix-transpose + +(check-equal? (matrix-transpose (matrix [[1 2 3] [4 5 6]])) + (matrix [[1 4] [2 5] [3 6]])) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-transpose a)))) + +;; matrix-conjugate + +(check-equal? (matrix-conjugate (matrix [[1+i 2-i] [3+i 4-i]])) + (matrix [[1-i 2+i] [3-i 4+i]])) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-conjugate a)))) + +;; matrix-hermitian + +(let ([a (array-make-rectangular (random-matrix 5 6) + (random-matrix 5 6))]) + (check-equal? (matrix-hermitian a) + (matrix-conjugate (matrix-transpose a))) + (check-equal? (matrix-hermitian a) + (matrix-transpose (matrix-conjugate a)))) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-hermitian a)))) + +;; matrix-trace + +(check-equal? (matrix-trace (matrix [[1 2 3] [4 5 6] [7 8 9]])) + (+ 1 5 9)) + +(check-exn exn:fail:contract? (λ () (matrix-trace (row-matrix [1 2 3])))) +(check-exn exn:fail:contract? (λ () (matrix-trace (col-matrix [1 2 3])))) +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (matrix-trace a)))) + +;; =================================================================================================== +;; Tests not yet converted to rackunit + +(begin (begin "matrix-operations.rkt" - (list 'vandermonde-matrix - (equal? (vandermonde-matrix '(1 2 3) 5) - (list*->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]]))) - #; - (list 'in-column - (equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 0)]) - x) - '(1 3)) - (equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 1)]) - x) - '(2 4)) - (equal? (for/list: : (Listof Number) ([x (in-column (col-matrix [5 2 3]))]) x) - '(5 2 3))) - #; - (list 'in-row - (equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 0)]) - x) - '(1 2)) - (equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 1)]) - x) - '(3 4))) - (list 'for/matrix: - (equal? (for/matrix: : Number 2 4 ([i (in-naturals)]) i) - (matrix [[0 1 2 3] [4 5 6 7]])) - (equal? (for/matrix: : Number 2 4 #:column ([i (in-naturals)]) i) - (matrix [[0 2 4 6] [1 3 5 7]])) - (equal? (for/matrix: : Number 3 3 ([i (in-range 10 100)]) i) - (matrix [[10 11 12] [13 14 15] [16 17 18]]))) - (list 'for*/matrix: - (equal? (for*/matrix: : Number 3 3 ([i (in-range 3)] [j (in-range 3)]) (+ (* i 10) j)) - (matrix [[0 1 2] [10 11 12] [20 21 22]]))) - (list 'matrix-block-diagonal - (equal? (block-diagonal-matrix (list (matrix [[1 2] [3 4]]) (matrix [[5 6 7]]))) - (list*->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]]))) - (list 'matrix-augment - (equal? (matrix-augment (list (col-matrix [1 2 3]) - (col-matrix [4 5 6]) - (col-matrix [7 8 9]))) - (matrix [[1 4 7] [2 5 8] [3 6 9]]))) - (list 'matrix-stack - (equal? (matrix-stack (list (col-matrix [1 2 3]) - (col-matrix [4 5 6]) - (col-matrix [7 8 9]))) - (col-matrix [1 2 3 4 5 6 7 8 9]))) #; (list 'column-dimension (= (column-dimension #(1 2 3)) 3) @@ -206,8 +687,6 @@ (+ (* 1 4) (* 2 5) (* 3 6))) (= (column-dot (col-matrix [+3i +4i]) (col-matrix [+3i +4i])) 25))) - (list 'matrix-trace - (equal? (matrix-trace (vector->matrix 2 2 #(1 2 3 4))) 5)) (let ([matrix: vector->matrix]) (list 'column-norm (= (column-norm (col-matrix [2 4])) (sqrt 20)))) @@ -286,15 +765,6 @@ [9 10 -11 12] [13 14 15 16]])) 5280)) - (list 'matrix-scale - (equal? (matrix-scale (list*->matrix '[[1 2] [3 4]]) 2) - (list*->matrix '[[2 4] [6 8]]))) - (list 'matrix-transpose - (equal? (matrix-transpose (list*->matrix '[[1 2] [3 4]])) - (list*->matrix '[[1 3] [2 4]]))) - (list 'matrix-hermitian - (equal? (matrix-hermitian (list*->matrix '[[1+i 2-i] [3+i 4-i]])) - (list*->matrix '[[1-i 3-i] [2+i 4+i]]))) (let () (: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number)) (define (gauss-eliminate M u? p?) @@ -366,75 +836,60 @@ (equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0) (equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1) (equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0)) - #;(let () - (define-values (c1 n1) - (matrix-column+null-space (list*rix '[[0 0] [0 0]]))) - (define-values (c2 n2) - (matrix-column+null-space (list*->matrix '[[1 2] [2 4]]))) - (define-values (c3 n3) - (matrix-column+null-space (list*atrix '[[1 2] [2 5]]))) - (list - 'matrix-column+null-space - (equal? c1 '()) - (equal? n1 (list (list*->matrix '[[0] [0]]) - (list*->matrix '[[0] [0]]))) - (equal? c2 (list (list*->matrix '[[1] [2]]))) - ;(equal? n2 '([0 0])) - (equal? c3 (list (list*->matrix '[[1] [2]]) - (list*->matrix '[[2] [5]]))) - (equal? n3 '())))) + #; + (let () + (define-values (c1 n1) + (matrix-column+null-space (list*rix '[[0 0] [0 0]]))) + (define-values (c2 n2) + (matrix-column+null-space (list*->matrix '[[1 2] [2 4]]))) + (define-values (c3 n3) + (matrix-column+null-space (list*atrix '[[1 2] [2 5]]))) + (list + 'matrix-column+null-space + (equal? c1 '()) + (equal? n1 (list (list*->matrix '[[0] [0]]) + (list*->matrix '[[0] [0]]))) + (equal? c2 (list (list*->matrix '[[1] [2]]))) + ;(equal? n2 '([0 0])) + (equal? c3 (list (list*->matrix '[[1] [2]]) + (list*->matrix '[[2] [5]]))) + (equal? n3 '())))) - - - - - - - #;(begin - "matrix-multiply.rkt" - (list 'matrix* - (let () - (define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6] [7 8]] '[[19 22] [43 50]])) - (equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB))) - (let () - (define-values (A B AB) (values '[[1 2] [3 4]] - '[[5 6 7] [8 9 10]] - '[[21 24 27] [47 54 61]])) - (equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB))))) - #;(begin - "matrix-2d.rkt" - (let () - (define e1 (matrix-transpose (vector->matrix #(#( 1 0))))) - (define e2 (matrix-transpose (vector->matrix #(#( 0 1))))) - (define -e1 (matrix-transpose (vector->matrix #(#(-1 0))))) - (define -e2 (matrix-transpose (vector->matrix #(#( 0 -1))))) - (define O (matrix-transpose (vector->matrix #(#( 0 0))))) - (define 2*e1 (matrix-scale e1 2)) - (define 4*e1 (matrix-scale e1 4)) - (define 3*e2 (matrix-scale e2 3)) - (define 4*e2 (matrix-scale e2 4)) - (begin - (list 'matrix-2d-rotation - (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0) - (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e2) -e1)) epsilon.0)) - (list - 'matrix-2d-scaling - (equal? (matrix* (matrix-2d-scaling 2 3) (matrix+ e1 e2)) (matrix+ 2*e1 3*e2))) - (list - 'matrix-2d-shear-x - (equal? (matrix* (matrix-2d-shear-x 3) (matrix+ e1 e2)) (matrix+ 4*e1 e2))) - (list - 'matrix-2d-shear-y - (equal? (matrix* (matrix-2d-shear-y 3) (matrix+ e1 e2)) (matrix+ e1 4*e2))) - (list - 'matrix-2d-reflection - (equal? (matrix* (matrix-2d-reflection 0 1) e1) -e1) - (equal? (matrix* (matrix-2d-reflection 0 1) e2) e2) - (equal? (matrix* (matrix-2d-reflection 1 0) e1) e1) - (equal? (matrix* (matrix-2d-reflection 1 0) e2) -e2)) - (list - 'matrix-2d-orthogonal-projection - (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e1) e1) - (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O) - (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O) - (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2)))))) + #; + (begin + "matrix-2d.rkt" + (let () + (define e1 (matrix-transpose (vector->matrix #(#( 1 0))))) + (define e2 (matrix-transpose (vector->matrix #(#( 0 1))))) + (define -e1 (matrix-transpose (vector->matrix #(#(-1 0))))) + (define -e2 (matrix-transpose (vector->matrix #(#( 0 -1))))) + (define O (matrix-transpose (vector->matrix #(#( 0 0))))) + (define 2*e1 (matrix-scale e1 2)) + (define 4*e1 (matrix-scale e1 4)) + (define 3*e2 (matrix-scale e2 3)) + (define 4*e2 (matrix-scale e2 4)) + (begin + (list 'matrix-2d-rotation + (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0) + (<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e2) -e1)) epsilon.0)) + (list + 'matrix-2d-scaling + (equal? (matrix* (matrix-2d-scaling 2 3) (matrix+ e1 e2)) (matrix+ 2*e1 3*e2))) + (list + 'matrix-2d-shear-x + (equal? (matrix* (matrix-2d-shear-x 3) (matrix+ e1 e2)) (matrix+ 4*e1 e2))) + (list + 'matrix-2d-shear-y + (equal? (matrix* (matrix-2d-shear-y 3) (matrix+ e1 e2)) (matrix+ e1 4*e2))) + (list + 'matrix-2d-reflection + (equal? (matrix* (matrix-2d-reflection 0 1) e1) -e1) + (equal? (matrix* (matrix-2d-reflection 0 1) e2) e2) + (equal? (matrix* (matrix-2d-reflection 1 0) e1) e1) + (equal? (matrix* (matrix-2d-reflection 1 0) e2) -e2)) + (list + 'matrix-2d-orthogonal-projection + (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e1) e1) + (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O) + (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O) + (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2))))))