Moar `math/matrix' review/refactoring
* Split "matrix-constructors.rkt" into three parts: * "matrix-constructors.rkt" * "matrix-conversion.rkt" * "matrix-syntax.rkt" * Made `matrix-map' automatically inline (it's dirt simple) * Renamed a few things, changed some type signatures * Fixed error in `matrix-dot' caught by testing (it was broadcasting) * Rewrote matrix comprehensions in terms of array comprehensions * Removed `in-column' and `in-row' (can use `in-array', `matrix-col' and `matrix-row') * Tons of new rackunit tests: only "matrix-2d.rkt" and "matrix-operations.rkt" are left (though the latter is large)
This commit is contained in:
parent
155ec7dc41
commit
8d5a069d41
|
@ -2,21 +2,24 @@
|
|||
|
||||
(require "private/matrix/matrix-arithmetic.rkt"
|
||||
"private/matrix/matrix-constructors.rkt"
|
||||
"private/matrix/matrix-conversion.rkt"
|
||||
"private/matrix/matrix-syntax.rkt"
|
||||
"private/matrix/matrix-basic.rkt"
|
||||
"private/matrix/matrix-operations.rkt"
|
||||
"private/matrix/matrix-comprehension.rkt"
|
||||
"private/matrix/matrix-sequences.rkt"
|
||||
"private/matrix/matrix-expt.rkt"
|
||||
"private/matrix/matrix-types.rkt"
|
||||
"private/matrix/matrix-2d.rkt"
|
||||
"private/matrix/utils.rkt")
|
||||
|
||||
(provide (all-from-out
|
||||
"private/matrix/matrix-arithmetic.rkt"
|
||||
"private/matrix/matrix-constructors.rkt"
|
||||
"private/matrix/matrix-conversion.rkt"
|
||||
"private/matrix/matrix-syntax.rkt"
|
||||
"private/matrix/matrix-basic.rkt"
|
||||
"private/matrix/matrix-operations.rkt"
|
||||
"private/matrix/matrix-comprehension.rkt"
|
||||
"private/matrix/matrix-sequences.rkt"
|
||||
"private/matrix/matrix-expt.rkt"
|
||||
"private/matrix/matrix-types.rkt")
|
||||
matrix?)
|
||||
"private/matrix/matrix-types.rkt"
|
||||
"private/matrix/matrix-2d.rkt"))
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
(require math/array
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt")
|
||||
"matrix-syntax.rkt")
|
||||
|
||||
(provide matrix-2d-rotation
|
||||
matrix-2d-scaling
|
||||
|
|
|
@ -19,21 +19,18 @@
|
|||
[(_ . es) (syntax/loc inner-stx (typed:fun . es))]
|
||||
[_ (syntax/loc inner-stx typed:fun)])))]))
|
||||
|
||||
(define/inline-macro matrix* (a . as) inline-matrix* typed:matrix*)
|
||||
(define/inline-macro matrix+ (a . as) inline-matrix+ typed:matrix+)
|
||||
(define/inline-macro matrix- (a . as) inline-matrix- typed:matrix-)
|
||||
(define/inline-macro matrix* (a as ...) inline-matrix* typed:matrix*)
|
||||
(define/inline-macro matrix+ (a as ...) inline-matrix+ typed:matrix+)
|
||||
(define/inline-macro matrix- (a as ...) inline-matrix- typed:matrix-)
|
||||
(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale)
|
||||
|
||||
(define/inline-macro do-matrix-map (f a as ...) inline-matrix-map matrix-map)
|
||||
|
||||
(provide
|
||||
;; Equality
|
||||
(rename-out [typed:matrix= matrix=])
|
||||
;; Mapping
|
||||
inline-matrix-map
|
||||
matrix-map
|
||||
;; Multiplication
|
||||
(rename-out [do-matrix-map matrix-map]
|
||||
[typed:matrix= matrix=]
|
||||
[typed:matrix-sum matrix-sum])
|
||||
matrix*
|
||||
;; Pointwise operators
|
||||
matrix+
|
||||
matrix-
|
||||
matrix-scale
|
||||
(rename-out [typed:matrix-sum matrix-sum]))
|
||||
matrix-scale)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
math/array
|
||||
math/flonum
|
||||
"matrix-types.rkt"
|
||||
"matrix-arithmetic.rkt"
|
||||
"utils.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
|
@ -18,7 +19,7 @@
|
|||
matrix-rows
|
||||
matrix-cols
|
||||
;; Predicates
|
||||
zero-matrix?
|
||||
matrix-zero?
|
||||
;; Embiggenment
|
||||
matrix-augment
|
||||
matrix-stack
|
||||
|
@ -73,7 +74,7 @@
|
|||
(unsafe-vector-set! ij 0 0)
|
||||
res))]))
|
||||
|
||||
(: matrix-col (All (A) (Matrix A) Index -> (Matrix A)))
|
||||
(: matrix-col (All (A) (Matrix A) Integer -> (Matrix A)))
|
||||
(define (matrix-col a j)
|
||||
(define-values (m n) (matrix-shape a))
|
||||
(cond [(or (j . < . 0) (j . >= . n))
|
||||
|
@ -99,9 +100,9 @@
|
|||
;; ===================================================================================================
|
||||
;; Predicates
|
||||
|
||||
(: zero-matrix? ((Array Number) -> Boolean))
|
||||
(define (zero-matrix? a)
|
||||
(array-all-and (array-map zero? a)))
|
||||
(: matrix-zero? ((Array Number) -> Boolean))
|
||||
(define (matrix-zero? a)
|
||||
(array-all-and (matrix-map zero? a)))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Embiggenment (this is a perfectly cromulent word)
|
||||
|
@ -179,9 +180,14 @@
|
|||
((Array Number) (Array Number) -> Number)))
|
||||
;; Computes the Frobenius inner product of two matrices
|
||||
(define (matrix-dot a b)
|
||||
(cond [(not (matrix? a)) (raise-argument-error 'matrix-dot "matrix?" 0 a b)]
|
||||
[(not (matrix? b)) (raise-argument-error 'matrix-dot "matrix?" 1 a b)]
|
||||
[else (array-all-sum (array* a (array-conjugate b)))]))
|
||||
(define-values (m n) (matrix-shapes 'matrix-dot a b))
|
||||
(define aproc (unsafe-array-proc a))
|
||||
(define bproc (unsafe-array-proc b))
|
||||
(array-all-sum
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n)
|
||||
(λ: ([js : Indexes])
|
||||
(* (aproc js) (conjugate (bproc js)))))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Operators
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
(require math/array
|
||||
math/base
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
"matrix-conversion.rkt"
|
||||
"matrix-arithmetic.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
|
|
|
@ -1,140 +1,63 @@
|
|||
#lang racket
|
||||
#lang racket/base
|
||||
|
||||
(require math/array
|
||||
typed/racket/base
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt")
|
||||
(require (for-syntax racket/base
|
||||
syntax/parse)
|
||||
math/array)
|
||||
|
||||
(provide for/matrix
|
||||
for*/matrix
|
||||
for/matrix:
|
||||
for*/matrix:)
|
||||
(provide for/matrix:
|
||||
for*/matrix:
|
||||
for/matrix
|
||||
for*/matrix)
|
||||
|
||||
;;; COMPREHENSIONS
|
||||
(module typed-defs typed/racket/base
|
||||
(require (for-syntax racket/base
|
||||
syntax/parse)
|
||||
math/array)
|
||||
|
||||
(provide (all-defined-out))
|
||||
|
||||
(: ensure-matrix-dims (Symbol Any Any -> (Values Positive-Index Positive-Index)))
|
||||
(define (ensure-matrix-dims name m n)
|
||||
(cond [(or (not (index? m)) (zero? m)) (raise-argument-error name "Positive-Index" 0 m n)]
|
||||
[(or (not (index? n)) (zero? n)) (raise-argument-error name "Positive-Index" 1 m n)]
|
||||
[else (values m n)]))
|
||||
|
||||
(define-syntax (base-for/matrix: stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ name:id for/array:id
|
||||
m-expr:expr n-expr:expr
|
||||
(~optional (~seq #:fill fill-expr:expr))
|
||||
(clause ...)
|
||||
(~optional (~seq : A:expr))
|
||||
body:expr ...+)
|
||||
(with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())]
|
||||
[(maybe-type ...) (if (attribute A) #'(: A) #'())])
|
||||
(syntax/loc stx
|
||||
(let-values ([(m n) (ensure-matrix-dims 'name
|
||||
(ann m-expr Integer)
|
||||
(ann n-expr Integer))])
|
||||
(for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...) maybe-type ...
|
||||
body ...))))]))
|
||||
|
||||
(define-syntax-rule (for/matrix: e ...) (base-for/matrix: for/matrix: for/array: e ...))
|
||||
(define-syntax-rule (for*/matrix: e ...) (base-for/matrix: for*/matrix: for*/array: e ...))
|
||||
|
||||
)
|
||||
|
||||
; (for/matrix m n (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first n values produced becomes the first row.
|
||||
; The next n values becomes the second row and so on.
|
||||
; The bindings in clauses run in parallel.
|
||||
(define-syntax (for/matrix stx)
|
||||
(syntax-case stx ()
|
||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ([m m-expr] [n n-expr])
|
||||
(define flat-vector
|
||||
(for/vector #:length (* m n)
|
||||
(clause ...) . defs+exprs))
|
||||
(vector->matrix m n flat-vector)))]))
|
||||
(require (submod "." typed-defs))
|
||||
|
||||
; (for*/matrix m n (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first n values produced becomes the first row.
|
||||
; The next n values becomes the second row and so on.
|
||||
; The bindings in clauses run nested.
|
||||
; (for*/matrix m n #:column (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first m values produced becomes the first column.
|
||||
; The next m values becomes the second column and so on.
|
||||
; The bindings in clauses run nested.
|
||||
(define-syntax (for*/matrix stx)
|
||||
(syntax-case stx ()
|
||||
[(_ m-expr n-expr #:column (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let* ([m m-expr]
|
||||
[n n-expr]
|
||||
[v (make-vector (* m n) 0)]
|
||||
[w (for*/vector #:length (* m n) (clause ...) . defs+exprs)])
|
||||
(for* ([i (in-range m)] [j (in-range n)])
|
||||
(vector-set! v (+ (* i n) j)
|
||||
(vector-ref w (+ (* j m) i))))
|
||||
(vector->matrix m n v)))]
|
||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ([m m-expr] [n n-expr])
|
||||
(vector->matrix
|
||||
m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))]))
|
||||
(define-syntax (base-for/matrix stx)
|
||||
(syntax-parse stx
|
||||
[(_ name:id for/array:id
|
||||
m-expr:expr n-expr:expr
|
||||
(~optional (~seq #:fill fill-expr:expr))
|
||||
(clause ...)
|
||||
body:expr ...+)
|
||||
(with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())])
|
||||
(syntax/loc stx
|
||||
(let-values ([(m n) (ensure-matrix-dims 'name m-expr n-expr)])
|
||||
(for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...)
|
||||
body ...))))]))
|
||||
|
||||
|
||||
(define-syntax (for/column: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: flat-vector : (Vectorof Number) (make-vector m 0))
|
||||
(for: ([i (in-range m)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! flat-vector i x))
|
||||
(vector->col-matrix flat-vector)))]))
|
||||
|
||||
(define-syntax (for/matrix: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: k : Index 0)
|
||||
(for: ([i (in-range m*n)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
||||
(set! k (assert (+ k 1) index?)))
|
||||
(vector->matrix m n v)))]
|
||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(for: ([i (in-range m*n)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v i x))
|
||||
(vector->matrix m n v)))]))
|
||||
|
||||
(define-syntax (for*/matrix: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: k : Index 0)
|
||||
(for*: (for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
||||
(set! k (assert (+ k 1) index?)))
|
||||
(vector->matrix m n v)))]
|
||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: i : Index 0)
|
||||
(for*: (for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v i x)
|
||||
(set! i (assert (+ i 1) index?)))
|
||||
(vector->matrix m n v)))]))
|
||||
#;
|
||||
(module* test #f
|
||||
(require rackunit)
|
||||
; "matrix-sequences.rkt"
|
||||
; These work in racket not in typed racket
|
||||
(check-equal? (matrix->list* (for*/matrix 2 3 ([i 2] [j 3]) (+ i j)))
|
||||
'[[0 1 2] [1 2 3]])
|
||||
(check-equal? (matrix->list* (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j)))
|
||||
'[[0 2 2] [1 1 3]])
|
||||
(check-equal? (matrix->list* (for*/matrix 2 2 #:column ([i 4]) i))
|
||||
'[[0 2] [1 3]])
|
||||
(check-equal? (matrix->list* (for/matrix 2 2 ([i 4]) i))
|
||||
'[[0 1] [2 3]])
|
||||
(check-equal? (matrix->list* (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j)))
|
||||
'[[6 8 10] [12 14 16]]))
|
||||
(define-syntax-rule (for/matrix e ...) (base-for/matrix for/matrix for/array e ...))
|
||||
(define-syntax-rule (for*/matrix e ...) (base-for/matrix for*/matrix for*/array e ...))
|
||||
|
|
|
@ -1,376 +1,151 @@
|
|||
#lang racket/base
|
||||
#lang typed/racket/base
|
||||
|
||||
(provide
|
||||
;; Constructors
|
||||
identity-matrix
|
||||
make-matrix
|
||||
build-matrix
|
||||
diagonal-matrix/zero
|
||||
diagonal-matrix
|
||||
block-diagonal-matrix/zero
|
||||
block-diagonal-matrix
|
||||
vandermonde-matrix
|
||||
;; Basic conversion
|
||||
list->matrix
|
||||
matrix->list
|
||||
vector->matrix
|
||||
matrix->vector
|
||||
->row-matrix
|
||||
->col-matrix
|
||||
;; Nested conversion
|
||||
list*->matrix
|
||||
matrix->list*
|
||||
vector*->matrix
|
||||
matrix->vector*
|
||||
;; Syntax
|
||||
matrix
|
||||
row-matrix
|
||||
col-matrix)
|
||||
|
||||
(module typed-defs typed/racket/base
|
||||
(require racket/fixnum
|
||||
racket/list
|
||||
racket/vector
|
||||
math/array
|
||||
"../array/utils.rkt"
|
||||
"matrix-types.rkt"
|
||||
"utils.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(provide (all-defined-out))
|
||||
|
||||
;; =================================================================================================
|
||||
;; Constructors
|
||||
|
||||
(: identity-matrix (Integer -> (Matrix (U 0 1))))
|
||||
(define (identity-matrix m) (diagonal-array 2 m 1 0))
|
||||
|
||||
(: make-matrix (All (A) (Integer Integer A -> (Matrix A))))
|
||||
(define (make-matrix m n x)
|
||||
(make-array (vector m n) x))
|
||||
|
||||
(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A))))
|
||||
(define (build-matrix m n proc)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
|
||||
[else
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n)
|
||||
(λ: ([js : Indexes])
|
||||
(proc (unsafe-vector-ref js 0)
|
||||
(unsafe-vector-ref js 1))))]))
|
||||
|
||||
(: diagonal-matrix/zero (All (A) (Array A) A -> (Array A)))
|
||||
(define (diagonal-matrix/zero a zero)
|
||||
(define ds (array-shape a))
|
||||
(cond [(= 1 (vector-length ds))
|
||||
(define m (unsafe-vector-ref ds 0))
|
||||
(define proc (unsafe-array-proc a))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m m)
|
||||
(λ: ([js : Indexes])
|
||||
(define i (unsafe-vector-ref js 0))
|
||||
(cond [(= i (unsafe-vector-ref js 1)) (proc ((inst vector Index) i))]
|
||||
[else zero])))]
|
||||
[else
|
||||
(raise-argument-error 'diagonal-matrix "Array with one dimension" a)]))
|
||||
|
||||
(: diagonal-matrix (case-> ((Array Real) -> (Array Real))
|
||||
((Array Number) -> (Array Number))))
|
||||
(define (diagonal-matrix a)
|
||||
(diagonal-matrix/zero a 0))
|
||||
|
||||
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
|
||||
(define (block-diagonal-matrix/zero* as zero)
|
||||
(define num (vector-length as))
|
||||
(define-values (ms ns)
|
||||
(let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty]
|
||||
[ns : (Listof Index) empty]
|
||||
) ([a (in-vector as)])
|
||||
(define-values (m n) (matrix-shape a))
|
||||
(values (cons m ms) (cons n ns)))])
|
||||
(values (reverse ms) (reverse ns))))
|
||||
(define res-m (assert (apply + ms) index?))
|
||||
(define res-n (assert (apply + ns) index?))
|
||||
(define vs ((inst make-vector Index) res-m 0))
|
||||
(define hs ((inst make-vector Index) res-n 0))
|
||||
(define is ((inst make-vector Index) res-m 0))
|
||||
(define js ((inst make-vector Index) res-n 0))
|
||||
(define-values (_res-i _res-j)
|
||||
(for/fold: ([res-i : Nonnegative-Fixnum 0]
|
||||
[res-j : Nonnegative-Fixnum 0]
|
||||
) ([m (in-list ms)]
|
||||
[n (in-list ns)]
|
||||
[k : Nonnegative-Fixnum (in-range num)])
|
||||
(let ([k (assert k index?)])
|
||||
(for: ([i : Nonnegative-Fixnum (in-range m)])
|
||||
(vector-set! vs (unsafe-fx+ res-i i) k)
|
||||
(vector-set! is (unsafe-fx+ res-i i) (assert i index?)))
|
||||
(for: ([j : Nonnegative-Fixnum (in-range n)])
|
||||
(vector-set! hs (unsafe-fx+ res-j j) k)
|
||||
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
|
||||
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
|
||||
(define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) res-m res-n)
|
||||
(λ: ([ij : Indexes])
|
||||
(define i (unsafe-vector-ref ij 0))
|
||||
(define j (unsafe-vector-ref ij 1))
|
||||
(define v (unsafe-vector-ref vs i))
|
||||
(cond [(fx= v (vector-ref hs j))
|
||||
(define proc (unsafe-vector-ref procs v))
|
||||
(define iv (unsafe-vector-ref is i))
|
||||
(define jv (unsafe-vector-ref js j))
|
||||
(unsafe-vector-set! ij 0 iv)
|
||||
(unsafe-vector-set! ij 1 jv)
|
||||
(define res (proc ij))
|
||||
(unsafe-vector-set! ij 0 i)
|
||||
(unsafe-vector-set! ij 1 j)
|
||||
res]
|
||||
[else
|
||||
zero]))))
|
||||
|
||||
(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A)))
|
||||
(define (block-diagonal-matrix/zero as zero)
|
||||
(let ([as (list->vector as)])
|
||||
(define num (vector-length as))
|
||||
(cond [(= num 0)
|
||||
(raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)]
|
||||
[(= num 1)
|
||||
(unsafe-vector-ref as 0)]
|
||||
[else
|
||||
(block-diagonal-matrix/zero* as zero)])))
|
||||
|
||||
(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real))
|
||||
((Listof (Array Number)) -> (Array Number))))
|
||||
(define (block-diagonal-matrix as)
|
||||
(block-diagonal-matrix/zero as 0))
|
||||
|
||||
(: expt-hack (case-> (Real Integer -> Real)
|
||||
(Number Integer -> Number)))
|
||||
;; Stop using this when TR correctly derives expt : Real Integer -> Real
|
||||
(define (expt-hack x n)
|
||||
(cond [(real? x) (assert (expt x n) real?)]
|
||||
[else (expt x n)]))
|
||||
|
||||
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real))
|
||||
((Listof Number) Integer -> (Array Number))))
|
||||
(define (vandermonde-matrix xs n)
|
||||
(cond [(empty? xs)
|
||||
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]
|
||||
[(or (not (index? n)) (zero? n))
|
||||
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
||||
[else
|
||||
(array-axis-expand (list->array xs) 1 n expt-hack)]))
|
||||
|
||||
;; =================================================================================================
|
||||
;; Flat conversion
|
||||
|
||||
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
|
||||
(define (list->matrix m n xs)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
|
||||
[else (list->array (vector m n) xs)]))
|
||||
|
||||
(: matrix->list (All (A) ((Array A) -> (Listof A))))
|
||||
(define (matrix->list a)
|
||||
(array->list (ensure-matrix 'matrix->list a)))
|
||||
|
||||
(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->matrix m n v)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
|
||||
[else (vector->array (vector m n) v)]))
|
||||
|
||||
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
|
||||
(define (matrix->vector a)
|
||||
(array->vector (ensure-matrix 'matrix->vector a)))
|
||||
|
||||
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
|
||||
(define (list->row-matrix xs)
|
||||
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
|
||||
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
|
||||
|
||||
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
|
||||
(define (list->col-matrix xs)
|
||||
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
|
||||
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
|
||||
|
||||
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->row-matrix xs)
|
||||
(define n (vector-length xs))
|
||||
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
|
||||
[else (vector->array ((inst vector Index) 1 n) xs)]))
|
||||
|
||||
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->col-matrix xs)
|
||||
(define n (vector-length xs))
|
||||
(cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)]
|
||||
[else (vector->array ((inst vector Index) n 1) xs)]))
|
||||
|
||||
(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index)))
|
||||
(define (find-nontrivial-axis ds)
|
||||
(define dims (vector-length ds))
|
||||
(let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0])
|
||||
(cond [(k . < . dims) (define dk (unsafe-vector-ref ds k))
|
||||
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
|
||||
[else (values 0 0)])))
|
||||
|
||||
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
|
||||
(define (array->row-matrix arr)
|
||||
(define (fail)
|
||||
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||
(define ds (array-shape arr))
|
||||
(define dims (vector-length ds))
|
||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||
(cond [(zero? (array-size arr)) (fail)]
|
||||
[(row-matrix? arr) arr]
|
||||
[(= num-ones dims)
|
||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 1)
|
||||
(λ: ([ij : Indexes]) (proc js)))]
|
||||
[(= num-ones (- dims 1))
|
||||
(define-values (k n) (find-nontrivial-axis ds))
|
||||
(define js (make-thread-local-indexes dims))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 n)
|
||||
(λ: ([ij : Indexes])
|
||||
(let ([js (js)])
|
||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
|
||||
(proc js))))]
|
||||
[else (fail)]))
|
||||
|
||||
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
|
||||
(define (array->col-matrix arr)
|
||||
(define (fail)
|
||||
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||
(define ds (array-shape arr))
|
||||
(define dims (vector-length ds))
|
||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||
(cond [(zero? (array-size arr)) (fail)]
|
||||
[(col-matrix? arr) arr]
|
||||
[(= num-ones dims)
|
||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 1)
|
||||
(λ: ([ij : Indexes]) (proc js)))]
|
||||
[(= num-ones (- dims 1))
|
||||
(define-values (k m) (find-nontrivial-axis ds))
|
||||
(define js (make-thread-local-indexes dims))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) m 1)
|
||||
(λ: ([ij : Indexes])
|
||||
(let ([js (js)])
|
||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 0))
|
||||
(proc js))))]
|
||||
[else (fail)]))
|
||||
|
||||
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||
(define (->row-matrix xs)
|
||||
(cond [(list? xs) (list->row-matrix xs)]
|
||||
[(array? xs) (array->row-matrix xs)]
|
||||
[else (vector->row-matrix xs)]))
|
||||
|
||||
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||
(define (->col-matrix xs)
|
||||
(cond [(list? xs) (list->col-matrix xs)]
|
||||
[(array? xs) (array->col-matrix xs)]
|
||||
[else (vector->col-matrix xs)]))
|
||||
|
||||
;; =================================================================================================
|
||||
;; Nested conversion
|
||||
|
||||
(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index)))
|
||||
(define (list*-shape xss fail)
|
||||
(define m (length xss))
|
||||
(cond [(m . > . 0)
|
||||
(define n (length (first xss)))
|
||||
(cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss)))
|
||||
(values m n)]
|
||||
[else (fail)])]
|
||||
[else (fail)]))
|
||||
|
||||
(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing)
|
||||
-> (Values Positive-Index Positive-Index)))
|
||||
(define (vector*-shape xss fail)
|
||||
(define m (vector-length xss))
|
||||
(cond [(m . > . 0)
|
||||
(define ns ((inst vector-map Index (Vectorof A)) vector-length xss))
|
||||
(define n (vector-length (unsafe-vector-ref xss 0)))
|
||||
(cond [(and (n . > . 0)
|
||||
(let: loop : Boolean ([i : Nonnegative-Fixnum 1])
|
||||
(cond [(i . fx< . m)
|
||||
(if (= n (vector-length (unsafe-vector-ref xss i)))
|
||||
(loop (fx+ i 1))
|
||||
#f)]
|
||||
[else #t])))
|
||||
(values m n)]
|
||||
[else (fail)])]
|
||||
[else (fail)]))
|
||||
|
||||
(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A)))
|
||||
(define (list*->matrix xss)
|
||||
(define (fail)
|
||||
(raise-argument-error 'list*->matrix
|
||||
"nested lists with rectangular shape and at least one matrix element"
|
||||
xss))
|
||||
(define-values (m n) (list*-shape xss fail))
|
||||
(list->array ((inst vector Index) m n) (apply append xss)))
|
||||
|
||||
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
|
||||
(define (matrix->list* a)
|
||||
(cond [(matrix? a) (array->list (array->list-array a 1))]
|
||||
[else (raise-argument-error 'matrix->list* "matrix?" a)]))
|
||||
|
||||
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
|
||||
(define (vector*->matrix xss)
|
||||
(define (fail)
|
||||
(raise-argument-error 'vector*->matrix
|
||||
"nested vectors with rectangular shape and at least one matrix element"
|
||||
xss))
|
||||
(define-values (m n) (vector*-shape xss fail))
|
||||
(vector->matrix m n (apply vector-append (vector->list xss))))
|
||||
|
||||
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
|
||||
(define (matrix->vector* a)
|
||||
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))]
|
||||
[else (raise-argument-error 'matrix->vector* "matrix?" a)]))
|
||||
) ; module
|
||||
|
||||
(require (for-syntax racket/base
|
||||
syntax/parse)
|
||||
(only-in typed/racket/base :)
|
||||
(require racket/fixnum
|
||||
racket/list
|
||||
racket/vector
|
||||
math/array
|
||||
(submod "." typed-defs))
|
||||
"matrix-types.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(define-syntax (matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [[x0 xs0 ...] [x xs ...] ...])
|
||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))]
|
||||
[(_ [[x0 xs0 ...] [x xs ...] ...] : T)
|
||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))]
|
||||
[(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T)))
|
||||
(raise-syntax-error 'matrix "given empty row" stx #'r)]
|
||||
[(_ (~and [] c) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'matrix "given empty matrix" stx #'c)]))
|
||||
(provide identity-matrix
|
||||
make-matrix
|
||||
build-matrix
|
||||
diagonal-matrix/zero
|
||||
diagonal-matrix
|
||||
block-diagonal-matrix/zero
|
||||
block-diagonal-matrix
|
||||
vandermonde-matrix)
|
||||
|
||||
(define-syntax (row-matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))]
|
||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))]
|
||||
[(_ (~and [] r) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'row-matrix "given empty row" stx #'r)]))
|
||||
;; ===================================================================================================
|
||||
;; Basic constructors
|
||||
|
||||
(define-syntax (col-matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))]
|
||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))]
|
||||
[(_ (~and [] c) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'row-matrix "given empty column" stx #'c)]))
|
||||
(: identity-matrix (Integer -> (Matrix (U 0 1))))
|
||||
(define (identity-matrix m) (diagonal-array 2 m 1 0))
|
||||
|
||||
(: make-matrix (All (A) (Integer Integer A -> (Matrix A))))
|
||||
(define (make-matrix m n x)
|
||||
(make-array (vector m n) x))
|
||||
|
||||
(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A))))
|
||||
(define (build-matrix m n proc)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
|
||||
[else
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n)
|
||||
(λ: ([js : Indexes])
|
||||
(proc (unsafe-vector-ref js 0)
|
||||
(unsafe-vector-ref js 1))))]))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Diagonal matrices
|
||||
|
||||
(: diagonal-matrix/zero (All (A) (Listof A) A -> (Array A)))
|
||||
(define (diagonal-matrix/zero xs zero)
|
||||
(cond [(empty? xs)
|
||||
(raise-argument-error 'diagonal-matrix "nonempty List" xs)]
|
||||
[else
|
||||
(define vs (list->vector xs))
|
||||
(define m (vector-length vs))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m m)
|
||||
(λ: ([js : Indexes])
|
||||
(define i (unsafe-vector-ref js 0))
|
||||
(cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)]
|
||||
[else zero])))]))
|
||||
|
||||
(: diagonal-matrix (case-> ((Listof Real) -> (Array Real))
|
||||
((Listof Number) -> (Array Number))))
|
||||
(define (diagonal-matrix xs)
|
||||
(diagonal-matrix/zero xs 0))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Block diagonal matrices
|
||||
|
||||
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
|
||||
(define (block-diagonal-matrix/zero* as zero)
|
||||
(define num (vector-length as))
|
||||
(define-values (ms ns)
|
||||
(let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty]
|
||||
[ns : (Listof Index) empty]
|
||||
) ([a (in-vector as)])
|
||||
(define-values (m n) (matrix-shape a))
|
||||
(values (cons m ms) (cons n ns)))])
|
||||
(values (reverse ms) (reverse ns))))
|
||||
(define res-m (assert (apply + ms) index?))
|
||||
(define res-n (assert (apply + ns) index?))
|
||||
(define vs ((inst make-vector Index) res-m 0))
|
||||
(define hs ((inst make-vector Index) res-n 0))
|
||||
(define is ((inst make-vector Index) res-m 0))
|
||||
(define js ((inst make-vector Index) res-n 0))
|
||||
(define-values (_res-i _res-j)
|
||||
(for/fold: ([res-i : Nonnegative-Fixnum 0]
|
||||
[res-j : Nonnegative-Fixnum 0]
|
||||
) ([m (in-list ms)]
|
||||
[n (in-list ns)]
|
||||
[k : Nonnegative-Fixnum (in-range num)])
|
||||
(let ([k (assert k index?)])
|
||||
(for: ([i : Nonnegative-Fixnum (in-range m)])
|
||||
(vector-set! vs (unsafe-fx+ res-i i) k)
|
||||
(vector-set! is (unsafe-fx+ res-i i) (assert i index?)))
|
||||
(for: ([j : Nonnegative-Fixnum (in-range n)])
|
||||
(vector-set! hs (unsafe-fx+ res-j j) k)
|
||||
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
|
||||
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
|
||||
(define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) res-m res-n)
|
||||
(λ: ([ij : Indexes])
|
||||
(define i (unsafe-vector-ref ij 0))
|
||||
(define j (unsafe-vector-ref ij 1))
|
||||
(define v (unsafe-vector-ref vs i))
|
||||
(cond [(fx= v (vector-ref hs j))
|
||||
(define proc (unsafe-vector-ref procs v))
|
||||
(define iv (unsafe-vector-ref is i))
|
||||
(define jv (unsafe-vector-ref js j))
|
||||
(unsafe-vector-set! ij 0 iv)
|
||||
(unsafe-vector-set! ij 1 jv)
|
||||
(define res (proc ij))
|
||||
(unsafe-vector-set! ij 0 i)
|
||||
(unsafe-vector-set! ij 1 j)
|
||||
res]
|
||||
[else
|
||||
zero]))))
|
||||
|
||||
(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A)))
|
||||
(define (block-diagonal-matrix/zero as zero)
|
||||
(let ([as (list->vector as)])
|
||||
(define num (vector-length as))
|
||||
(cond [(= num 0)
|
||||
(raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)]
|
||||
[(= num 1)
|
||||
(unsafe-vector-ref as 0)]
|
||||
[else
|
||||
(block-diagonal-matrix/zero* as zero)])))
|
||||
|
||||
(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real))
|
||||
((Listof (Array Number)) -> (Array Number))))
|
||||
(define (block-diagonal-matrix as)
|
||||
(block-diagonal-matrix/zero as 0))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Special matrices
|
||||
|
||||
(: expt-hack (case-> (Real Integer -> Real)
|
||||
(Number Integer -> Number)))
|
||||
;; Stop using this when TR correctly derives expt : Real Integer -> Real
|
||||
(define (expt-hack x n)
|
||||
(cond [(real? x) (assert (expt x n) real?)]
|
||||
[else (expt x n)]))
|
||||
|
||||
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real))
|
||||
((Listof Number) Integer -> (Array Number))))
|
||||
(define (vandermonde-matrix xs n)
|
||||
(cond [(empty? xs)
|
||||
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]
|
||||
[(or (not (index? n)) (zero? n))
|
||||
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
||||
[else
|
||||
(array-axis-expand (list->array xs) 1 n expt-hack)]))
|
||||
|
|
202
collects/math/private/matrix/matrix-conversion.rkt
Normal file
202
collects/math/private/matrix/matrix-conversion.rkt
Normal file
|
@ -0,0 +1,202 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require racket/fixnum
|
||||
racket/list
|
||||
racket/vector
|
||||
math/array
|
||||
"matrix-types.rkt"
|
||||
"utils.rkt"
|
||||
"../array/utils.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(provide
|
||||
;; Flat conversion
|
||||
list->matrix
|
||||
matrix->list
|
||||
vector->matrix
|
||||
matrix->vector
|
||||
->row-matrix
|
||||
->col-matrix
|
||||
;; Nested conversion
|
||||
list*->matrix
|
||||
matrix->list*
|
||||
vector*->matrix
|
||||
matrix->vector*)
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Flat conversion
|
||||
|
||||
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
|
||||
(define (list->matrix m n xs)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
|
||||
[else (list->array (vector m n) xs)]))
|
||||
|
||||
(: matrix->list (All (A) ((Array A) -> (Listof A))))
|
||||
(define (matrix->list a)
|
||||
(array->list (ensure-matrix 'matrix->list a)))
|
||||
|
||||
(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->matrix m n v)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
|
||||
[else (vector->array (vector m n) v)]))
|
||||
|
||||
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
|
||||
(define (matrix->vector a)
|
||||
(array->vector (ensure-matrix 'matrix->vector a)))
|
||||
|
||||
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
|
||||
(define (list->row-matrix xs)
|
||||
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
|
||||
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
|
||||
|
||||
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
|
||||
(define (list->col-matrix xs)
|
||||
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
|
||||
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
|
||||
|
||||
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->row-matrix xs)
|
||||
(define n (vector-length xs))
|
||||
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
|
||||
[else (vector->array ((inst vector Index) 1 n) xs)]))
|
||||
|
||||
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->col-matrix xs)
|
||||
(define n (vector-length xs))
|
||||
(cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)]
|
||||
[else (vector->array ((inst vector Index) n 1) xs)]))
|
||||
|
||||
(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index)))
|
||||
(define (find-nontrivial-axis ds)
|
||||
(define dims (vector-length ds))
|
||||
(let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0])
|
||||
(cond [(k . < . dims) (define dk (unsafe-vector-ref ds k))
|
||||
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
|
||||
[else (values 0 0)])))
|
||||
|
||||
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
|
||||
(define (array->row-matrix arr)
|
||||
(define (fail)
|
||||
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||
(define ds (array-shape arr))
|
||||
(define dims (vector-length ds))
|
||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||
(cond [(zero? (array-size arr)) (fail)]
|
||||
[(row-matrix? arr) arr]
|
||||
[(= num-ones dims)
|
||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 1)
|
||||
(λ: ([ij : Indexes]) (proc js)))]
|
||||
[(= num-ones (- dims 1))
|
||||
(define-values (k n) (find-nontrivial-axis ds))
|
||||
(define js (make-thread-local-indexes dims))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 n)
|
||||
(λ: ([ij : Indexes])
|
||||
(let ([js (js)])
|
||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
|
||||
(proc js))))]
|
||||
[else (fail)]))
|
||||
|
||||
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
|
||||
(define (array->col-matrix arr)
|
||||
(define (fail)
|
||||
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||
(define ds (array-shape arr))
|
||||
(define dims (vector-length ds))
|
||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||
(cond [(zero? (array-size arr)) (fail)]
|
||||
[(col-matrix? arr) arr]
|
||||
[(= num-ones dims)
|
||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 1)
|
||||
(λ: ([ij : Indexes]) (proc js)))]
|
||||
[(= num-ones (- dims 1))
|
||||
(define-values (k m) (find-nontrivial-axis ds))
|
||||
(define js (make-thread-local-indexes dims))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) m 1)
|
||||
(λ: ([ij : Indexes])
|
||||
(let ([js (js)])
|
||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 0))
|
||||
(proc js))))]
|
||||
[else (fail)]))
|
||||
|
||||
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||
(define (->row-matrix xs)
|
||||
(cond [(list? xs) (list->row-matrix xs)]
|
||||
[(array? xs) (array->row-matrix xs)]
|
||||
[else (vector->row-matrix xs)]))
|
||||
|
||||
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||
(define (->col-matrix xs)
|
||||
(cond [(list? xs) (list->col-matrix xs)]
|
||||
[(array? xs) (array->col-matrix xs)]
|
||||
[else (vector->col-matrix xs)]))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Nested conversion
|
||||
|
||||
(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index)))
|
||||
(define (list*-shape xss fail)
|
||||
(define m (length xss))
|
||||
(cond [(m . > . 0)
|
||||
(define n (length (first xss)))
|
||||
(cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss)))
|
||||
(values m n)]
|
||||
[else (fail)])]
|
||||
[else (fail)]))
|
||||
|
||||
(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing)
|
||||
-> (Values Positive-Index Positive-Index)))
|
||||
(define (vector*-shape xss fail)
|
||||
(define m (vector-length xss))
|
||||
(cond [(m . > . 0)
|
||||
(define ns ((inst vector-map Index (Vectorof A)) vector-length xss))
|
||||
(define n (vector-length (unsafe-vector-ref xss 0)))
|
||||
(cond [(and (n . > . 0)
|
||||
(let: loop : Boolean ([i : Nonnegative-Fixnum 1])
|
||||
(cond [(i . fx< . m)
|
||||
(if (= n (vector-length (unsafe-vector-ref xss i)))
|
||||
(loop (fx+ i 1))
|
||||
#f)]
|
||||
[else #t])))
|
||||
(values m n)]
|
||||
[else (fail)])]
|
||||
[else (fail)]))
|
||||
|
||||
(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A)))
|
||||
(define (list*->matrix xss)
|
||||
(define (fail)
|
||||
(raise-argument-error 'list*->matrix
|
||||
"nested lists with rectangular shape and at least one matrix element"
|
||||
xss))
|
||||
(define-values (m n) (list*-shape xss fail))
|
||||
(list->array ((inst vector Index) m n) (apply append xss)))
|
||||
|
||||
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
|
||||
(define (matrix->list* a)
|
||||
(cond [(matrix? a) (array->list (array->list-array a 1))]
|
||||
[else (raise-argument-error 'matrix->list* "matrix?" a)]))
|
||||
|
||||
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
|
||||
(define (vector*->matrix xss)
|
||||
(define (fail)
|
||||
(raise-argument-error 'vector*->matrix
|
||||
"nested vectors with rectangular shape and at least one matrix element"
|
||||
xss))
|
||||
(define-values (m n) (vector*-shape xss fail))
|
||||
(vector->matrix m n (apply vector-append (vector->list xss))))
|
||||
|
||||
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
|
||||
(define (matrix->vector* a)
|
||||
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))]
|
||||
[else (raise-argument-error 'matrix->vector* "matrix?" a)]))
|
|
@ -6,6 +6,7 @@
|
|||
"../unsafe.rkt"
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
"matrix-conversion.rkt"
|
||||
"matrix-arithmetic.rkt"
|
||||
"matrix-basic.rkt"
|
||||
"matrix-column.rkt"
|
||||
|
@ -19,13 +20,6 @@
|
|||
; 4. Pseudo inverse
|
||||
; 5. Eigenvalues and eigenvectors
|
||||
|
||||
; 6. "Bug"
|
||||
; (for*/matrix : Number 2 3 ([i (in-naturals)]) i)
|
||||
; ought to generate a matrix with numbers from 0 to 5.
|
||||
; Problem: In expansion of for/matrix an extra [i (in-range (* m n))]
|
||||
; is added to make sure the comprehension stops.
|
||||
; But TR has problems with #:when so what is the proper expansion ?
|
||||
|
||||
(provide
|
||||
matrix-inverse
|
||||
; row and column
|
||||
|
@ -582,7 +576,7 @@
|
|||
; Note: We project onto vs (not on the original ws)
|
||||
; in order to get numerical stability.
|
||||
(let ([w-minus-proj (array-strict (array- w w-proj))])
|
||||
(if (zero-matrix? w-minus-proj)
|
||||
(if (matrix-zero? w-minus-proj)
|
||||
(loop vs (cdr ws)) ; w in span{vs} => omit it
|
||||
(loop (cons w-minus-proj vs) (cdr ws)))))]))
|
||||
(reverse (loop (list (car ws)) (cdr ws)))]))
|
||||
|
|
|
@ -1,340 +0,0 @@
|
|||
#lang racket
|
||||
|
||||
(provide in-row
|
||||
in-column)
|
||||
|
||||
(require math/array
|
||||
"matrix-types.rkt"
|
||||
"matrix-basic.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
)
|
||||
|
||||
(define (in-row/proc M r)
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ (j) (matrix-ref M r j))
|
||||
; next-pos
|
||||
(λ (j) (+ j 1))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ (j) (< j n))
|
||||
#f #f ))))
|
||||
|
||||
; (in-row M i]
|
||||
; Returns a sequence of all elements of row i,
|
||||
; that is xi0, xi1, xi2, ...
|
||||
(define-sequence-syntax in-row
|
||||
(λ () #'in-row/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
[[(x) (_ M-expr r-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r))
|
||||
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr r-expr)]
|
||||
#'((i x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
|
||||
|
||||
; (in-column M j]
|
||||
; Returns a sequence of all elements of column j,
|
||||
; that is x0j, x1j, x2j, ...
|
||||
|
||||
|
||||
(define (in-column/proc M s)
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ (i) (matrix-ref M i s))
|
||||
; next-pos
|
||||
(λ (i) (+ i 1))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ (i) (< i m))
|
||||
#f #f ))))
|
||||
|
||||
(define-sequence-syntax in-column
|
||||
(λ () #'in-column/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
[[(x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-row "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-col "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
||||
|
||||
(define-syntax (for/matrix-sum: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: sum : (U False (Matrix Number)) #f)
|
||||
(for: (for:-clause ...)
|
||||
(define a (let () . defs+exprs))
|
||||
(set! sum (if sum (array+ (assert sum) a) a)))
|
||||
(assert sum)))]))
|
||||
#|
|
||||
;;;
|
||||
;;; SEQUENCES
|
||||
;;;
|
||||
|
||||
(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
||||
(define (in-row/proc M r)
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ: ([j : Index]) (matrix-ref M r j))
|
||||
; next-pos
|
||||
(λ: ([j : Index]) (assert (+ j 1) index?))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ: ([j : Index ]) (< j n))
|
||||
#f #f))))
|
||||
|
||||
(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
||||
(define (in-column/proc M s)
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ: ([i : Index]) (matrix-ref M i s))
|
||||
; next-pos
|
||||
(λ: ([i : Index]) (assert (+ i 1) index?))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ: ([i : Index]) (< i m))
|
||||
#f #f))))
|
||||
|
||||
; (in-row M i]
|
||||
; Returns a sequence of all elements of row i,
|
||||
; that is xi0, xi1, xi2, ...
|
||||
(define-sequence-syntax in-row
|
||||
(λ () #'in-row/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
[[(x) (_ M-expr r-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r))
|
||||
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr r-expr)]
|
||||
#'((i x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
|
||||
|
||||
; (in-column M j]
|
||||
; Returns a sequence of all elements of column j,
|
||||
; that is x0j, x1j, x2j, ...
|
||||
|
||||
(define-sequence-syntax in-column
|
||||
(λ () #'in-column/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
; M-expr evaluates to column
|
||||
[[(x) (_ M-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d j)])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
; M-expr evaluats to matrix, s-expr to the column index
|
||||
[[(x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-row "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-col "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
||||
|#
|
||||
#;
|
||||
(module* test #f
|
||||
(require rackunit)
|
||||
; "matrix-sequences.rkt"
|
||||
(check-equal? (for/list ([x (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
||||
'(3 4))
|
||||
(check-equal? (for/list ([(i x) (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)])
|
||||
(list i x))
|
||||
'((0 3) (1 4)))
|
||||
(check-equal? (for/list ([x (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
||||
'(2 4))
|
||||
(check-equal? (for/list ([(i x) (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)])
|
||||
(list i x))
|
||||
'((0 2) (1 4))))
|
35
collects/math/private/matrix/matrix-syntax.rkt
Normal file
35
collects/math/private/matrix/matrix-syntax.rkt
Normal file
|
@ -0,0 +1,35 @@
|
|||
#lang racket/base
|
||||
|
||||
(require (for-syntax racket/base
|
||||
syntax/parse)
|
||||
(only-in typed/racket/base :)
|
||||
math/array)
|
||||
|
||||
(provide matrix row-matrix col-matrix)
|
||||
|
||||
(define-syntax (matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [[x0 xs0 ...] [x xs ...] ...])
|
||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))]
|
||||
[(_ [[x0 xs0 ...] [x xs ...] ...] : T)
|
||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))]
|
||||
[(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T)))
|
||||
(raise-syntax-error 'matrix "given empty row" stx #'r)]
|
||||
[(_ (~and [] c) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'matrix "given empty matrix" stx #'c)]
|
||||
[(_ x (~optional (~seq : T)))
|
||||
(raise-syntax-error 'matrix "expected two-dimensional data" stx)]))
|
||||
|
||||
(define-syntax (row-matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))]
|
||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))]
|
||||
[(_ (~and [] r) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'row-matrix "given empty row" stx #'r)]))
|
||||
|
||||
(define-syntax (col-matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))]
|
||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))]
|
||||
[(_ (~and [] c) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'row-matrix "given empty column" stx #'c)]))
|
|
@ -27,44 +27,53 @@
|
|||
(define-type (Column-Matrix A) (Matrix A))
|
||||
; a column vector represented as a matrix
|
||||
|
||||
(define matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (> (array-size arr) 0)
|
||||
(= (array-dims arr) 2))))
|
||||
(: matrix? (All (A) ((Array A) -> Boolean)))
|
||||
(define (matrix? arr)
|
||||
(and (> (array-size arr) 0)
|
||||
(= (array-dims arr) 2)))
|
||||
|
||||
(define square-matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (matrix? arr)
|
||||
(let ([sh (array-shape arr)])
|
||||
(= (vector-ref sh 0) (vector-ref sh 1))))))
|
||||
(: square-matrix? (All (A) ((Array A) -> Boolean)))
|
||||
(define (square-matrix? arr)
|
||||
(define ds (array-shape arr))
|
||||
(and (= (vector-length ds) 2)
|
||||
(let ([d0 (unsafe-vector-ref ds 0)]
|
||||
[d1 (unsafe-vector-ref ds 1)])
|
||||
(and (> d0 0) (> d1 0) (= d0 d1)))))
|
||||
|
||||
(define row-matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (matrix? arr)
|
||||
(= (vector-ref (array-shape arr) 0) 1))))
|
||||
(: row-matrix? (All (A) ((Array A) -> Boolean)))
|
||||
(define (row-matrix? arr)
|
||||
(define ds (array-shape arr))
|
||||
(and (= (vector-length ds) 2)
|
||||
(= (unsafe-vector-ref ds 0) 1)
|
||||
(> (unsafe-vector-ref ds 1) 0)))
|
||||
|
||||
(define col-matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (matrix? arr)
|
||||
(= (vector-ref (array-shape arr) 1) 1))))
|
||||
(: col-matrix? (All (A) ((Array A) -> Boolean)))
|
||||
(define (col-matrix? arr)
|
||||
(define ds (array-shape arr))
|
||||
(and (= (vector-length ds) 2)
|
||||
(> (unsafe-vector-ref ds 0) 0)
|
||||
(= (unsafe-vector-ref ds 1) 1)))
|
||||
|
||||
(: matrix-shape : (All (A) (Matrix A) -> (Values Index Index)))
|
||||
(: matrix-shape (All (A) ((Array A) -> (Values Index Index))))
|
||||
(define (matrix-shape a)
|
||||
(cond [(matrix? a) (define sh (array-shape a))
|
||||
(values (unsafe-vector-ref sh 0) (unsafe-vector-ref sh 1))]
|
||||
[else (raise-argument-error 'matrix-shape "matrix?" a)]))
|
||||
(define ds (array-shape a))
|
||||
(if (and (> (array-size a) 0)
|
||||
(= (vector-length ds) 2))
|
||||
(values (unsafe-vector-ref ds 0)
|
||||
(unsafe-vector-ref ds 1))
|
||||
(raise-argument-error 'matrix-shape "matrix?" a)))
|
||||
|
||||
(: square-matrix-size (All (A) ((Matrix A) -> Index)))
|
||||
(: square-matrix-size (All (A) ((Array A) -> Index)))
|
||||
(define (square-matrix-size arr)
|
||||
(cond [(square-matrix? arr) (unsafe-vector-ref (array-shape arr) 0)]
|
||||
[else (raise-argument-error 'square-matrix-size "square-matrix?" arr)]))
|
||||
|
||||
(: matrix-num-rows (All (A) ((Matrix A) -> Index)))
|
||||
(: matrix-num-rows (All (A) ((Array A) -> Index)))
|
||||
(define (matrix-num-rows a)
|
||||
(cond [(matrix? a) (vector-ref (array-shape a) 0)]
|
||||
[else (raise-argument-error 'matrix-col-length "matrix?" a)]))
|
||||
|
||||
(: matrix-num-cols (All (A) ((Matrix A) -> Index)))
|
||||
(: matrix-num-cols (All (A) ((Array A) -> Index)))
|
||||
(define (matrix-num-cols a)
|
||||
(cond [(matrix? a) (vector-ref (array-shape a) 1)]
|
||||
[else (raise-argument-error 'matrix-row-length "matrix?" a)]))
|
||||
|
|
|
@ -73,8 +73,6 @@
|
|||
|
||||
) ; module
|
||||
|
||||
(require 'syntax-defs)
|
||||
|
||||
(module untyped-defs typed/racket/base
|
||||
(require math/array
|
||||
(submod ".." syntax-defs)
|
||||
|
@ -102,4 +100,5 @@
|
|||
|
||||
) ; module
|
||||
|
||||
(require 'untyped-defs)
|
||||
(require 'syntax-defs
|
||||
'untyped-defs)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
(define (matrix-shapes name arr . brrs)
|
||||
(define-values (m n) (matrix-shape arr))
|
||||
(unless (andmap (λ: ([brr : (Matrix Any)])
|
||||
(match-define (vector bm bn) (array-shape brr))
|
||||
(define-values (bm bn) (matrix-shape brr))
|
||||
(and (= bm m) (= bn n)))
|
||||
brrs)
|
||||
(error name
|
||||
|
|
|
@ -7,192 +7,673 @@
|
|||
"../private/matrix/matrix-column.rkt"
|
||||
"test-utils.rkt")
|
||||
|
||||
(: random-matrix (Integer Integer Integer -> (Matrix Integer)))
|
||||
;; Generates a random matrix with integer elements < k. Useful to test properties.
|
||||
(define (random-matrix m n k)
|
||||
(: random-matrix (case-> (Integer Integer -> (Matrix Integer))
|
||||
(Integer Integer Integer -> (Matrix Integer))))
|
||||
;; Generates a random matrix with Natural elements < k. Useful to test properties.
|
||||
(define (random-matrix m n [k 100])
|
||||
(array-strict (build-array (vector m n) (λ (_) (random k)))))
|
||||
|
||||
(define nonmatrices
|
||||
(list (make-array #() 0)
|
||||
(make-array #(1) 0)
|
||||
(make-array #(1 0) 0)
|
||||
(make-array #(0 1) 0)
|
||||
(make-array #(0 0) 0)
|
||||
(make-array #(1 1 1) 0)))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Types
|
||||
;; Literal syntax
|
||||
|
||||
(check-equal? (matrix [[1]])
|
||||
(array #[#[1]]))
|
||||
|
||||
(check-equal? (matrix [[1 2 3 4]])
|
||||
(array #[#[1 2 3 4]]))
|
||||
|
||||
(check-equal? (matrix [[1 2] [3 4]])
|
||||
(array #[#[1 2] #[3 4]]))
|
||||
|
||||
(check-equal? (matrix [[1] [2] [3] [4]])
|
||||
(array #[#[1] #[2] #[3] #[4]]))
|
||||
|
||||
(check-equal? (row-matrix [1 2 3 4])
|
||||
(matrix [[1 2 3 4]]))
|
||||
|
||||
(check-equal? (col-matrix [1 2 3 4])
|
||||
(matrix [[1] [2] [3] [4]]))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Predicates
|
||||
|
||||
(check-true (matrix? (array #[#[1]])))
|
||||
(check-false (matrix? (array #[1])))
|
||||
(check-false (matrix? (array 1)))
|
||||
(check-false (matrix? (array #[])))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-false (matrix? a)))
|
||||
|
||||
(check-true (square-matrix? (matrix [[1]])))
|
||||
(check-true (square-matrix? (matrix [[1 1] [1 1]])))
|
||||
(check-false (square-matrix? (matrix [[1 2]])))
|
||||
(check-false (square-matrix? (matrix [[1] [2]])))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-false (square-matrix? a)))
|
||||
|
||||
(check-true (row-matrix? (matrix [[1 2 3 4]])))
|
||||
(check-true (row-matrix? (matrix [[1]])))
|
||||
(check-false (row-matrix? (matrix [[1] [2] [3] [4]])))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-false (row-matrix? a)))
|
||||
|
||||
(check-true (col-matrix? (matrix [[1] [2] [3] [4]])))
|
||||
(check-true (col-matrix? (matrix [[1]])))
|
||||
(check-false (col-matrix? (matrix [[1 2 3 4]])))
|
||||
(check-false (col-matrix? (array #[1])))
|
||||
(check-false (col-matrix? (array 1)))
|
||||
(check-false (col-matrix? (array #[])))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-false (col-matrix? a)))
|
||||
|
||||
(check-true (matrix-zero? (make-matrix 4 3 0)))
|
||||
(check-true (matrix-zero? (make-matrix 4 3 0.0)))
|
||||
(check-true (matrix-zero? (make-matrix 4 3 0+0.0i)))
|
||||
(check-false (matrix-zero? (row-matrix [0 0 0 0 1])))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-zero? a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Matrix multiplication
|
||||
;; Accessors
|
||||
|
||||
(check-equal? (matrix* (identity-matrix 2)
|
||||
(matrix [[1 20] [300 4000]]))
|
||||
(matrix [[1 20] [300 4000]]))
|
||||
;; matrix-shape
|
||||
|
||||
(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]])
|
||||
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||
(matrix [[30 36 42] [66 81 96] [102 126 150]]))
|
||||
(check-equal? (let-values ([(m n) (matrix-shape (matrix [[1 2 3] [4 5 6]]))])
|
||||
(list m n))
|
||||
(list 2 3))
|
||||
|
||||
(let ([m0 (random-matrix 4 5 100)]
|
||||
[m1 (random-matrix 5 2 100)]
|
||||
[m2 (random-matrix 2 10 100)])
|
||||
(check-equal? (matrix* (matrix* m0 m1) m2)
|
||||
(matrix* m0 (matrix* m1 m2))))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (let-values ([(m n) (matrix-shape a)])
|
||||
(void)))))
|
||||
|
||||
;; square-matrix-size
|
||||
|
||||
(check-equal? (square-matrix-size (matrix [[1 2] [3 4]]))
|
||||
2)
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (square-matrix-size (matrix [[1 2]]))))
|
||||
(check-exn exn:fail:contract? (λ () (square-matrix-size (matrix [[1] [2]]))))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (square-matrix-size a))))
|
||||
|
||||
;; matrix-num-rows
|
||||
|
||||
(check-equal? (matrix-num-rows (matrix [[1 2 3] [4 5 6]]))
|
||||
2)
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-num-rows a))))
|
||||
|
||||
;; matrix-num-cols
|
||||
|
||||
(check-equal? (matrix-num-cols (matrix [[1 2 3] [4 5 6]]))
|
||||
3)
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-num-cols a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Construction
|
||||
;; Constructors
|
||||
|
||||
;; identity-matrix
|
||||
|
||||
(check-equal? (identity-matrix 1) (matrix [[1]]))
|
||||
(check-equal? (identity-matrix 2) (matrix [[1 0] [0 1]]))
|
||||
(check-equal? (identity-matrix 3) (matrix [[1 0 0] [0 1 0] [0 0 1]]))
|
||||
(check-exn exn:fail:contract? (λ () (identity-matrix 0)))
|
||||
|
||||
;; make-matrix
|
||||
|
||||
(check-equal? (make-matrix 1 1 4) (matrix [[4]]))
|
||||
(check-equal? (make-matrix 2 2 3) (matrix [[3 3] [3 3]]))
|
||||
(check-exn exn:fail:contract? (λ () (make-matrix 1 0 4)))
|
||||
(check-exn exn:fail:contract? (λ () (make-matrix 0 1 4)))
|
||||
|
||||
;; build-matrix
|
||||
|
||||
(check-equal? (build-matrix 4 4 (λ: ([i : Index] [j : Index])
|
||||
(+ i j)))
|
||||
(build-array #(4 4) (λ: ([js : Indexes])
|
||||
(+ (vector-ref js 0) (vector-ref js 1)))))
|
||||
(check-exn exn:fail:contract? (λ () (build-matrix 1 0 (λ: ([i : Index] [j : Index]) (+ i j)))))
|
||||
(check-exn exn:fail:contract? (λ () (build-matrix 0 1 (λ: ([i : Index] [j : Index]) (+ i j)))))
|
||||
|
||||
;; diagonal-matrix
|
||||
|
||||
(check-equal? (diagonal-matrix '(1 2 3 4))
|
||||
(matrix [[1 0 0 0]
|
||||
[0 2 0 0]
|
||||
[0 0 3 0]
|
||||
[0 0 0 4]]))
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (diagonal-matrix '())))
|
||||
|
||||
;; block-diagonal-matrix
|
||||
|
||||
(let ([m (random-matrix 4 4 100)])
|
||||
(check-equal? (block-diagonal-matrix (list m))
|
||||
m))
|
||||
|
||||
(check-equal?
|
||||
(block-diagonal-matrix
|
||||
(list
|
||||
(matrix [[1 2] [3 4]])
|
||||
(matrix [[1 2 3] [4 5 6]])
|
||||
(matrix [[1] [3] [5]])))
|
||||
(matrix
|
||||
[[1 2 0 0 0 0]
|
||||
[3 4 0 0 0 0]
|
||||
[0 0 1 2 3 0]
|
||||
[0 0 4 5 6 0]
|
||||
[0 0 0 0 0 1]
|
||||
[0 0 0 0 0 3]
|
||||
[0 0 0 0 0 5]]))
|
||||
(list (matrix [[1 2] [3 4]])
|
||||
(matrix [[1 2 3] [4 5 6]])
|
||||
(matrix [[1] [3] [5]])
|
||||
(matrix [[2 4 6]])))
|
||||
(matrix [[1 2 0 0 0 0 0 0 0]
|
||||
[3 4 0 0 0 0 0 0 0]
|
||||
[0 0 1 2 3 0 0 0 0]
|
||||
[0 0 4 5 6 0 0 0 0]
|
||||
[0 0 0 0 0 1 0 0 0]
|
||||
[0 0 0 0 0 3 0 0 0]
|
||||
[0 0 0 0 0 5 0 0 0]
|
||||
[0 0 0 0 0 0 2 4 6]]))
|
||||
|
||||
(check-equal?
|
||||
(block-diagonal-matrix (map (λ: ([i : Integer]) (matrix [[i]])) '(1 2 3 4)))
|
||||
(diagonal-matrix '(1 2 3 4)))
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (block-diagonal-matrix '())))
|
||||
|
||||
;; Vandermonde matrix
|
||||
|
||||
(check-equal? (vandermonde-matrix '(10) 1)
|
||||
(matrix [[1]]))
|
||||
(check-equal? (vandermonde-matrix '(10) 4)
|
||||
(matrix [[1 10 100 1000]]))
|
||||
(check-equal? (vandermonde-matrix '(1 2 3 4) 3)
|
||||
(matrix [[1 1 1] [1 2 4] [1 3 9] [1 4 16]]))
|
||||
(check-exn exn:fail:contract? (λ () (vandermonde-matrix '() 1)))
|
||||
(check-exn exn:fail:contract? (λ () (vandermonde-matrix '(1) 0)))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Flat conversion
|
||||
|
||||
(begin
|
||||
(begin "matrix-types.rkt"
|
||||
(list
|
||||
'matrix?
|
||||
(matrix? (list*->array '[[1 2] [3 4]] real? ))
|
||||
(not (matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? ))))
|
||||
(list
|
||||
'square-matrix?
|
||||
(square-matrix? (list*->array '[[1 2] [3 4]] real? ))
|
||||
(not (square-matrix? (list*->array '[[1 2 3] [4 5 6]] real? ))))
|
||||
(list
|
||||
'square-matrix-size
|
||||
(= 3 (square-matrix-size (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real?))))
|
||||
(list
|
||||
'matrix=
|
||||
(matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? ))
|
||||
#;(not (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? ))))
|
||||
(list
|
||||
'matrix-shape
|
||||
(let-values ([(m n) (matrix-shape (list*->matrix '[[1 2 3] [4 5 6]]))])
|
||||
(equal? (list m n) '(2 3)))))
|
||||
|
||||
(begin "matrix-constructors.rkt"
|
||||
(list
|
||||
'identity-matrix
|
||||
(equal? (array->list* (identity-matrix 1)) '[[1]])
|
||||
(equal? (array->list* (identity-matrix 2)) '[[1 0] [0 1]])
|
||||
(equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]]))
|
||||
(list
|
||||
'const-matrix
|
||||
(equal? (array->list* (make-matrix 2 3 0)) '((0 0 0) (0 0 0)))
|
||||
(equal? (array->list* (make-matrix 2 3 0.)) '((0. 0. 0.) (0. 0. 0.))))
|
||||
(list
|
||||
'matrix->list
|
||||
(equal? (matrix->list* (list*->matrix '((1 2) (3 4)))) '((1 2) (3 4)))
|
||||
(equal? (matrix->list* (list*->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.))))
|
||||
(list
|
||||
'matrix->vector
|
||||
(equal? (matrix->vector* ((inst vector*->matrix Integer) '#(#(1 2) #(3 4))))
|
||||
'#(#(1 2) #(3 4)))
|
||||
(equal? (matrix->vector* ((inst vector*->matrix Flonum) '#(#(1. 2.) #(3. 4.))))
|
||||
'#(#(1. 2.) #(3. 4.))))
|
||||
(list
|
||||
'matrix-row
|
||||
(equal? (matrix-row (identity-matrix 3) 0) (list*->matrix '[[1 0 0]]))
|
||||
(equal? (matrix-row (identity-matrix 3) 1) (list*->matrix '[[0 1 0]]))
|
||||
(equal? (matrix-row (identity-matrix 3) 2) (list*->matrix '[[0 0 1]])))
|
||||
(list
|
||||
'matrix-col
|
||||
(equal? (matrix-col (identity-matrix 3) 0) (list*->matrix '[[1] [0] [0]]))
|
||||
(equal? (matrix-col (identity-matrix 3) 1) (list*->matrix '[[0] [1] [0]]))
|
||||
(equal? (matrix-col (identity-matrix 3) 2) (list*->matrix '[[0] [0] [1]])))
|
||||
(list
|
||||
'submatrix
|
||||
(equal? (submatrix (identity-matrix 3)
|
||||
(in-range 0 1) (in-range 0 2)) (list*->matrix '[[1 0]]))
|
||||
(equal? (submatrix (identity-matrix 3)
|
||||
(in-range 0 2) (in-range 0 3)) (list*->matrix '[[1 0 0] [0 1 0]]))))
|
||||
|
||||
(begin
|
||||
"matrix-pointwise.rkt"
|
||||
(let ()
|
||||
(define A (list*->matrix '[[1 2] [3 4]]))
|
||||
(define ~A (list*->matrix '[[-1 -2] [-3 -4]]))
|
||||
(define B (list*->matrix '[[5 6] [7 8]]))
|
||||
(define A+B (list*->matrix '[[6 8] [10 12]]))
|
||||
(define A-B (list*->matrix '[[-4 -4] [-4 -4]]))
|
||||
(list 'matrix+ (equal? (matrix+ A B) A+B))
|
||||
(list 'matrix-
|
||||
(equal? (matrix- A B) A-B)
|
||||
(equal? (matrix- A) ~A))))
|
||||
|
||||
(begin
|
||||
"matrix-expt.rkt"
|
||||
(let ()
|
||||
(define A (list*->matrix '[[1 2] [3 4]]))
|
||||
(list
|
||||
'matrix-expt
|
||||
(equal? (matrix-expt A 0) (identity-matrix 2))
|
||||
(equal? (matrix-expt A 1) A)
|
||||
(equal? (matrix-expt A 2) (list*->matrix '[[7 10] [15 22]]))
|
||||
(equal? (matrix-expt A 3) (list*->matrix '[[37 54] [81 118]]))
|
||||
(equal? (matrix-expt A 8) (list*->matrix '[[165751 241570] [362355 528106]])))))
|
||||
(check-equal? (list->matrix 1 3 '(1 2 3)) (row-matrix [1 2 3]))
|
||||
(check-equal? (list->matrix 3 1 '(1 2 3)) (col-matrix [1 2 3]))
|
||||
(check-exn exn:fail:contract? (λ () (list->matrix 0 1 '())))
|
||||
(check-exn exn:fail:contract? (λ () (list->matrix 1 0 '())))
|
||||
(check-exn exn:fail:contract? (λ () (list->matrix 1 1 '(1 2))))
|
||||
|
||||
(check-equal? (vector->matrix 1 3 #(1 2 3)) (row-matrix [1 2 3]))
|
||||
(check-equal? (vector->matrix 3 1 #(1 2 3)) (col-matrix [1 2 3]))
|
||||
(check-exn exn:fail:contract? (λ () (vector->matrix 0 1 #())))
|
||||
(check-exn exn:fail:contract? (λ () (vector->matrix 1 0 #())))
|
||||
(check-exn exn:fail:contract? (λ () (vector->matrix 1 1 #(1 2))))
|
||||
|
||||
(check-equal? (->row-matrix '(1 2 3)) (row-matrix [1 2 3]))
|
||||
(check-equal? (->row-matrix #(1 2 3)) (row-matrix [1 2 3]))
|
||||
(check-equal? (->row-matrix (row-matrix [1 2 3])) (row-matrix [1 2 3]))
|
||||
(check-equal? (->row-matrix (col-matrix [1 2 3])) (row-matrix [1 2 3]))
|
||||
(check-equal? (->row-matrix (make-array #() 1)) (row-matrix [1]))
|
||||
(check-equal? (->row-matrix (make-array #(3) 1)) (row-matrix [1 1 1]))
|
||||
(check-equal? (->row-matrix (make-array #(1 3 1) 1)) (row-matrix [1 1 1]))
|
||||
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(2 3 1) 1))))
|
||||
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(1 3 2) 1))))
|
||||
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(0 3) 1))))
|
||||
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(3 0) 1))))
|
||||
|
||||
(check-equal? (->col-matrix '(1 2 3)) (col-matrix [1 2 3]))
|
||||
(check-equal? (->col-matrix #(1 2 3)) (col-matrix [1 2 3]))
|
||||
(check-equal? (->col-matrix (col-matrix [1 2 3])) (col-matrix [1 2 3]))
|
||||
(check-equal? (->col-matrix (row-matrix [1 2 3])) (col-matrix [1 2 3]))
|
||||
(check-equal? (->col-matrix (make-array #() 1)) (col-matrix [1]))
|
||||
(check-equal? (->col-matrix (make-array #(3) 1)) (col-matrix [1 1 1]))
|
||||
(check-equal? (->col-matrix (make-array #(1 3 1) 1)) (col-matrix [1 1 1]))
|
||||
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(2 3 1) 1))))
|
||||
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(1 3 2) 1))))
|
||||
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(0 3) 1))))
|
||||
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(3 0) 1))))
|
||||
|
||||
(check-equal? (matrix->list (matrix [[1 2 3] [4 5 6]])) '(1 2 3 4 5 6))
|
||||
(check-equal? (matrix->list (row-matrix [1 2 3])) '(1 2 3))
|
||||
(check-equal? (matrix->list (col-matrix [1 2 3])) '(1 2 3))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix->list a))))
|
||||
|
||||
(check-equal? (matrix->vector (matrix [[1 2 3] [4 5 6]])) #(1 2 3 4 5 6))
|
||||
(check-equal? (matrix->vector (row-matrix [1 2 3])) #(1 2 3))
|
||||
(check-equal? (matrix->vector (col-matrix [1 2 3])) #(1 2 3))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix->vector a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Nested conversion
|
||||
|
||||
(check-equal? (list*->matrix '((1 2 3) (4 5 6))) (matrix [[1 2 3] [4 5 6]]))
|
||||
(check-exn exn:fail:contract? (λ () (list*->matrix '((1 2 3) (4 5)))))
|
||||
(check-exn exn:fail:contract? (λ () (list*->matrix '(() () ()))))
|
||||
(check-exn exn:fail:contract? (λ () (list*->matrix '())))
|
||||
|
||||
(check-equal? ((inst vector*->matrix Integer) #(#(1 2 3) #(4 5 6))) (matrix [[1 2 3] [4 5 6]]))
|
||||
(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #(#(1 2 3) #(4 5)))))
|
||||
(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #(#() #() #()))))
|
||||
(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #())))
|
||||
|
||||
(check-equal? (matrix->list* (matrix [[1 2 3] [4 5 6]])) '((1 2 3) (4 5 6)))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix->list* a))))
|
||||
|
||||
(check-equal? (matrix->vector* (matrix [[1 2 3] [4 5 6]])) #(#(1 2 3) #(4 5 6)))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix->vector* a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Equality
|
||||
|
||||
(check-true (matrix= (matrix [[1 2 3]
|
||||
[4 5 6]])
|
||||
(matrix [[1.0 2.0 3.0]
|
||||
[4.0 5.0 6.0]])))
|
||||
|
||||
(check-true (matrix= (matrix [[1 2 3]
|
||||
[4 5 6]])
|
||||
(matrix [[1.0 2.0 3.0]
|
||||
[4.0 5.0 6.0]])
|
||||
(matrix [[1.0+0.0i 2.0+0.0i 3.0+0.0i]
|
||||
[4.0+0.0i 5.0+0.0i 6.0+0.0i]])))
|
||||
|
||||
(check-false (matrix= (matrix [[1 2 3] [4 5 6]])
|
||||
(matrix [[1 2 3] [4 5 7]])))
|
||||
|
||||
(check-false (matrix= (matrix [[0 2 3] [4 5 6]])
|
||||
(matrix [[1 2 3] [4 5 7]])))
|
||||
|
||||
(check-false (matrix= (matrix [[1 2 3] [4 5 6]])
|
||||
(matrix [[1 4] [2 5] [3 6]])))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix= a (matrix [[1]]))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix= (matrix [[1]]) a)))
|
||||
(check-exn exn:fail:contract? (λ () (matrix= (matrix [[1]]) (matrix [[1]]) a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Pointwise operations
|
||||
|
||||
(define-syntax-rule (test-matrix-map (matrix-map ...) (array-map ...))
|
||||
(begin
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-map ... a)))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-map ... (matrix [[1]]) a))))
|
||||
|
||||
(for*: ([m '(2 3 4)]
|
||||
[n '(2 3 4)])
|
||||
(define a0 (random-matrix m n))
|
||||
(define a1 (random-matrix m n))
|
||||
(define a2 (random-matrix m n))
|
||||
(check-equal? (matrix-map ... a0)
|
||||
(array-map ... a0))
|
||||
(check-equal? (matrix-map ... a0 a1)
|
||||
(array-map ... a0 a1))
|
||||
(check-equal? (matrix-map ... a0 a1 a2)
|
||||
(array-map ... a0 a1 a2))
|
||||
;; Don't know why this (void) is necessary, but TR complains without it
|
||||
(void))))
|
||||
|
||||
(test-matrix-map (matrix-map -) (array-map -))
|
||||
(test-matrix-map ((values matrix-map) -) (array-map -))
|
||||
|
||||
(test-matrix-map (matrix+) (array+))
|
||||
(test-matrix-map ((values matrix+)) (array+))
|
||||
|
||||
(test-matrix-map (matrix-) (array-))
|
||||
(test-matrix-map ((values matrix-)) (array-))
|
||||
|
||||
(check-equal? (matrix-sum (list (matrix [[1 2 3] [4 5 6]])))
|
||||
(matrix [[1 2 3] [4 5 6]]))
|
||||
(check-equal? (matrix-sum (list (matrix [[1 2 3] [4 5 6]])
|
||||
(matrix [[0 1 2] [3 4 5]])))
|
||||
(matrix+ (matrix [[1 2 3] [4 5 6]])
|
||||
(matrix [[0 1 2] [3 4 5]])))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-sum '())))
|
||||
|
||||
(check-equal? (matrix-scale (matrix [[1 2 3] [4 5 6]]) 10)
|
||||
(matrix [[10 20 30] [40 50 60]]))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-scale a 0))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Multiplication
|
||||
|
||||
(define-syntax-rule (test-matrix* matrix*)
|
||||
(begin
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix* a (matrix [[1]])))))
|
||||
|
||||
(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]])
|
||||
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||
(matrix [[30 36 42] [66 81 96] [102 126 150]]))
|
||||
|
||||
(check-equal? (matrix* (row-matrix [1 2 3 4])
|
||||
(col-matrix [1 2 3 4]))
|
||||
(matrix [[30]]))
|
||||
|
||||
(check-equal? (matrix* (col-matrix [1 2 3 4])
|
||||
(row-matrix [1 2 3 4]))
|
||||
(matrix [[1 2 3 4]
|
||||
[2 4 6 8]
|
||||
[3 6 9 12]
|
||||
[4 8 12 16]]))
|
||||
|
||||
(check-equal? (matrix* (matrix [[3]]) (matrix [[7]]))
|
||||
(matrix [[21]]))
|
||||
|
||||
;; Left/right identity
|
||||
(let ([m (random-matrix 2 2)])
|
||||
(check-equal? (matrix* (identity-matrix 2) m)
|
||||
m)
|
||||
(check-equal? (matrix* m (identity-matrix 2))
|
||||
m))
|
||||
|
||||
;; Shape
|
||||
(let ([m0 (random-matrix 4 5)]
|
||||
[m1 (random-matrix 5 2)]
|
||||
[m2 (random-matrix 2 10)])
|
||||
(check-equal? (let-values ([(m n) (matrix-shape (matrix* m0 m1))])
|
||||
(list m n))
|
||||
(list 4 2))
|
||||
(check-equal? (let-values ([(m n) (matrix-shape (matrix* m1 m2))])
|
||||
(list m n))
|
||||
(list 5 10))
|
||||
(check-equal? (let-values ([(m n) (matrix-shape (matrix* m0 m1 m2))])
|
||||
(list m n))
|
||||
(list 4 10)))
|
||||
|
||||
(check-exn exn:fail? (λ () (matrix* (random-matrix 1 2) (random-matrix 3 2))))
|
||||
|
||||
;; Associativity
|
||||
(let ([m0 (random-matrix 4 5)]
|
||||
[m1 (random-matrix 5 2)]
|
||||
[m2 (random-matrix 2 10)])
|
||||
(check-equal? (matrix* m0 m1 m2)
|
||||
(matrix* (matrix* m0 m1) m2))
|
||||
(check-equal? (matrix* (matrix* m0 m1) m2)
|
||||
(matrix* m0 (matrix* m1 m2))))
|
||||
))
|
||||
|
||||
(test-matrix* matrix*)
|
||||
;; `matrix*' is an inlining macro, so we need to check the function version as well
|
||||
(test-matrix* (values matrix*))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Exponentiation
|
||||
|
||||
(let ([A (matrix [[1 2] [3 4]])])
|
||||
(check-equal? (matrix-expt A 0) (identity-matrix 2))
|
||||
(check-equal? (matrix-expt A 1) A)
|
||||
(check-equal? (matrix-expt A 2) (matrix [[7 10] [15 22]]))
|
||||
(check-equal? (matrix-expt A 3) (matrix [[37 54] [81 118]]))
|
||||
(check-equal? (matrix-expt A 8) (matrix [[165751 241570] [362355 528106]])))
|
||||
|
||||
(check-equal? (matrix-expt (matrix [[2]]) 10) (matrix [[(expt 2 10)]]))
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (matrix-expt (row-matrix [1 2 3]) 0)))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-expt (col-matrix [1 2 3]) 0)))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-expt a 0))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Comprehensions
|
||||
|
||||
;; for/matrix and friends are defined in terms of for/array and friends, so we only need to test that
|
||||
;; it works for one case each, and that they properly raise exceptions when given zero-length axes
|
||||
|
||||
(check-equal?
|
||||
(for/matrix 2 2 ([i (in-range 4)]) i)
|
||||
(matrix [[0 1] [2 3]]))
|
||||
|
||||
#;; TR can't type this, but it's defined using exactly the same wrapper as `for/matrix'
|
||||
(check-equal?
|
||||
(for*/matrix 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j))
|
||||
(matrix [[0 1] [1 2]]))
|
||||
|
||||
(check-equal?
|
||||
(for/matrix: 2 2 ([i (in-range 4)]) i)
|
||||
(matrix [[0 1] [2 3]]))
|
||||
|
||||
(check-equal?
|
||||
(for*/matrix: 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j))
|
||||
(matrix [[0 1] [1 2]]))
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (for/matrix 2 0 () 0)))
|
||||
(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0)))
|
||||
(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0)))
|
||||
(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0)))
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (for/matrix: 2 0 () 0)))
|
||||
(check-exn exn:fail:contract? (λ () (for/matrix: 0 2 () 0)))
|
||||
(check-exn exn:fail:contract? (λ () (for*/matrix: 2 0 () 0)))
|
||||
(check-exn exn:fail:contract? (λ () (for*/matrix: 0 2 () 0)))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Extraction
|
||||
|
||||
;; matrix-ref
|
||||
|
||||
(let ([a (matrix [[10 11] [12 13]])])
|
||||
(check-equal? (matrix-ref a 0 0) 10)
|
||||
(check-equal? (matrix-ref a 0 1) 11)
|
||||
(check-equal? (matrix-ref a 1 0) 12)
|
||||
(check-equal? (matrix-ref a 1 1) 13)
|
||||
(check-exn exn:fail? (λ () (matrix-ref a 2 0)))
|
||||
(check-exn exn:fail? (λ () (matrix-ref a 0 2)))
|
||||
(check-exn exn:fail? (λ () (matrix-ref a -1 0)))
|
||||
(check-exn exn:fail? (λ () (matrix-ref a 0 -1))))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-ref a 0 0))))
|
||||
|
||||
;; matrix-diagonal
|
||||
|
||||
(check-equal? (matrix-diagonal (diagonal-matrix '(1 2 3 4)))
|
||||
(array #[1 2 3 4]))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-diagonal a))))
|
||||
|
||||
;; submatrix
|
||||
|
||||
(check-equal? (submatrix (identity-matrix 8) (:: 2 4) (:: 2 4))
|
||||
(identity-matrix 2))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (submatrix a '(0) '(0)))))
|
||||
|
||||
;; matrix-row
|
||||
|
||||
(let ([a (matrix [[1 2 3] [4 5 6]])])
|
||||
(check-equal? (matrix-row a 0) (row-matrix [1 2 3]))
|
||||
(check-equal? (matrix-row a 1) (row-matrix [4 5 6]))
|
||||
(check-exn exn:fail? (λ () (matrix-row a -1)))
|
||||
(check-exn exn:fail? (λ () (matrix-row a 2))))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-row a 0))))
|
||||
|
||||
;; matrix-col
|
||||
|
||||
(let ([a (matrix [[1 2 3] [4 5 6]])])
|
||||
(check-equal? (matrix-col a 0) (col-matrix [1 4]))
|
||||
(check-equal? (matrix-col a 1) (col-matrix [2 5]))
|
||||
(check-equal? (matrix-col a 2) (col-matrix [3 6]))
|
||||
(check-exn exn:fail? (λ () (matrix-col a -1)))
|
||||
(check-exn exn:fail? (λ () (matrix-col a 3))))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-col a 0))))
|
||||
|
||||
;; matrix-rows
|
||||
|
||||
(check-equal? (matrix-rows (matrix [[1 2 3] [4 5 6]]))
|
||||
(list (row-matrix [1 2 3])
|
||||
(row-matrix [4 5 6])))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-rows a))))
|
||||
|
||||
;; matrix-cols
|
||||
|
||||
(check-equal? (matrix-cols (matrix [[1 2 3] [4 5 6]]))
|
||||
(list (col-matrix [1 4])
|
||||
(col-matrix [2 5])
|
||||
(col-matrix [3 6])))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-cols a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Embiggenment (it's a perfectly cromulent word)
|
||||
|
||||
;; matrix-augment
|
||||
|
||||
(let ([a (random-matrix 3 5)])
|
||||
(check-equal? (matrix-augment (list a)) a)
|
||||
(check-equal? (matrix-augment (matrix-cols a)) a))
|
||||
|
||||
(check-equal? (matrix-augment (list (col-matrix [1 2 3]) (col-matrix [4 5 6])))
|
||||
(matrix [[1 4] [2 5] [3 6]]))
|
||||
|
||||
(check-equal? (matrix-augment (list (matrix [[1 2] [4 5]]) (col-matrix [3 6])))
|
||||
(matrix [[1 2 3] [4 5 6]]))
|
||||
|
||||
(check-exn exn:fail? (λ () (matrix-augment (list (matrix [[1 2] [4 5]]) (col-matrix [3])))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-augment '())))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-augment (list a))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-augment (list (matrix [[1]]) a)))))
|
||||
|
||||
;; matrix-stack
|
||||
|
||||
(let ([a (random-matrix 5 3)])
|
||||
(check-equal? (matrix-stack (list a)) a)
|
||||
(check-equal? (matrix-stack (matrix-rows a)) a))
|
||||
|
||||
(check-equal? (matrix-stack (list (row-matrix [1 2 3]) (row-matrix [4 5 6])))
|
||||
(matrix [[1 2 3] [4 5 6]]))
|
||||
|
||||
(check-equal? (matrix-stack (list (matrix [[1 2 3] [4 5 6]]) (row-matrix [7 8 9])))
|
||||
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||
|
||||
(check-exn exn:fail? (λ () (matrix-stack (list (matrix [[1 2 3] [4 5 6]]) (row-matrix [7 8])))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-stack '())))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-stack (list a))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-stack (list (matrix [[1]]) a)))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Inner product space
|
||||
|
||||
;; matrix-norm
|
||||
|
||||
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]))
|
||||
(sqrt (+ (* 1 1) (* 2 2) (* 3 3) (* 4 4) (* 5 5) (* 6 6))))
|
||||
|
||||
;; Default norm is Frobenius norm
|
||||
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]))
|
||||
(matrix-norm (matrix [[1 2 3] [4 5 6]]) 2))
|
||||
|
||||
;; This shouldn't overflow (so we check against `flhypot', which also shouldn't overflow)
|
||||
(check-equal? (matrix-norm (matrix [[1e200 1e199]]))
|
||||
(flhypot 1e200 1e199))
|
||||
|
||||
;; Taxicab (Manhattan) norm
|
||||
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) 1)
|
||||
(+ 1 2 3 4 5 6))
|
||||
|
||||
;; Infinity (maximum) norm
|
||||
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) +inf.0)
|
||||
(max 1 2 3 4 5 6))
|
||||
|
||||
;; The actual norm is indistinguishable from floating-point 6
|
||||
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) 1000)
|
||||
6.0)
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-norm a 1)))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-norm a)))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-norm a 5)))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-norm a +inf.0))))
|
||||
|
||||
(check-equal? (matrix-norm (row-matrix [1+1i]))
|
||||
(sqrt 2))
|
||||
|
||||
(check-equal? (matrix-norm (row-matrix [1+1i 2+2i 3+3i]))
|
||||
(matrix-norm (row-matrix [(magnitude 1+1i) (magnitude 2+2i) (magnitude 3+3i)])))
|
||||
|
||||
;; matrix-dot (induces the Frobenius norm)
|
||||
|
||||
(check-equal? (matrix-dot (matrix [[1 -2 3] [-4 5 -6]])
|
||||
(matrix [[-1 2 -3] [4 -5 6]]))
|
||||
(+ (* 1 -1) (* -2 2) (* 3 -3) (* -4 4) (* 5 -5) (* -6 6)))
|
||||
|
||||
(check-equal? (matrix-dot (row-matrix [1 2 3])
|
||||
(row-matrix [0+4i 0-5i 0+6i]))
|
||||
(+ (* 1 0-4i) (* 2 0+5i) (* 3 0-6i)))
|
||||
|
||||
(check-exn exn:fail? (λ () (matrix-dot (random-matrix 1 3) (random-matrix 3 1))))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-dot a (matrix [[1]]))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-dot (matrix [[1]]) a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Simple operators
|
||||
|
||||
;; matrix-transpose
|
||||
|
||||
(check-equal? (matrix-transpose (matrix [[1 2 3] [4 5 6]]))
|
||||
(matrix [[1 4] [2 5] [3 6]]))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-transpose a))))
|
||||
|
||||
;; matrix-conjugate
|
||||
|
||||
(check-equal? (matrix-conjugate (matrix [[1+i 2-i] [3+i 4-i]]))
|
||||
(matrix [[1-i 2+i] [3-i 4+i]]))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-conjugate a))))
|
||||
|
||||
;; matrix-hermitian
|
||||
|
||||
(let ([a (array-make-rectangular (random-matrix 5 6)
|
||||
(random-matrix 5 6))])
|
||||
(check-equal? (matrix-hermitian a)
|
||||
(matrix-conjugate (matrix-transpose a)))
|
||||
(check-equal? (matrix-hermitian a)
|
||||
(matrix-transpose (matrix-conjugate a))))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-hermitian a))))
|
||||
|
||||
;; matrix-trace
|
||||
|
||||
(check-equal? (matrix-trace (matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||
(+ 1 5 9))
|
||||
|
||||
(check-exn exn:fail:contract? (λ () (matrix-trace (row-matrix [1 2 3]))))
|
||||
(check-exn exn:fail:contract? (λ () (matrix-trace (col-matrix [1 2 3]))))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-trace a))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Tests not yet converted to rackunit
|
||||
|
||||
(begin
|
||||
|
||||
(begin
|
||||
"matrix-operations.rkt"
|
||||
(list 'vandermonde-matrix
|
||||
(equal? (vandermonde-matrix '(1 2 3) 5)
|
||||
(list*->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]])))
|
||||
#;
|
||||
(list 'in-column
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 0)])
|
||||
x)
|
||||
'(1 3))
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 1)])
|
||||
x)
|
||||
'(2 4))
|
||||
(equal? (for/list: : (Listof Number) ([x (in-column (col-matrix [5 2 3]))]) x)
|
||||
'(5 2 3)))
|
||||
#;
|
||||
(list 'in-row
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 0)])
|
||||
x)
|
||||
'(1 2))
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 1)])
|
||||
x)
|
||||
'(3 4)))
|
||||
(list 'for/matrix:
|
||||
(equal? (for/matrix: : Number 2 4 ([i (in-naturals)]) i)
|
||||
(matrix [[0 1 2 3] [4 5 6 7]]))
|
||||
(equal? (for/matrix: : Number 2 4 #:column ([i (in-naturals)]) i)
|
||||
(matrix [[0 2 4 6] [1 3 5 7]]))
|
||||
(equal? (for/matrix: : Number 3 3 ([i (in-range 10 100)]) i)
|
||||
(matrix [[10 11 12] [13 14 15] [16 17 18]])))
|
||||
(list 'for*/matrix:
|
||||
(equal? (for*/matrix: : Number 3 3 ([i (in-range 3)] [j (in-range 3)]) (+ (* i 10) j))
|
||||
(matrix [[0 1 2] [10 11 12] [20 21 22]])))
|
||||
(list 'matrix-block-diagonal
|
||||
(equal? (block-diagonal-matrix (list (matrix [[1 2] [3 4]]) (matrix [[5 6 7]])))
|
||||
(list*->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]])))
|
||||
(list 'matrix-augment
|
||||
(equal? (matrix-augment (list (col-matrix [1 2 3])
|
||||
(col-matrix [4 5 6])
|
||||
(col-matrix [7 8 9])))
|
||||
(matrix [[1 4 7] [2 5 8] [3 6 9]])))
|
||||
(list 'matrix-stack
|
||||
(equal? (matrix-stack (list (col-matrix [1 2 3])
|
||||
(col-matrix [4 5 6])
|
||||
(col-matrix [7 8 9])))
|
||||
(col-matrix [1 2 3 4 5 6 7 8 9])))
|
||||
#;
|
||||
(list 'column-dimension
|
||||
(= (column-dimension #(1 2 3)) 3)
|
||||
|
@ -206,8 +687,6 @@
|
|||
(+ (* 1 4) (* 2 5) (* 3 6)))
|
||||
(= (column-dot (col-matrix [+3i +4i]) (col-matrix [+3i +4i]))
|
||||
25)))
|
||||
(list 'matrix-trace
|
||||
(equal? (matrix-trace (vector->matrix 2 2 #(1 2 3 4))) 5))
|
||||
(let ([matrix: vector->matrix])
|
||||
(list 'column-norm
|
||||
(= (column-norm (col-matrix [2 4])) (sqrt 20))))
|
||||
|
@ -286,15 +765,6 @@
|
|||
[9 10 -11 12]
|
||||
[13 14 15 16]]))
|
||||
5280))
|
||||
(list 'matrix-scale
|
||||
(equal? (matrix-scale (list*->matrix '[[1 2] [3 4]]) 2)
|
||||
(list*->matrix '[[2 4] [6 8]])))
|
||||
(list 'matrix-transpose
|
||||
(equal? (matrix-transpose (list*->matrix '[[1 2] [3 4]]))
|
||||
(list*->matrix '[[1 3] [2 4]])))
|
||||
(list 'matrix-hermitian
|
||||
(equal? (matrix-hermitian (list*->matrix '[[1+i 2-i] [3+i 4-i]]))
|
||||
(list*->matrix '[[1-i 3-i] [2+i 4+i]])))
|
||||
(let ()
|
||||
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
|
||||
(define (gauss-eliminate M u? p?)
|
||||
|
@ -366,75 +836,60 @@
|
|||
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0)
|
||||
(equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1)
|
||||
(equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0))
|
||||
#;(let ()
|
||||
(define-values (c1 n1)
|
||||
(matrix-column+null-space (list*rix '[[0 0] [0 0]])))
|
||||
(define-values (c2 n2)
|
||||
(matrix-column+null-space (list*->matrix '[[1 2] [2 4]])))
|
||||
(define-values (c3 n3)
|
||||
(matrix-column+null-space (list*atrix '[[1 2] [2 5]])))
|
||||
(list
|
||||
'matrix-column+null-space
|
||||
(equal? c1 '())
|
||||
(equal? n1 (list (list*->matrix '[[0] [0]])
|
||||
(list*->matrix '[[0] [0]])))
|
||||
(equal? c2 (list (list*->matrix '[[1] [2]])))
|
||||
;(equal? n2 '([0 0]))
|
||||
(equal? c3 (list (list*->matrix '[[1] [2]])
|
||||
(list*->matrix '[[2] [5]])))
|
||||
(equal? n3 '()))))
|
||||
#;
|
||||
(let ()
|
||||
(define-values (c1 n1)
|
||||
(matrix-column+null-space (list*rix '[[0 0] [0 0]])))
|
||||
(define-values (c2 n2)
|
||||
(matrix-column+null-space (list*->matrix '[[1 2] [2 4]])))
|
||||
(define-values (c3 n3)
|
||||
(matrix-column+null-space (list*atrix '[[1 2] [2 5]])))
|
||||
(list
|
||||
'matrix-column+null-space
|
||||
(equal? c1 '())
|
||||
(equal? n1 (list (list*->matrix '[[0] [0]])
|
||||
(list*->matrix '[[0] [0]])))
|
||||
(equal? c2 (list (list*->matrix '[[1] [2]])))
|
||||
;(equal? n2 '([0 0]))
|
||||
(equal? c3 (list (list*->matrix '[[1] [2]])
|
||||
(list*->matrix '[[2] [5]])))
|
||||
(equal? n3 '()))))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#;(begin
|
||||
"matrix-multiply.rkt"
|
||||
(list 'matrix*
|
||||
(let ()
|
||||
(define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6] [7 8]] '[[19 22] [43 50]]))
|
||||
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))
|
||||
(let ()
|
||||
(define-values (A B AB) (values '[[1 2] [3 4]]
|
||||
'[[5 6 7] [8 9 10]]
|
||||
'[[21 24 27] [47 54 61]]))
|
||||
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))))
|
||||
#;(begin
|
||||
"matrix-2d.rkt"
|
||||
(let ()
|
||||
(define e1 (matrix-transpose (vector->matrix #(#( 1 0)))))
|
||||
(define e2 (matrix-transpose (vector->matrix #(#( 0 1)))))
|
||||
(define -e1 (matrix-transpose (vector->matrix #(#(-1 0)))))
|
||||
(define -e2 (matrix-transpose (vector->matrix #(#( 0 -1)))))
|
||||
(define O (matrix-transpose (vector->matrix #(#( 0 0)))))
|
||||
(define 2*e1 (matrix-scale e1 2))
|
||||
(define 4*e1 (matrix-scale e1 4))
|
||||
(define 3*e2 (matrix-scale e2 3))
|
||||
(define 4*e2 (matrix-scale e2 4))
|
||||
(begin
|
||||
(list 'matrix-2d-rotation
|
||||
(<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0)
|
||||
(<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e2) -e1)) epsilon.0))
|
||||
(list
|
||||
'matrix-2d-scaling
|
||||
(equal? (matrix* (matrix-2d-scaling 2 3) (matrix+ e1 e2)) (matrix+ 2*e1 3*e2)))
|
||||
(list
|
||||
'matrix-2d-shear-x
|
||||
(equal? (matrix* (matrix-2d-shear-x 3) (matrix+ e1 e2)) (matrix+ 4*e1 e2)))
|
||||
(list
|
||||
'matrix-2d-shear-y
|
||||
(equal? (matrix* (matrix-2d-shear-y 3) (matrix+ e1 e2)) (matrix+ e1 4*e2)))
|
||||
(list
|
||||
'matrix-2d-reflection
|
||||
(equal? (matrix* (matrix-2d-reflection 0 1) e1) -e1)
|
||||
(equal? (matrix* (matrix-2d-reflection 0 1) e2) e2)
|
||||
(equal? (matrix* (matrix-2d-reflection 1 0) e1) e1)
|
||||
(equal? (matrix* (matrix-2d-reflection 1 0) e2) -e2))
|
||||
(list
|
||||
'matrix-2d-orthogonal-projection
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e1) e1)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2))))))
|
||||
#;
|
||||
(begin
|
||||
"matrix-2d.rkt"
|
||||
(let ()
|
||||
(define e1 (matrix-transpose (vector->matrix #(#( 1 0)))))
|
||||
(define e2 (matrix-transpose (vector->matrix #(#( 0 1)))))
|
||||
(define -e1 (matrix-transpose (vector->matrix #(#(-1 0)))))
|
||||
(define -e2 (matrix-transpose (vector->matrix #(#( 0 -1)))))
|
||||
(define O (matrix-transpose (vector->matrix #(#( 0 0)))))
|
||||
(define 2*e1 (matrix-scale e1 2))
|
||||
(define 4*e1 (matrix-scale e1 4))
|
||||
(define 3*e2 (matrix-scale e2 3))
|
||||
(define 4*e2 (matrix-scale e2 4))
|
||||
(begin
|
||||
(list 'matrix-2d-rotation
|
||||
(<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0)
|
||||
(<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e2) -e1)) epsilon.0))
|
||||
(list
|
||||
'matrix-2d-scaling
|
||||
(equal? (matrix* (matrix-2d-scaling 2 3) (matrix+ e1 e2)) (matrix+ 2*e1 3*e2)))
|
||||
(list
|
||||
'matrix-2d-shear-x
|
||||
(equal? (matrix* (matrix-2d-shear-x 3) (matrix+ e1 e2)) (matrix+ 4*e1 e2)))
|
||||
(list
|
||||
'matrix-2d-shear-y
|
||||
(equal? (matrix* (matrix-2d-shear-y 3) (matrix+ e1 e2)) (matrix+ e1 4*e2)))
|
||||
(list
|
||||
'matrix-2d-reflection
|
||||
(equal? (matrix* (matrix-2d-reflection 0 1) e1) -e1)
|
||||
(equal? (matrix* (matrix-2d-reflection 0 1) e2) e2)
|
||||
(equal? (matrix* (matrix-2d-reflection 1 0) e1) e1)
|
||||
(equal? (matrix* (matrix-2d-reflection 1 0) e2) -e2))
|
||||
(list
|
||||
'matrix-2d-orthogonal-projection
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e1) e1)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2))))))
|
||||
|
|
Loading…
Reference in New Issue
Block a user