Moar `math/matrix' review/refactoring
* Split "matrix-constructors.rkt" into three parts: * "matrix-constructors.rkt" * "matrix-conversion.rkt" * "matrix-syntax.rkt" * Made `matrix-map' automatically inline (it's dirt simple) * Renamed a few things, changed some type signatures * Fixed error in `matrix-dot' caught by testing (it was broadcasting) * Rewrote matrix comprehensions in terms of array comprehensions * Removed `in-column' and `in-row' (can use `in-array', `matrix-col' and `matrix-row') * Tons of new rackunit tests: only "matrix-2d.rkt" and "matrix-operations.rkt" are left (though the latter is large)
This commit is contained in:
parent
155ec7dc41
commit
8d5a069d41
|
@ -2,21 +2,24 @@
|
||||||
|
|
||||||
(require "private/matrix/matrix-arithmetic.rkt"
|
(require "private/matrix/matrix-arithmetic.rkt"
|
||||||
"private/matrix/matrix-constructors.rkt"
|
"private/matrix/matrix-constructors.rkt"
|
||||||
|
"private/matrix/matrix-conversion.rkt"
|
||||||
|
"private/matrix/matrix-syntax.rkt"
|
||||||
"private/matrix/matrix-basic.rkt"
|
"private/matrix/matrix-basic.rkt"
|
||||||
"private/matrix/matrix-operations.rkt"
|
"private/matrix/matrix-operations.rkt"
|
||||||
"private/matrix/matrix-comprehension.rkt"
|
"private/matrix/matrix-comprehension.rkt"
|
||||||
"private/matrix/matrix-sequences.rkt"
|
|
||||||
"private/matrix/matrix-expt.rkt"
|
"private/matrix/matrix-expt.rkt"
|
||||||
"private/matrix/matrix-types.rkt"
|
"private/matrix/matrix-types.rkt"
|
||||||
|
"private/matrix/matrix-2d.rkt"
|
||||||
"private/matrix/utils.rkt")
|
"private/matrix/utils.rkt")
|
||||||
|
|
||||||
(provide (all-from-out
|
(provide (all-from-out
|
||||||
"private/matrix/matrix-arithmetic.rkt"
|
"private/matrix/matrix-arithmetic.rkt"
|
||||||
"private/matrix/matrix-constructors.rkt"
|
"private/matrix/matrix-constructors.rkt"
|
||||||
|
"private/matrix/matrix-conversion.rkt"
|
||||||
|
"private/matrix/matrix-syntax.rkt"
|
||||||
"private/matrix/matrix-basic.rkt"
|
"private/matrix/matrix-basic.rkt"
|
||||||
"private/matrix/matrix-operations.rkt"
|
"private/matrix/matrix-operations.rkt"
|
||||||
"private/matrix/matrix-comprehension.rkt"
|
"private/matrix/matrix-comprehension.rkt"
|
||||||
"private/matrix/matrix-sequences.rkt"
|
|
||||||
"private/matrix/matrix-expt.rkt"
|
"private/matrix/matrix-expt.rkt"
|
||||||
"private/matrix/matrix-types.rkt")
|
"private/matrix/matrix-types.rkt"
|
||||||
matrix?)
|
"private/matrix/matrix-2d.rkt"))
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
(require math/array
|
(require math/array
|
||||||
"matrix-types.rkt"
|
"matrix-types.rkt"
|
||||||
"matrix-constructors.rkt")
|
"matrix-syntax.rkt")
|
||||||
|
|
||||||
(provide matrix-2d-rotation
|
(provide matrix-2d-rotation
|
||||||
matrix-2d-scaling
|
matrix-2d-scaling
|
||||||
|
|
|
@ -19,21 +19,18 @@
|
||||||
[(_ . es) (syntax/loc inner-stx (typed:fun . es))]
|
[(_ . es) (syntax/loc inner-stx (typed:fun . es))]
|
||||||
[_ (syntax/loc inner-stx typed:fun)])))]))
|
[_ (syntax/loc inner-stx typed:fun)])))]))
|
||||||
|
|
||||||
(define/inline-macro matrix* (a . as) inline-matrix* typed:matrix*)
|
(define/inline-macro matrix* (a as ...) inline-matrix* typed:matrix*)
|
||||||
(define/inline-macro matrix+ (a . as) inline-matrix+ typed:matrix+)
|
(define/inline-macro matrix+ (a as ...) inline-matrix+ typed:matrix+)
|
||||||
(define/inline-macro matrix- (a . as) inline-matrix- typed:matrix-)
|
(define/inline-macro matrix- (a as ...) inline-matrix- typed:matrix-)
|
||||||
(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale)
|
(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale)
|
||||||
|
|
||||||
|
(define/inline-macro do-matrix-map (f a as ...) inline-matrix-map matrix-map)
|
||||||
|
|
||||||
(provide
|
(provide
|
||||||
;; Equality
|
(rename-out [do-matrix-map matrix-map]
|
||||||
(rename-out [typed:matrix= matrix=])
|
[typed:matrix= matrix=]
|
||||||
;; Mapping
|
[typed:matrix-sum matrix-sum])
|
||||||
inline-matrix-map
|
|
||||||
matrix-map
|
|
||||||
;; Multiplication
|
|
||||||
matrix*
|
matrix*
|
||||||
;; Pointwise operators
|
|
||||||
matrix+
|
matrix+
|
||||||
matrix-
|
matrix-
|
||||||
matrix-scale
|
matrix-scale)
|
||||||
(rename-out [typed:matrix-sum matrix-sum]))
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
math/array
|
math/array
|
||||||
math/flonum
|
math/flonum
|
||||||
"matrix-types.rkt"
|
"matrix-types.rkt"
|
||||||
|
"matrix-arithmetic.rkt"
|
||||||
"utils.rkt"
|
"utils.rkt"
|
||||||
"../unsafe.rkt")
|
"../unsafe.rkt")
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@
|
||||||
matrix-rows
|
matrix-rows
|
||||||
matrix-cols
|
matrix-cols
|
||||||
;; Predicates
|
;; Predicates
|
||||||
zero-matrix?
|
matrix-zero?
|
||||||
;; Embiggenment
|
;; Embiggenment
|
||||||
matrix-augment
|
matrix-augment
|
||||||
matrix-stack
|
matrix-stack
|
||||||
|
@ -73,7 +74,7 @@
|
||||||
(unsafe-vector-set! ij 0 0)
|
(unsafe-vector-set! ij 0 0)
|
||||||
res))]))
|
res))]))
|
||||||
|
|
||||||
(: matrix-col (All (A) (Matrix A) Index -> (Matrix A)))
|
(: matrix-col (All (A) (Matrix A) Integer -> (Matrix A)))
|
||||||
(define (matrix-col a j)
|
(define (matrix-col a j)
|
||||||
(define-values (m n) (matrix-shape a))
|
(define-values (m n) (matrix-shape a))
|
||||||
(cond [(or (j . < . 0) (j . >= . n))
|
(cond [(or (j . < . 0) (j . >= . n))
|
||||||
|
@ -99,9 +100,9 @@
|
||||||
;; ===================================================================================================
|
;; ===================================================================================================
|
||||||
;; Predicates
|
;; Predicates
|
||||||
|
|
||||||
(: zero-matrix? ((Array Number) -> Boolean))
|
(: matrix-zero? ((Array Number) -> Boolean))
|
||||||
(define (zero-matrix? a)
|
(define (matrix-zero? a)
|
||||||
(array-all-and (array-map zero? a)))
|
(array-all-and (matrix-map zero? a)))
|
||||||
|
|
||||||
;; ===================================================================================================
|
;; ===================================================================================================
|
||||||
;; Embiggenment (this is a perfectly cromulent word)
|
;; Embiggenment (this is a perfectly cromulent word)
|
||||||
|
@ -179,9 +180,14 @@
|
||||||
((Array Number) (Array Number) -> Number)))
|
((Array Number) (Array Number) -> Number)))
|
||||||
;; Computes the Frobenius inner product of two matrices
|
;; Computes the Frobenius inner product of two matrices
|
||||||
(define (matrix-dot a b)
|
(define (matrix-dot a b)
|
||||||
(cond [(not (matrix? a)) (raise-argument-error 'matrix-dot "matrix?" 0 a b)]
|
(define-values (m n) (matrix-shapes 'matrix-dot a b))
|
||||||
[(not (matrix? b)) (raise-argument-error 'matrix-dot "matrix?" 1 a b)]
|
(define aproc (unsafe-array-proc a))
|
||||||
[else (array-all-sum (array* a (array-conjugate b)))]))
|
(define bproc (unsafe-array-proc b))
|
||||||
|
(array-all-sum
|
||||||
|
(unsafe-build-array
|
||||||
|
((inst vector Index) m n)
|
||||||
|
(λ: ([js : Indexes])
|
||||||
|
(* (aproc js) (conjugate (bproc js)))))))
|
||||||
|
|
||||||
;; ===================================================================================================
|
;; ===================================================================================================
|
||||||
;; Operators
|
;; Operators
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
(require math/array
|
(require math/array
|
||||||
math/base
|
math/base
|
||||||
"matrix-types.rkt"
|
"matrix-types.rkt"
|
||||||
"matrix-constructors.rkt"
|
"matrix-conversion.rkt"
|
||||||
"matrix-arithmetic.rkt"
|
"matrix-arithmetic.rkt"
|
||||||
"../unsafe.rkt")
|
"../unsafe.rkt")
|
||||||
|
|
||||||
|
|
|
@ -1,140 +1,63 @@
|
||||||
#lang racket
|
#lang racket/base
|
||||||
|
|
||||||
(require math/array
|
(require (for-syntax racket/base
|
||||||
typed/racket/base
|
syntax/parse)
|
||||||
"matrix-types.rkt"
|
math/array)
|
||||||
"matrix-constructors.rkt")
|
|
||||||
|
|
||||||
(provide for/matrix
|
(provide for/matrix:
|
||||||
for*/matrix
|
for*/matrix:
|
||||||
for/matrix:
|
for/matrix
|
||||||
for*/matrix:)
|
for*/matrix)
|
||||||
|
|
||||||
;;; COMPREHENSIONS
|
(module typed-defs typed/racket/base
|
||||||
|
(require (for-syntax racket/base
|
||||||
|
syntax/parse)
|
||||||
|
math/array)
|
||||||
|
|
||||||
; (for/matrix m n (clause ...) . defs+exprs)
|
(provide (all-defined-out))
|
||||||
; Return an m x n matrix with elements from the last expr.
|
|
||||||
; The first n values produced becomes the first row.
|
(: ensure-matrix-dims (Symbol Any Any -> (Values Positive-Index Positive-Index)))
|
||||||
; The next n values becomes the second row and so on.
|
(define (ensure-matrix-dims name m n)
|
||||||
; The bindings in clauses run in parallel.
|
(cond [(or (not (index? m)) (zero? m)) (raise-argument-error name "Positive-Index" 0 m n)]
|
||||||
(define-syntax (for/matrix stx)
|
[(or (not (index? n)) (zero? n)) (raise-argument-error name "Positive-Index" 1 m n)]
|
||||||
(syntax-case stx ()
|
[else (values m n)]))
|
||||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
|
||||||
|
(define-syntax (base-for/matrix: stx)
|
||||||
|
(syntax-parse stx #:literals (:)
|
||||||
|
[(_ name:id for/array:id
|
||||||
|
m-expr:expr n-expr:expr
|
||||||
|
(~optional (~seq #:fill fill-expr:expr))
|
||||||
|
(clause ...)
|
||||||
|
(~optional (~seq : A:expr))
|
||||||
|
body:expr ...+)
|
||||||
|
(with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())]
|
||||||
|
[(maybe-type ...) (if (attribute A) #'(: A) #'())])
|
||||||
(syntax/loc stx
|
(syntax/loc stx
|
||||||
(let ([m m-expr] [n n-expr])
|
(let-values ([(m n) (ensure-matrix-dims 'name
|
||||||
(define flat-vector
|
(ann m-expr Integer)
|
||||||
(for/vector #:length (* m n)
|
(ann n-expr Integer))])
|
||||||
(clause ...) . defs+exprs))
|
(for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...) maybe-type ...
|
||||||
(vector->matrix m n flat-vector)))]))
|
body ...))))]))
|
||||||
|
|
||||||
; (for*/matrix m n (clause ...) . defs+exprs)
|
(define-syntax-rule (for/matrix: e ...) (base-for/matrix: for/matrix: for/array: e ...))
|
||||||
; Return an m x n matrix with elements from the last expr.
|
(define-syntax-rule (for*/matrix: e ...) (base-for/matrix: for*/matrix: for*/array: e ...))
|
||||||
; The first n values produced becomes the first row.
|
|
||||||
; The next n values becomes the second row and so on.
|
|
||||||
; The bindings in clauses run nested.
|
|
||||||
; (for*/matrix m n #:column (clause ...) . defs+exprs)
|
|
||||||
; Return an m x n matrix with elements from the last expr.
|
|
||||||
; The first m values produced becomes the first column.
|
|
||||||
; The next m values becomes the second column and so on.
|
|
||||||
; The bindings in clauses run nested.
|
|
||||||
(define-syntax (for*/matrix stx)
|
|
||||||
(syntax-case stx ()
|
|
||||||
[(_ m-expr n-expr #:column (clause ...) . defs+exprs)
|
|
||||||
(syntax/loc stx
|
|
||||||
(let* ([m m-expr]
|
|
||||||
[n n-expr]
|
|
||||||
[v (make-vector (* m n) 0)]
|
|
||||||
[w (for*/vector #:length (* m n) (clause ...) . defs+exprs)])
|
|
||||||
(for* ([i (in-range m)] [j (in-range n)])
|
|
||||||
(vector-set! v (+ (* i n) j)
|
|
||||||
(vector-ref w (+ (* j m) i))))
|
|
||||||
(vector->matrix m n v)))]
|
|
||||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
|
||||||
(syntax/loc stx
|
|
||||||
(let ([m m-expr] [n n-expr])
|
|
||||||
(vector->matrix
|
|
||||||
m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))]))
|
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
(define-syntax (for/column: stx)
|
(require (submod "." typed-defs))
|
||||||
(syntax-case stx ()
|
|
||||||
[(_ : type m-expr (for:-clause ...) . defs+exprs)
|
|
||||||
(syntax/loc stx
|
|
||||||
(let ()
|
|
||||||
(define: m : Index m-expr)
|
|
||||||
(define: flat-vector : (Vectorof Number) (make-vector m 0))
|
|
||||||
(for: ([i (in-range m)] for:-clause ...)
|
|
||||||
(define x (let () . defs+exprs))
|
|
||||||
(vector-set! flat-vector i x))
|
|
||||||
(vector->col-matrix flat-vector)))]))
|
|
||||||
|
|
||||||
(define-syntax (for/matrix: stx)
|
(define-syntax (base-for/matrix stx)
|
||||||
(syntax-case stx ()
|
(syntax-parse stx
|
||||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
[(_ name:id for/array:id
|
||||||
|
m-expr:expr n-expr:expr
|
||||||
|
(~optional (~seq #:fill fill-expr:expr))
|
||||||
|
(clause ...)
|
||||||
|
body:expr ...+)
|
||||||
|
(with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())])
|
||||||
(syntax/loc stx
|
(syntax/loc stx
|
||||||
(let ()
|
(let-values ([(m n) (ensure-matrix-dims 'name m-expr n-expr)])
|
||||||
(define: m : Index m-expr)
|
(for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...)
|
||||||
(define: n : Index n-expr)
|
body ...))))]))
|
||||||
(define: m*n : Index (assert (* m n) index?))
|
|
||||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
|
||||||
(define: k : Index 0)
|
|
||||||
(for: ([i (in-range m*n)] for:-clause ...)
|
|
||||||
(define x (let () . defs+exprs))
|
|
||||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
|
||||||
(set! k (assert (+ k 1) index?)))
|
|
||||||
(vector->matrix m n v)))]
|
|
||||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
|
||||||
(syntax/loc stx
|
|
||||||
(let ()
|
|
||||||
(define: m : Index m-expr)
|
|
||||||
(define: n : Index n-expr)
|
|
||||||
(define: m*n : Index (assert (* m n) index?))
|
|
||||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
|
||||||
(for: ([i (in-range m*n)] for:-clause ...)
|
|
||||||
(define x (let () . defs+exprs))
|
|
||||||
(vector-set! v i x))
|
|
||||||
(vector->matrix m n v)))]))
|
|
||||||
|
|
||||||
(define-syntax (for*/matrix: stx)
|
(define-syntax-rule (for/matrix e ...) (base-for/matrix for/matrix for/array e ...))
|
||||||
(syntax-case stx ()
|
(define-syntax-rule (for*/matrix e ...) (base-for/matrix for*/matrix for*/array e ...))
|
||||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
|
||||||
(syntax/loc stx
|
|
||||||
(let ()
|
|
||||||
(define: m : Index m-expr)
|
|
||||||
(define: n : Index n-expr)
|
|
||||||
(define: m*n : Index (assert (* m n) index?))
|
|
||||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
|
||||||
(define: k : Index 0)
|
|
||||||
(for*: (for:-clause ...)
|
|
||||||
(define x (let () . defs+exprs))
|
|
||||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
|
||||||
(set! k (assert (+ k 1) index?)))
|
|
||||||
(vector->matrix m n v)))]
|
|
||||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
|
||||||
(syntax/loc stx
|
|
||||||
(let ()
|
|
||||||
(define: m : Index m-expr)
|
|
||||||
(define: n : Index n-expr)
|
|
||||||
(define: m*n : Index (assert (* m n) index?))
|
|
||||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
|
||||||
(define: i : Index 0)
|
|
||||||
(for*: (for:-clause ...)
|
|
||||||
(define x (let () . defs+exprs))
|
|
||||||
(vector-set! v i x)
|
|
||||||
(set! i (assert (+ i 1) index?)))
|
|
||||||
(vector->matrix m n v)))]))
|
|
||||||
#;
|
|
||||||
(module* test #f
|
|
||||||
(require rackunit)
|
|
||||||
; "matrix-sequences.rkt"
|
|
||||||
; These work in racket not in typed racket
|
|
||||||
(check-equal? (matrix->list* (for*/matrix 2 3 ([i 2] [j 3]) (+ i j)))
|
|
||||||
'[[0 1 2] [1 2 3]])
|
|
||||||
(check-equal? (matrix->list* (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j)))
|
|
||||||
'[[0 2 2] [1 1 3]])
|
|
||||||
(check-equal? (matrix->list* (for*/matrix 2 2 #:column ([i 4]) i))
|
|
||||||
'[[0 2] [1 3]])
|
|
||||||
(check-equal? (matrix->list* (for/matrix 2 2 ([i 4]) i))
|
|
||||||
'[[0 1] [2 3]])
|
|
||||||
(check-equal? (matrix->list* (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j)))
|
|
||||||
'[[6 8 10] [12 14 16]]))
|
|
||||||
|
|
|
@ -1,46 +1,23 @@
|
||||||
#lang racket/base
|
#lang typed/racket/base
|
||||||
|
|
||||||
(provide
|
(require racket/fixnum
|
||||||
;; Constructors
|
racket/list
|
||||||
identity-matrix
|
racket/vector
|
||||||
|
math/array
|
||||||
|
"matrix-types.rkt"
|
||||||
|
"../unsafe.rkt")
|
||||||
|
|
||||||
|
(provide identity-matrix
|
||||||
make-matrix
|
make-matrix
|
||||||
build-matrix
|
build-matrix
|
||||||
diagonal-matrix/zero
|
diagonal-matrix/zero
|
||||||
diagonal-matrix
|
diagonal-matrix
|
||||||
block-diagonal-matrix/zero
|
block-diagonal-matrix/zero
|
||||||
block-diagonal-matrix
|
block-diagonal-matrix
|
||||||
vandermonde-matrix
|
vandermonde-matrix)
|
||||||
;; Basic conversion
|
|
||||||
list->matrix
|
|
||||||
matrix->list
|
|
||||||
vector->matrix
|
|
||||||
matrix->vector
|
|
||||||
->row-matrix
|
|
||||||
->col-matrix
|
|
||||||
;; Nested conversion
|
|
||||||
list*->matrix
|
|
||||||
matrix->list*
|
|
||||||
vector*->matrix
|
|
||||||
matrix->vector*
|
|
||||||
;; Syntax
|
|
||||||
matrix
|
|
||||||
row-matrix
|
|
||||||
col-matrix)
|
|
||||||
|
|
||||||
(module typed-defs typed/racket/base
|
;; ===================================================================================================
|
||||||
(require racket/fixnum
|
;; Basic constructors
|
||||||
racket/list
|
|
||||||
racket/vector
|
|
||||||
math/array
|
|
||||||
"../array/utils.rkt"
|
|
||||||
"matrix-types.rkt"
|
|
||||||
"utils.rkt"
|
|
||||||
"../unsafe.rkt")
|
|
||||||
|
|
||||||
(provide (all-defined-out))
|
|
||||||
|
|
||||||
;; =================================================================================================
|
|
||||||
;; Constructors
|
|
||||||
|
|
||||||
(: identity-matrix (Integer -> (Matrix (U 0 1))))
|
(: identity-matrix (Integer -> (Matrix (U 0 1))))
|
||||||
(define (identity-matrix m) (diagonal-array 2 m 1 0))
|
(define (identity-matrix m) (diagonal-array 2 m 1 0))
|
||||||
|
@ -62,25 +39,30 @@
|
||||||
(proc (unsafe-vector-ref js 0)
|
(proc (unsafe-vector-ref js 0)
|
||||||
(unsafe-vector-ref js 1))))]))
|
(unsafe-vector-ref js 1))))]))
|
||||||
|
|
||||||
(: diagonal-matrix/zero (All (A) (Array A) A -> (Array A)))
|
;; ===================================================================================================
|
||||||
(define (diagonal-matrix/zero a zero)
|
;; Diagonal matrices
|
||||||
(define ds (array-shape a))
|
|
||||||
(cond [(= 1 (vector-length ds))
|
(: diagonal-matrix/zero (All (A) (Listof A) A -> (Array A)))
|
||||||
(define m (unsafe-vector-ref ds 0))
|
(define (diagonal-matrix/zero xs zero)
|
||||||
(define proc (unsafe-array-proc a))
|
(cond [(empty? xs)
|
||||||
|
(raise-argument-error 'diagonal-matrix "nonempty List" xs)]
|
||||||
|
[else
|
||||||
|
(define vs (list->vector xs))
|
||||||
|
(define m (vector-length vs))
|
||||||
(unsafe-build-array
|
(unsafe-build-array
|
||||||
((inst vector Index) m m)
|
((inst vector Index) m m)
|
||||||
(λ: ([js : Indexes])
|
(λ: ([js : Indexes])
|
||||||
(define i (unsafe-vector-ref js 0))
|
(define i (unsafe-vector-ref js 0))
|
||||||
(cond [(= i (unsafe-vector-ref js 1)) (proc ((inst vector Index) i))]
|
(cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)]
|
||||||
[else zero])))]
|
[else zero])))]))
|
||||||
[else
|
|
||||||
(raise-argument-error 'diagonal-matrix "Array with one dimension" a)]))
|
|
||||||
|
|
||||||
(: diagonal-matrix (case-> ((Array Real) -> (Array Real))
|
(: diagonal-matrix (case-> ((Listof Real) -> (Array Real))
|
||||||
((Array Number) -> (Array Number))))
|
((Listof Number) -> (Array Number))))
|
||||||
(define (diagonal-matrix a)
|
(define (diagonal-matrix xs)
|
||||||
(diagonal-matrix/zero a 0))
|
(diagonal-matrix/zero xs 0))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Block diagonal matrices
|
||||||
|
|
||||||
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
|
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
|
||||||
(define (block-diagonal-matrix/zero* as zero)
|
(define (block-diagonal-matrix/zero* as zero)
|
||||||
|
@ -148,6 +130,9 @@
|
||||||
(define (block-diagonal-matrix as)
|
(define (block-diagonal-matrix as)
|
||||||
(block-diagonal-matrix/zero as 0))
|
(block-diagonal-matrix/zero as 0))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Special matrices
|
||||||
|
|
||||||
(: expt-hack (case-> (Real Integer -> Real)
|
(: expt-hack (case-> (Real Integer -> Real)
|
||||||
(Number Integer -> Number)))
|
(Number Integer -> Number)))
|
||||||
;; Stop using this when TR correctly derives expt : Real Integer -> Real
|
;; Stop using this when TR correctly derives expt : Real Integer -> Real
|
||||||
|
@ -164,213 +149,3 @@
|
||||||
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
||||||
[else
|
[else
|
||||||
(array-axis-expand (list->array xs) 1 n expt-hack)]))
|
(array-axis-expand (list->array xs) 1 n expt-hack)]))
|
||||||
|
|
||||||
;; =================================================================================================
|
|
||||||
;; Flat conversion
|
|
||||||
|
|
||||||
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
|
|
||||||
(define (list->matrix m n xs)
|
|
||||||
(cond [(or (not (index? m)) (= m 0))
|
|
||||||
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
|
|
||||||
[(or (not (index? n)) (= n 0))
|
|
||||||
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
|
|
||||||
[else (list->array (vector m n) xs)]))
|
|
||||||
|
|
||||||
(: matrix->list (All (A) ((Array A) -> (Listof A))))
|
|
||||||
(define (matrix->list a)
|
|
||||||
(array->list (ensure-matrix 'matrix->list a)))
|
|
||||||
|
|
||||||
(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A))))
|
|
||||||
(define (vector->matrix m n v)
|
|
||||||
(cond [(or (not (index? m)) (= m 0))
|
|
||||||
(raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)]
|
|
||||||
[(or (not (index? n)) (= n 0))
|
|
||||||
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
|
|
||||||
[else (vector->array (vector m n) v)]))
|
|
||||||
|
|
||||||
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
|
|
||||||
(define (matrix->vector a)
|
|
||||||
(array->vector (ensure-matrix 'matrix->vector a)))
|
|
||||||
|
|
||||||
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
|
|
||||||
(define (list->row-matrix xs)
|
|
||||||
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
|
|
||||||
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
|
|
||||||
|
|
||||||
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
|
|
||||||
(define (list->col-matrix xs)
|
|
||||||
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
|
|
||||||
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
|
|
||||||
|
|
||||||
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
|
||||||
(define (vector->row-matrix xs)
|
|
||||||
(define n (vector-length xs))
|
|
||||||
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
|
|
||||||
[else (vector->array ((inst vector Index) 1 n) xs)]))
|
|
||||||
|
|
||||||
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
|
||||||
(define (vector->col-matrix xs)
|
|
||||||
(define n (vector-length xs))
|
|
||||||
(cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)]
|
|
||||||
[else (vector->array ((inst vector Index) n 1) xs)]))
|
|
||||||
|
|
||||||
(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index)))
|
|
||||||
(define (find-nontrivial-axis ds)
|
|
||||||
(define dims (vector-length ds))
|
|
||||||
(let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0])
|
|
||||||
(cond [(k . < . dims) (define dk (unsafe-vector-ref ds k))
|
|
||||||
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
|
|
||||||
[else (values 0 0)])))
|
|
||||||
|
|
||||||
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
|
|
||||||
(define (array->row-matrix arr)
|
|
||||||
(define (fail)
|
|
||||||
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
|
|
||||||
(define ds (array-shape arr))
|
|
||||||
(define dims (vector-length ds))
|
|
||||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
|
||||||
(cond [(zero? (array-size arr)) (fail)]
|
|
||||||
[(row-matrix? arr) arr]
|
|
||||||
[(= num-ones dims)
|
|
||||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
|
||||||
(define proc (unsafe-array-proc arr))
|
|
||||||
(unsafe-build-array ((inst vector Index) 1 1)
|
|
||||||
(λ: ([ij : Indexes]) (proc js)))]
|
|
||||||
[(= num-ones (- dims 1))
|
|
||||||
(define-values (k n) (find-nontrivial-axis ds))
|
|
||||||
(define js (make-thread-local-indexes dims))
|
|
||||||
(define proc (unsafe-array-proc arr))
|
|
||||||
(unsafe-build-array ((inst vector Index) 1 n)
|
|
||||||
(λ: ([ij : Indexes])
|
|
||||||
(let ([js (js)])
|
|
||||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
|
|
||||||
(proc js))))]
|
|
||||||
[else (fail)]))
|
|
||||||
|
|
||||||
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
|
|
||||||
(define (array->col-matrix arr)
|
|
||||||
(define (fail)
|
|
||||||
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
|
|
||||||
(define ds (array-shape arr))
|
|
||||||
(define dims (vector-length ds))
|
|
||||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
|
||||||
(cond [(zero? (array-size arr)) (fail)]
|
|
||||||
[(col-matrix? arr) arr]
|
|
||||||
[(= num-ones dims)
|
|
||||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
|
||||||
(define proc (unsafe-array-proc arr))
|
|
||||||
(unsafe-build-array ((inst vector Index) 1 1)
|
|
||||||
(λ: ([ij : Indexes]) (proc js)))]
|
|
||||||
[(= num-ones (- dims 1))
|
|
||||||
(define-values (k m) (find-nontrivial-axis ds))
|
|
||||||
(define js (make-thread-local-indexes dims))
|
|
||||||
(define proc (unsafe-array-proc arr))
|
|
||||||
(unsafe-build-array ((inst vector Index) m 1)
|
|
||||||
(λ: ([ij : Indexes])
|
|
||||||
(let ([js (js)])
|
|
||||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 0))
|
|
||||||
(proc js))))]
|
|
||||||
[else (fail)]))
|
|
||||||
|
|
||||||
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
|
||||||
(define (->row-matrix xs)
|
|
||||||
(cond [(list? xs) (list->row-matrix xs)]
|
|
||||||
[(array? xs) (array->row-matrix xs)]
|
|
||||||
[else (vector->row-matrix xs)]))
|
|
||||||
|
|
||||||
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
|
||||||
(define (->col-matrix xs)
|
|
||||||
(cond [(list? xs) (list->col-matrix xs)]
|
|
||||||
[(array? xs) (array->col-matrix xs)]
|
|
||||||
[else (vector->col-matrix xs)]))
|
|
||||||
|
|
||||||
;; =================================================================================================
|
|
||||||
;; Nested conversion
|
|
||||||
|
|
||||||
(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index)))
|
|
||||||
(define (list*-shape xss fail)
|
|
||||||
(define m (length xss))
|
|
||||||
(cond [(m . > . 0)
|
|
||||||
(define n (length (first xss)))
|
|
||||||
(cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss)))
|
|
||||||
(values m n)]
|
|
||||||
[else (fail)])]
|
|
||||||
[else (fail)]))
|
|
||||||
|
|
||||||
(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing)
|
|
||||||
-> (Values Positive-Index Positive-Index)))
|
|
||||||
(define (vector*-shape xss fail)
|
|
||||||
(define m (vector-length xss))
|
|
||||||
(cond [(m . > . 0)
|
|
||||||
(define ns ((inst vector-map Index (Vectorof A)) vector-length xss))
|
|
||||||
(define n (vector-length (unsafe-vector-ref xss 0)))
|
|
||||||
(cond [(and (n . > . 0)
|
|
||||||
(let: loop : Boolean ([i : Nonnegative-Fixnum 1])
|
|
||||||
(cond [(i . fx< . m)
|
|
||||||
(if (= n (vector-length (unsafe-vector-ref xss i)))
|
|
||||||
(loop (fx+ i 1))
|
|
||||||
#f)]
|
|
||||||
[else #t])))
|
|
||||||
(values m n)]
|
|
||||||
[else (fail)])]
|
|
||||||
[else (fail)]))
|
|
||||||
|
|
||||||
(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A)))
|
|
||||||
(define (list*->matrix xss)
|
|
||||||
(define (fail)
|
|
||||||
(raise-argument-error 'list*->matrix
|
|
||||||
"nested lists with rectangular shape and at least one matrix element"
|
|
||||||
xss))
|
|
||||||
(define-values (m n) (list*-shape xss fail))
|
|
||||||
(list->array ((inst vector Index) m n) (apply append xss)))
|
|
||||||
|
|
||||||
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
|
|
||||||
(define (matrix->list* a)
|
|
||||||
(cond [(matrix? a) (array->list (array->list-array a 1))]
|
|
||||||
[else (raise-argument-error 'matrix->list* "matrix?" a)]))
|
|
||||||
|
|
||||||
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
|
|
||||||
(define (vector*->matrix xss)
|
|
||||||
(define (fail)
|
|
||||||
(raise-argument-error 'vector*->matrix
|
|
||||||
"nested vectors with rectangular shape and at least one matrix element"
|
|
||||||
xss))
|
|
||||||
(define-values (m n) (vector*-shape xss fail))
|
|
||||||
(vector->matrix m n (apply vector-append (vector->list xss))))
|
|
||||||
|
|
||||||
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
|
|
||||||
(define (matrix->vector* a)
|
|
||||||
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))]
|
|
||||||
[else (raise-argument-error 'matrix->vector* "matrix?" a)]))
|
|
||||||
) ; module
|
|
||||||
|
|
||||||
(require (for-syntax racket/base
|
|
||||||
syntax/parse)
|
|
||||||
(only-in typed/racket/base :)
|
|
||||||
math/array
|
|
||||||
(submod "." typed-defs))
|
|
||||||
|
|
||||||
(define-syntax (matrix stx)
|
|
||||||
(syntax-parse stx #:literals (:)
|
|
||||||
[(_ [[x0 xs0 ...] [x xs ...] ...])
|
|
||||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))]
|
|
||||||
[(_ [[x0 xs0 ...] [x xs ...] ...] : T)
|
|
||||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))]
|
|
||||||
[(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T)))
|
|
||||||
(raise-syntax-error 'matrix "given empty row" stx #'r)]
|
|
||||||
[(_ (~and [] c) (~optional (~seq : T)))
|
|
||||||
(raise-syntax-error 'matrix "given empty matrix" stx #'c)]))
|
|
||||||
|
|
||||||
(define-syntax (row-matrix stx)
|
|
||||||
(syntax-parse stx #:literals (:)
|
|
||||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))]
|
|
||||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))]
|
|
||||||
[(_ (~and [] r) (~optional (~seq : T)))
|
|
||||||
(raise-syntax-error 'row-matrix "given empty row" stx #'r)]))
|
|
||||||
|
|
||||||
(define-syntax (col-matrix stx)
|
|
||||||
(syntax-parse stx #:literals (:)
|
|
||||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))]
|
|
||||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))]
|
|
||||||
[(_ (~and [] c) (~optional (~seq : T)))
|
|
||||||
(raise-syntax-error 'row-matrix "given empty column" stx #'c)]))
|
|
||||||
|
|
202
collects/math/private/matrix/matrix-conversion.rkt
Normal file
202
collects/math/private/matrix/matrix-conversion.rkt
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
#lang typed/racket/base
|
||||||
|
|
||||||
|
(require racket/fixnum
|
||||||
|
racket/list
|
||||||
|
racket/vector
|
||||||
|
math/array
|
||||||
|
"matrix-types.rkt"
|
||||||
|
"utils.rkt"
|
||||||
|
"../array/utils.rkt"
|
||||||
|
"../unsafe.rkt")
|
||||||
|
|
||||||
|
(provide
|
||||||
|
;; Flat conversion
|
||||||
|
list->matrix
|
||||||
|
matrix->list
|
||||||
|
vector->matrix
|
||||||
|
matrix->vector
|
||||||
|
->row-matrix
|
||||||
|
->col-matrix
|
||||||
|
;; Nested conversion
|
||||||
|
list*->matrix
|
||||||
|
matrix->list*
|
||||||
|
vector*->matrix
|
||||||
|
matrix->vector*)
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Flat conversion
|
||||||
|
|
||||||
|
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
|
||||||
|
(define (list->matrix m n xs)
|
||||||
|
(cond [(or (not (index? m)) (= m 0))
|
||||||
|
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
|
||||||
|
[(or (not (index? n)) (= n 0))
|
||||||
|
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
|
||||||
|
[else (list->array (vector m n) xs)]))
|
||||||
|
|
||||||
|
(: matrix->list (All (A) ((Array A) -> (Listof A))))
|
||||||
|
(define (matrix->list a)
|
||||||
|
(array->list (ensure-matrix 'matrix->list a)))
|
||||||
|
|
||||||
|
(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A))))
|
||||||
|
(define (vector->matrix m n v)
|
||||||
|
(cond [(or (not (index? m)) (= m 0))
|
||||||
|
(raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)]
|
||||||
|
[(or (not (index? n)) (= n 0))
|
||||||
|
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
|
||||||
|
[else (vector->array (vector m n) v)]))
|
||||||
|
|
||||||
|
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
|
||||||
|
(define (matrix->vector a)
|
||||||
|
(array->vector (ensure-matrix 'matrix->vector a)))
|
||||||
|
|
||||||
|
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
|
||||||
|
(define (list->row-matrix xs)
|
||||||
|
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
|
||||||
|
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
|
||||||
|
|
||||||
|
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
|
||||||
|
(define (list->col-matrix xs)
|
||||||
|
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
|
||||||
|
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
|
||||||
|
|
||||||
|
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||||
|
(define (vector->row-matrix xs)
|
||||||
|
(define n (vector-length xs))
|
||||||
|
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
|
||||||
|
[else (vector->array ((inst vector Index) 1 n) xs)]))
|
||||||
|
|
||||||
|
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||||
|
(define (vector->col-matrix xs)
|
||||||
|
(define n (vector-length xs))
|
||||||
|
(cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)]
|
||||||
|
[else (vector->array ((inst vector Index) n 1) xs)]))
|
||||||
|
|
||||||
|
(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index)))
|
||||||
|
(define (find-nontrivial-axis ds)
|
||||||
|
(define dims (vector-length ds))
|
||||||
|
(let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0])
|
||||||
|
(cond [(k . < . dims) (define dk (unsafe-vector-ref ds k))
|
||||||
|
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
|
||||||
|
[else (values 0 0)])))
|
||||||
|
|
||||||
|
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
|
||||||
|
(define (array->row-matrix arr)
|
||||||
|
(define (fail)
|
||||||
|
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||||
|
(define ds (array-shape arr))
|
||||||
|
(define dims (vector-length ds))
|
||||||
|
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||||
|
(cond [(zero? (array-size arr)) (fail)]
|
||||||
|
[(row-matrix? arr) arr]
|
||||||
|
[(= num-ones dims)
|
||||||
|
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||||
|
(define proc (unsafe-array-proc arr))
|
||||||
|
(unsafe-build-array ((inst vector Index) 1 1)
|
||||||
|
(λ: ([ij : Indexes]) (proc js)))]
|
||||||
|
[(= num-ones (- dims 1))
|
||||||
|
(define-values (k n) (find-nontrivial-axis ds))
|
||||||
|
(define js (make-thread-local-indexes dims))
|
||||||
|
(define proc (unsafe-array-proc arr))
|
||||||
|
(unsafe-build-array ((inst vector Index) 1 n)
|
||||||
|
(λ: ([ij : Indexes])
|
||||||
|
(let ([js (js)])
|
||||||
|
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
|
||||||
|
(proc js))))]
|
||||||
|
[else (fail)]))
|
||||||
|
|
||||||
|
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
|
||||||
|
(define (array->col-matrix arr)
|
||||||
|
(define (fail)
|
||||||
|
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||||
|
(define ds (array-shape arr))
|
||||||
|
(define dims (vector-length ds))
|
||||||
|
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||||
|
(cond [(zero? (array-size arr)) (fail)]
|
||||||
|
[(col-matrix? arr) arr]
|
||||||
|
[(= num-ones dims)
|
||||||
|
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||||
|
(define proc (unsafe-array-proc arr))
|
||||||
|
(unsafe-build-array ((inst vector Index) 1 1)
|
||||||
|
(λ: ([ij : Indexes]) (proc js)))]
|
||||||
|
[(= num-ones (- dims 1))
|
||||||
|
(define-values (k m) (find-nontrivial-axis ds))
|
||||||
|
(define js (make-thread-local-indexes dims))
|
||||||
|
(define proc (unsafe-array-proc arr))
|
||||||
|
(unsafe-build-array ((inst vector Index) m 1)
|
||||||
|
(λ: ([ij : Indexes])
|
||||||
|
(let ([js (js)])
|
||||||
|
(unsafe-vector-set! js k (unsafe-vector-ref ij 0))
|
||||||
|
(proc js))))]
|
||||||
|
[else (fail)]))
|
||||||
|
|
||||||
|
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||||
|
(define (->row-matrix xs)
|
||||||
|
(cond [(list? xs) (list->row-matrix xs)]
|
||||||
|
[(array? xs) (array->row-matrix xs)]
|
||||||
|
[else (vector->row-matrix xs)]))
|
||||||
|
|
||||||
|
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||||
|
(define (->col-matrix xs)
|
||||||
|
(cond [(list? xs) (list->col-matrix xs)]
|
||||||
|
[(array? xs) (array->col-matrix xs)]
|
||||||
|
[else (vector->col-matrix xs)]))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Nested conversion
|
||||||
|
|
||||||
|
(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index)))
|
||||||
|
(define (list*-shape xss fail)
|
||||||
|
(define m (length xss))
|
||||||
|
(cond [(m . > . 0)
|
||||||
|
(define n (length (first xss)))
|
||||||
|
(cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss)))
|
||||||
|
(values m n)]
|
||||||
|
[else (fail)])]
|
||||||
|
[else (fail)]))
|
||||||
|
|
||||||
|
(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing)
|
||||||
|
-> (Values Positive-Index Positive-Index)))
|
||||||
|
(define (vector*-shape xss fail)
|
||||||
|
(define m (vector-length xss))
|
||||||
|
(cond [(m . > . 0)
|
||||||
|
(define ns ((inst vector-map Index (Vectorof A)) vector-length xss))
|
||||||
|
(define n (vector-length (unsafe-vector-ref xss 0)))
|
||||||
|
(cond [(and (n . > . 0)
|
||||||
|
(let: loop : Boolean ([i : Nonnegative-Fixnum 1])
|
||||||
|
(cond [(i . fx< . m)
|
||||||
|
(if (= n (vector-length (unsafe-vector-ref xss i)))
|
||||||
|
(loop (fx+ i 1))
|
||||||
|
#f)]
|
||||||
|
[else #t])))
|
||||||
|
(values m n)]
|
||||||
|
[else (fail)])]
|
||||||
|
[else (fail)]))
|
||||||
|
|
||||||
|
(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A)))
|
||||||
|
(define (list*->matrix xss)
|
||||||
|
(define (fail)
|
||||||
|
(raise-argument-error 'list*->matrix
|
||||||
|
"nested lists with rectangular shape and at least one matrix element"
|
||||||
|
xss))
|
||||||
|
(define-values (m n) (list*-shape xss fail))
|
||||||
|
(list->array ((inst vector Index) m n) (apply append xss)))
|
||||||
|
|
||||||
|
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
|
||||||
|
(define (matrix->list* a)
|
||||||
|
(cond [(matrix? a) (array->list (array->list-array a 1))]
|
||||||
|
[else (raise-argument-error 'matrix->list* "matrix?" a)]))
|
||||||
|
|
||||||
|
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
|
||||||
|
(define (vector*->matrix xss)
|
||||||
|
(define (fail)
|
||||||
|
(raise-argument-error 'vector*->matrix
|
||||||
|
"nested vectors with rectangular shape and at least one matrix element"
|
||||||
|
xss))
|
||||||
|
(define-values (m n) (vector*-shape xss fail))
|
||||||
|
(vector->matrix m n (apply vector-append (vector->list xss))))
|
||||||
|
|
||||||
|
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
|
||||||
|
(define (matrix->vector* a)
|
||||||
|
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))]
|
||||||
|
[else (raise-argument-error 'matrix->vector* "matrix?" a)]))
|
|
@ -6,6 +6,7 @@
|
||||||
"../unsafe.rkt"
|
"../unsafe.rkt"
|
||||||
"matrix-types.rkt"
|
"matrix-types.rkt"
|
||||||
"matrix-constructors.rkt"
|
"matrix-constructors.rkt"
|
||||||
|
"matrix-conversion.rkt"
|
||||||
"matrix-arithmetic.rkt"
|
"matrix-arithmetic.rkt"
|
||||||
"matrix-basic.rkt"
|
"matrix-basic.rkt"
|
||||||
"matrix-column.rkt"
|
"matrix-column.rkt"
|
||||||
|
@ -19,13 +20,6 @@
|
||||||
; 4. Pseudo inverse
|
; 4. Pseudo inverse
|
||||||
; 5. Eigenvalues and eigenvectors
|
; 5. Eigenvalues and eigenvectors
|
||||||
|
|
||||||
; 6. "Bug"
|
|
||||||
; (for*/matrix : Number 2 3 ([i (in-naturals)]) i)
|
|
||||||
; ought to generate a matrix with numbers from 0 to 5.
|
|
||||||
; Problem: In expansion of for/matrix an extra [i (in-range (* m n))]
|
|
||||||
; is added to make sure the comprehension stops.
|
|
||||||
; But TR has problems with #:when so what is the proper expansion ?
|
|
||||||
|
|
||||||
(provide
|
(provide
|
||||||
matrix-inverse
|
matrix-inverse
|
||||||
; row and column
|
; row and column
|
||||||
|
@ -582,7 +576,7 @@
|
||||||
; Note: We project onto vs (not on the original ws)
|
; Note: We project onto vs (not on the original ws)
|
||||||
; in order to get numerical stability.
|
; in order to get numerical stability.
|
||||||
(let ([w-minus-proj (array-strict (array- w w-proj))])
|
(let ([w-minus-proj (array-strict (array- w w-proj))])
|
||||||
(if (zero-matrix? w-minus-proj)
|
(if (matrix-zero? w-minus-proj)
|
||||||
(loop vs (cdr ws)) ; w in span{vs} => omit it
|
(loop vs (cdr ws)) ; w in span{vs} => omit it
|
||||||
(loop (cons w-minus-proj vs) (cdr ws)))))]))
|
(loop (cons w-minus-proj vs) (cdr ws)))))]))
|
||||||
(reverse (loop (list (car ws)) (cdr ws)))]))
|
(reverse (loop (list (car ws)) (cdr ws)))]))
|
||||||
|
|
|
@ -1,340 +0,0 @@
|
||||||
#lang racket
|
|
||||||
|
|
||||||
(provide in-row
|
|
||||||
in-column)
|
|
||||||
|
|
||||||
(require math/array
|
|
||||||
"matrix-types.rkt"
|
|
||||||
"matrix-basic.rkt"
|
|
||||||
"matrix-constructors.rkt"
|
|
||||||
)
|
|
||||||
|
|
||||||
(define (in-row/proc M r)
|
|
||||||
(define-values (m n) (matrix-shape M))
|
|
||||||
(make-do-sequence
|
|
||||||
(λ ()
|
|
||||||
(values
|
|
||||||
; pos->element
|
|
||||||
(λ (j) (matrix-ref M r j))
|
|
||||||
; next-pos
|
|
||||||
(λ (j) (+ j 1))
|
|
||||||
; initial-pos
|
|
||||||
0
|
|
||||||
; continue-with-pos?
|
|
||||||
(λ (j) (< j n))
|
|
||||||
#f #f ))))
|
|
||||||
|
|
||||||
; (in-row M i]
|
|
||||||
; Returns a sequence of all elements of row i,
|
|
||||||
; that is xi0, xi1, xi2, ...
|
|
||||||
(define-sequence-syntax in-row
|
|
||||||
(λ () #'in-row/proc)
|
|
||||||
(λ (stx)
|
|
||||||
(syntax-case stx ()
|
|
||||||
[[(x) (_ M-expr r-expr)]
|
|
||||||
#'((x)
|
|
||||||
(:do-in
|
|
||||||
([(M r n d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 r-expr rd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? r)
|
|
||||||
(raise-type-error 'in-row "expected row number, got ~a" r))
|
|
||||||
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
|
|
||||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
|
||||||
([j 0])
|
|
||||||
(< j n)
|
|
||||||
([(x) (vector-ref d (+ (* r n) j))])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[(i x) (_ M-expr r-expr)]
|
|
||||||
#'((i x)
|
|
||||||
(:do-in
|
|
||||||
([(M r n d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 r-expr rd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? r)
|
|
||||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
|
||||||
([j 0])
|
|
||||||
(< j n)
|
|
||||||
([(x) (vector-ref d (+ (* r n) j))]
|
|
||||||
[(i) j])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[_ clause] (raise-syntax-error
|
|
||||||
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
|
|
||||||
|
|
||||||
; (in-column M j]
|
|
||||||
; Returns a sequence of all elements of column j,
|
|
||||||
; that is x0j, x1j, x2j, ...
|
|
||||||
|
|
||||||
|
|
||||||
(define (in-column/proc M s)
|
|
||||||
(define-values (m n) (matrix-shape M))
|
|
||||||
(make-do-sequence
|
|
||||||
(λ ()
|
|
||||||
(values
|
|
||||||
; pos->element
|
|
||||||
(λ (i) (matrix-ref M i s))
|
|
||||||
; next-pos
|
|
||||||
(λ (i) (+ i 1))
|
|
||||||
; initial-pos
|
|
||||||
0
|
|
||||||
; continue-with-pos?
|
|
||||||
(λ (i) (< i m))
|
|
||||||
#f #f ))))
|
|
||||||
|
|
||||||
(define-sequence-syntax in-column
|
|
||||||
(λ () #'in-column/proc)
|
|
||||||
(λ (stx)
|
|
||||||
(syntax-case stx ()
|
|
||||||
[[(x) (_ M-expr s-expr)]
|
|
||||||
#'((x)
|
|
||||||
(:do-in
|
|
||||||
([(M s n m d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 s-expr rd cd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? s)
|
|
||||||
(raise-type-error 'in-row "expected col number, got ~a" s))
|
|
||||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
|
||||||
(raise-type-error 'in-col "expected col number, got ~a" s)))
|
|
||||||
([j 0])
|
|
||||||
(< j m)
|
|
||||||
([(x) (vector-ref d (+ (* j n) s))])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[(i x) (_ M-expr s-expr)]
|
|
||||||
#'((x)
|
|
||||||
(:do-in
|
|
||||||
([(M s n m d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 s-expr rd cd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? s)
|
|
||||||
(raise-type-error 'in-column "expected col number, got ~a" s))
|
|
||||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
|
||||||
(raise-type-error 'in-column "expected col number, got ~a" s)))
|
|
||||||
([j 0])
|
|
||||||
(< j m)
|
|
||||||
([(x) (vector-ref d (+ (* j n) s))]
|
|
||||||
[(i) j])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[_ clause] (raise-syntax-error
|
|
||||||
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
|
||||||
|
|
||||||
(define-syntax (for/matrix-sum: stx)
|
|
||||||
(syntax-case stx ()
|
|
||||||
[(_ : type (for:-clause ...) . defs+exprs)
|
|
||||||
(syntax/loc stx
|
|
||||||
(let ()
|
|
||||||
(define: sum : (U False (Matrix Number)) #f)
|
|
||||||
(for: (for:-clause ...)
|
|
||||||
(define a (let () . defs+exprs))
|
|
||||||
(set! sum (if sum (array+ (assert sum) a) a)))
|
|
||||||
(assert sum)))]))
|
|
||||||
#|
|
|
||||||
;;;
|
|
||||||
;;; SEQUENCES
|
|
||||||
;;;
|
|
||||||
|
|
||||||
(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
|
||||||
(define (in-row/proc M r)
|
|
||||||
(define-values (m n) (matrix-shape M))
|
|
||||||
(make-do-sequence
|
|
||||||
(λ ()
|
|
||||||
(values
|
|
||||||
; pos->element
|
|
||||||
(λ: ([j : Index]) (matrix-ref M r j))
|
|
||||||
; next-pos
|
|
||||||
(λ: ([j : Index]) (assert (+ j 1) index?))
|
|
||||||
; initial-pos
|
|
||||||
0
|
|
||||||
; continue-with-pos?
|
|
||||||
(λ: ([j : Index ]) (< j n))
|
|
||||||
#f #f))))
|
|
||||||
|
|
||||||
(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
|
||||||
(define (in-column/proc M s)
|
|
||||||
(define-values (m n) (matrix-shape M))
|
|
||||||
(make-do-sequence
|
|
||||||
(λ ()
|
|
||||||
(values
|
|
||||||
; pos->element
|
|
||||||
(λ: ([i : Index]) (matrix-ref M i s))
|
|
||||||
; next-pos
|
|
||||||
(λ: ([i : Index]) (assert (+ i 1) index?))
|
|
||||||
; initial-pos
|
|
||||||
0
|
|
||||||
; continue-with-pos?
|
|
||||||
(λ: ([i : Index]) (< i m))
|
|
||||||
#f #f))))
|
|
||||||
|
|
||||||
; (in-row M i]
|
|
||||||
; Returns a sequence of all elements of row i,
|
|
||||||
; that is xi0, xi1, xi2, ...
|
|
||||||
(define-sequence-syntax in-row
|
|
||||||
(λ () #'in-row/proc)
|
|
||||||
(λ (stx)
|
|
||||||
(syntax-case stx ()
|
|
||||||
[[(x) (_ M-expr r-expr)]
|
|
||||||
#'((x)
|
|
||||||
(:do-in
|
|
||||||
([(M r n d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 r-expr rd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? r)
|
|
||||||
(raise-type-error 'in-row "expected row number, got ~a" r))
|
|
||||||
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
|
|
||||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
|
||||||
([j 0])
|
|
||||||
(< j n)
|
|
||||||
([(x) (vector-ref d (+ (* r n) j))])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[(i x) (_ M-expr r-expr)]
|
|
||||||
#'((i x)
|
|
||||||
(:do-in
|
|
||||||
([(M r n d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 r-expr rd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? r)
|
|
||||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
|
||||||
([j 0])
|
|
||||||
(< j n)
|
|
||||||
([(x) (vector-ref d (+ (* r n) j))]
|
|
||||||
[(i) j])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[_ clause] (raise-syntax-error
|
|
||||||
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
|
|
||||||
|
|
||||||
; (in-column M j]
|
|
||||||
; Returns a sequence of all elements of column j,
|
|
||||||
; that is x0j, x1j, x2j, ...
|
|
||||||
|
|
||||||
(define-sequence-syntax in-column
|
|
||||||
(λ () #'in-column/proc)
|
|
||||||
(λ (stx)
|
|
||||||
(syntax-case stx ()
|
|
||||||
; M-expr evaluates to column
|
|
||||||
[[(x) (_ M-expr)]
|
|
||||||
#'((x)
|
|
||||||
(:do-in
|
|
||||||
([(M n m d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 rd cd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
||||||
([j 0])
|
|
||||||
(< j n)
|
|
||||||
([(x) (vector-ref d j)])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
; M-expr evaluats to matrix, s-expr to the column index
|
|
||||||
[[(x) (_ M-expr s-expr)]
|
|
||||||
#'((x)
|
|
||||||
(:do-in
|
|
||||||
([(M s n m d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 s-expr rd cd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? s)
|
|
||||||
(raise-type-error 'in-row "expected col number, got ~a" s))
|
|
||||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
|
||||||
(raise-type-error 'in-col "expected col number, got ~a" s)))
|
|
||||||
([j 0])
|
|
||||||
(< j m)
|
|
||||||
([(x) (vector-ref d (+ (* j n) s))])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[(i x) (_ M-expr s-expr)]
|
|
||||||
#'((x)
|
|
||||||
(:do-in
|
|
||||||
([(M s n m d)
|
|
||||||
(let ([M1 M-expr])
|
|
||||||
(define-values (rd cd) (matrix-shape M1))
|
|
||||||
(values M1 s-expr rd cd
|
|
||||||
(mutable-array-data
|
|
||||||
(array->mutable-array M1))))])
|
|
||||||
(begin
|
|
||||||
(unless (matrix? M)
|
|
||||||
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
|
||||||
(unless (integer? s)
|
|
||||||
(raise-type-error 'in-column "expected col number, got ~a" s))
|
|
||||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
|
||||||
(raise-type-error 'in-column "expected col number, got ~a" s)))
|
|
||||||
([j 0])
|
|
||||||
(< j m)
|
|
||||||
([(x) (vector-ref d (+ (* j n) s))]
|
|
||||||
[(i) j])
|
|
||||||
#true
|
|
||||||
#true
|
|
||||||
[(+ j 1)]))]
|
|
||||||
[[_ clause] (raise-syntax-error
|
|
||||||
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
|
||||||
|#
|
|
||||||
#;
|
|
||||||
(module* test #f
|
|
||||||
(require rackunit)
|
|
||||||
; "matrix-sequences.rkt"
|
|
||||||
(check-equal? (for/list ([x (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
|
||||||
'(3 4))
|
|
||||||
(check-equal? (for/list ([(i x) (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)])
|
|
||||||
(list i x))
|
|
||||||
'((0 3) (1 4)))
|
|
||||||
(check-equal? (for/list ([x (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
|
||||||
'(2 4))
|
|
||||||
(check-equal? (for/list ([(i x) (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)])
|
|
||||||
(list i x))
|
|
||||||
'((0 2) (1 4))))
|
|
35
collects/math/private/matrix/matrix-syntax.rkt
Normal file
35
collects/math/private/matrix/matrix-syntax.rkt
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
#lang racket/base
|
||||||
|
|
||||||
|
(require (for-syntax racket/base
|
||||||
|
syntax/parse)
|
||||||
|
(only-in typed/racket/base :)
|
||||||
|
math/array)
|
||||||
|
|
||||||
|
(provide matrix row-matrix col-matrix)
|
||||||
|
|
||||||
|
(define-syntax (matrix stx)
|
||||||
|
(syntax-parse stx #:literals (:)
|
||||||
|
[(_ [[x0 xs0 ...] [x xs ...] ...])
|
||||||
|
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))]
|
||||||
|
[(_ [[x0 xs0 ...] [x xs ...] ...] : T)
|
||||||
|
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))]
|
||||||
|
[(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T)))
|
||||||
|
(raise-syntax-error 'matrix "given empty row" stx #'r)]
|
||||||
|
[(_ (~and [] c) (~optional (~seq : T)))
|
||||||
|
(raise-syntax-error 'matrix "given empty matrix" stx #'c)]
|
||||||
|
[(_ x (~optional (~seq : T)))
|
||||||
|
(raise-syntax-error 'matrix "expected two-dimensional data" stx)]))
|
||||||
|
|
||||||
|
(define-syntax (row-matrix stx)
|
||||||
|
(syntax-parse stx #:literals (:)
|
||||||
|
[(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))]
|
||||||
|
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))]
|
||||||
|
[(_ (~and [] r) (~optional (~seq : T)))
|
||||||
|
(raise-syntax-error 'row-matrix "given empty row" stx #'r)]))
|
||||||
|
|
||||||
|
(define-syntax (col-matrix stx)
|
||||||
|
(syntax-parse stx #:literals (:)
|
||||||
|
[(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))]
|
||||||
|
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))]
|
||||||
|
[(_ (~and [] c) (~optional (~seq : T)))
|
||||||
|
(raise-syntax-error 'row-matrix "given empty column" stx #'c)]))
|
|
@ -27,44 +27,53 @@
|
||||||
(define-type (Column-Matrix A) (Matrix A))
|
(define-type (Column-Matrix A) (Matrix A))
|
||||||
; a column vector represented as a matrix
|
; a column vector represented as a matrix
|
||||||
|
|
||||||
(define matrix?
|
(: matrix? (All (A) ((Array A) -> Boolean)))
|
||||||
(plambda: (A) ([arr : (Array A)])
|
(define (matrix? arr)
|
||||||
(and (> (array-size arr) 0)
|
(and (> (array-size arr) 0)
|
||||||
(= (array-dims arr) 2))))
|
(= (array-dims arr) 2)))
|
||||||
|
|
||||||
(define square-matrix?
|
(: square-matrix? (All (A) ((Array A) -> Boolean)))
|
||||||
(plambda: (A) ([arr : (Array A)])
|
(define (square-matrix? arr)
|
||||||
(and (matrix? arr)
|
(define ds (array-shape arr))
|
||||||
(let ([sh (array-shape arr)])
|
(and (= (vector-length ds) 2)
|
||||||
(= (vector-ref sh 0) (vector-ref sh 1))))))
|
(let ([d0 (unsafe-vector-ref ds 0)]
|
||||||
|
[d1 (unsafe-vector-ref ds 1)])
|
||||||
|
(and (> d0 0) (> d1 0) (= d0 d1)))))
|
||||||
|
|
||||||
(define row-matrix?
|
(: row-matrix? (All (A) ((Array A) -> Boolean)))
|
||||||
(plambda: (A) ([arr : (Array A)])
|
(define (row-matrix? arr)
|
||||||
(and (matrix? arr)
|
(define ds (array-shape arr))
|
||||||
(= (vector-ref (array-shape arr) 0) 1))))
|
(and (= (vector-length ds) 2)
|
||||||
|
(= (unsafe-vector-ref ds 0) 1)
|
||||||
|
(> (unsafe-vector-ref ds 1) 0)))
|
||||||
|
|
||||||
(define col-matrix?
|
(: col-matrix? (All (A) ((Array A) -> Boolean)))
|
||||||
(plambda: (A) ([arr : (Array A)])
|
(define (col-matrix? arr)
|
||||||
(and (matrix? arr)
|
(define ds (array-shape arr))
|
||||||
(= (vector-ref (array-shape arr) 1) 1))))
|
(and (= (vector-length ds) 2)
|
||||||
|
(> (unsafe-vector-ref ds 0) 0)
|
||||||
|
(= (unsafe-vector-ref ds 1) 1)))
|
||||||
|
|
||||||
(: matrix-shape : (All (A) (Matrix A) -> (Values Index Index)))
|
(: matrix-shape (All (A) ((Array A) -> (Values Index Index))))
|
||||||
(define (matrix-shape a)
|
(define (matrix-shape a)
|
||||||
(cond [(matrix? a) (define sh (array-shape a))
|
(define ds (array-shape a))
|
||||||
(values (unsafe-vector-ref sh 0) (unsafe-vector-ref sh 1))]
|
(if (and (> (array-size a) 0)
|
||||||
[else (raise-argument-error 'matrix-shape "matrix?" a)]))
|
(= (vector-length ds) 2))
|
||||||
|
(values (unsafe-vector-ref ds 0)
|
||||||
|
(unsafe-vector-ref ds 1))
|
||||||
|
(raise-argument-error 'matrix-shape "matrix?" a)))
|
||||||
|
|
||||||
(: square-matrix-size (All (A) ((Matrix A) -> Index)))
|
(: square-matrix-size (All (A) ((Array A) -> Index)))
|
||||||
(define (square-matrix-size arr)
|
(define (square-matrix-size arr)
|
||||||
(cond [(square-matrix? arr) (unsafe-vector-ref (array-shape arr) 0)]
|
(cond [(square-matrix? arr) (unsafe-vector-ref (array-shape arr) 0)]
|
||||||
[else (raise-argument-error 'square-matrix-size "square-matrix?" arr)]))
|
[else (raise-argument-error 'square-matrix-size "square-matrix?" arr)]))
|
||||||
|
|
||||||
(: matrix-num-rows (All (A) ((Matrix A) -> Index)))
|
(: matrix-num-rows (All (A) ((Array A) -> Index)))
|
||||||
(define (matrix-num-rows a)
|
(define (matrix-num-rows a)
|
||||||
(cond [(matrix? a) (vector-ref (array-shape a) 0)]
|
(cond [(matrix? a) (vector-ref (array-shape a) 0)]
|
||||||
[else (raise-argument-error 'matrix-col-length "matrix?" a)]))
|
[else (raise-argument-error 'matrix-col-length "matrix?" a)]))
|
||||||
|
|
||||||
(: matrix-num-cols (All (A) ((Matrix A) -> Index)))
|
(: matrix-num-cols (All (A) ((Array A) -> Index)))
|
||||||
(define (matrix-num-cols a)
|
(define (matrix-num-cols a)
|
||||||
(cond [(matrix? a) (vector-ref (array-shape a) 1)]
|
(cond [(matrix? a) (vector-ref (array-shape a) 1)]
|
||||||
[else (raise-argument-error 'matrix-row-length "matrix?" a)]))
|
[else (raise-argument-error 'matrix-row-length "matrix?" a)]))
|
||||||
|
|
|
@ -73,8 +73,6 @@
|
||||||
|
|
||||||
) ; module
|
) ; module
|
||||||
|
|
||||||
(require 'syntax-defs)
|
|
||||||
|
|
||||||
(module untyped-defs typed/racket/base
|
(module untyped-defs typed/racket/base
|
||||||
(require math/array
|
(require math/array
|
||||||
(submod ".." syntax-defs)
|
(submod ".." syntax-defs)
|
||||||
|
@ -102,4 +100,5 @@
|
||||||
|
|
||||||
) ; module
|
) ; module
|
||||||
|
|
||||||
(require 'untyped-defs)
|
(require 'syntax-defs
|
||||||
|
'untyped-defs)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
(define (matrix-shapes name arr . brrs)
|
(define (matrix-shapes name arr . brrs)
|
||||||
(define-values (m n) (matrix-shape arr))
|
(define-values (m n) (matrix-shape arr))
|
||||||
(unless (andmap (λ: ([brr : (Matrix Any)])
|
(unless (andmap (λ: ([brr : (Matrix Any)])
|
||||||
(match-define (vector bm bn) (array-shape brr))
|
(define-values (bm bn) (matrix-shape brr))
|
||||||
(and (= bm m) (= bn n)))
|
(and (= bm m) (= bn n)))
|
||||||
brrs)
|
brrs)
|
||||||
(error name
|
(error name
|
||||||
|
|
|
@ -7,192 +7,673 @@
|
||||||
"../private/matrix/matrix-column.rkt"
|
"../private/matrix/matrix-column.rkt"
|
||||||
"test-utils.rkt")
|
"test-utils.rkt")
|
||||||
|
|
||||||
(: random-matrix (Integer Integer Integer -> (Matrix Integer)))
|
(: random-matrix (case-> (Integer Integer -> (Matrix Integer))
|
||||||
;; Generates a random matrix with integer elements < k. Useful to test properties.
|
(Integer Integer Integer -> (Matrix Integer))))
|
||||||
(define (random-matrix m n k)
|
;; Generates a random matrix with Natural elements < k. Useful to test properties.
|
||||||
|
(define (random-matrix m n [k 100])
|
||||||
(array-strict (build-array (vector m n) (λ (_) (random k)))))
|
(array-strict (build-array (vector m n) (λ (_) (random k)))))
|
||||||
|
|
||||||
|
(define nonmatrices
|
||||||
|
(list (make-array #() 0)
|
||||||
|
(make-array #(1) 0)
|
||||||
|
(make-array #(1 0) 0)
|
||||||
|
(make-array #(0 1) 0)
|
||||||
|
(make-array #(0 0) 0)
|
||||||
|
(make-array #(1 1 1) 0)))
|
||||||
|
|
||||||
;; ===================================================================================================
|
;; ===================================================================================================
|
||||||
;; Types
|
;; Literal syntax
|
||||||
|
|
||||||
|
(check-equal? (matrix [[1]])
|
||||||
|
(array #[#[1]]))
|
||||||
|
|
||||||
|
(check-equal? (matrix [[1 2 3 4]])
|
||||||
|
(array #[#[1 2 3 4]]))
|
||||||
|
|
||||||
|
(check-equal? (matrix [[1 2] [3 4]])
|
||||||
|
(array #[#[1 2] #[3 4]]))
|
||||||
|
|
||||||
|
(check-equal? (matrix [[1] [2] [3] [4]])
|
||||||
|
(array #[#[1] #[2] #[3] #[4]]))
|
||||||
|
|
||||||
|
(check-equal? (row-matrix [1 2 3 4])
|
||||||
|
(matrix [[1 2 3 4]]))
|
||||||
|
|
||||||
|
(check-equal? (col-matrix [1 2 3 4])
|
||||||
|
(matrix [[1] [2] [3] [4]]))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Predicates
|
||||||
|
|
||||||
(check-true (matrix? (array #[#[1]])))
|
(check-true (matrix? (array #[#[1]])))
|
||||||
(check-false (matrix? (array #[1])))
|
(check-false (matrix? (array #[1])))
|
||||||
(check-false (matrix? (array 1)))
|
(check-false (matrix? (array 1)))
|
||||||
(check-false (matrix? (array #[])))
|
(check-false (matrix? (array #[])))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-false (matrix? a)))
|
||||||
|
|
||||||
|
(check-true (square-matrix? (matrix [[1]])))
|
||||||
|
(check-true (square-matrix? (matrix [[1 1] [1 1]])))
|
||||||
|
(check-false (square-matrix? (matrix [[1 2]])))
|
||||||
|
(check-false (square-matrix? (matrix [[1] [2]])))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-false (square-matrix? a)))
|
||||||
|
|
||||||
(check-true (row-matrix? (matrix [[1 2 3 4]])))
|
(check-true (row-matrix? (matrix [[1 2 3 4]])))
|
||||||
|
(check-true (row-matrix? (matrix [[1]])))
|
||||||
(check-false (row-matrix? (matrix [[1] [2] [3] [4]])))
|
(check-false (row-matrix? (matrix [[1] [2] [3] [4]])))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-false (row-matrix? a)))
|
||||||
|
|
||||||
(check-true (col-matrix? (matrix [[1] [2] [3] [4]])))
|
(check-true (col-matrix? (matrix [[1] [2] [3] [4]])))
|
||||||
|
(check-true (col-matrix? (matrix [[1]])))
|
||||||
(check-false (col-matrix? (matrix [[1 2 3 4]])))
|
(check-false (col-matrix? (matrix [[1 2 3 4]])))
|
||||||
|
(check-false (col-matrix? (array #[1])))
|
||||||
|
(check-false (col-matrix? (array 1)))
|
||||||
|
(check-false (col-matrix? (array #[])))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-false (col-matrix? a)))
|
||||||
|
|
||||||
|
(check-true (matrix-zero? (make-matrix 4 3 0)))
|
||||||
|
(check-true (matrix-zero? (make-matrix 4 3 0.0)))
|
||||||
|
(check-true (matrix-zero? (make-matrix 4 3 0+0.0i)))
|
||||||
|
(check-false (matrix-zero? (row-matrix [0 0 0 0 1])))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-zero? a))))
|
||||||
|
|
||||||
;; ===================================================================================================
|
;; ===================================================================================================
|
||||||
;; Matrix multiplication
|
;; Accessors
|
||||||
|
|
||||||
(check-equal? (matrix* (identity-matrix 2)
|
;; matrix-shape
|
||||||
(matrix [[1 20] [300 4000]]))
|
|
||||||
(matrix [[1 20] [300 4000]]))
|
(check-equal? (let-values ([(m n) (matrix-shape (matrix [[1 2 3] [4 5 6]]))])
|
||||||
|
(list m n))
|
||||||
|
(list 2 3))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (let-values ([(m n) (matrix-shape a)])
|
||||||
|
(void)))))
|
||||||
|
|
||||||
|
;; square-matrix-size
|
||||||
|
|
||||||
|
(check-equal? (square-matrix-size (matrix [[1 2] [3 4]]))
|
||||||
|
2)
|
||||||
|
|
||||||
|
(check-exn exn:fail:contract? (λ () (square-matrix-size (matrix [[1 2]]))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (square-matrix-size (matrix [[1] [2]]))))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (square-matrix-size a))))
|
||||||
|
|
||||||
|
;; matrix-num-rows
|
||||||
|
|
||||||
|
(check-equal? (matrix-num-rows (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
2)
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-num-rows a))))
|
||||||
|
|
||||||
|
;; matrix-num-cols
|
||||||
|
|
||||||
|
(check-equal? (matrix-num-cols (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
3)
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-num-cols a))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Constructors
|
||||||
|
|
||||||
|
;; identity-matrix
|
||||||
|
|
||||||
|
(check-equal? (identity-matrix 1) (matrix [[1]]))
|
||||||
|
(check-equal? (identity-matrix 2) (matrix [[1 0] [0 1]]))
|
||||||
|
(check-equal? (identity-matrix 3) (matrix [[1 0 0] [0 1 0] [0 0 1]]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (identity-matrix 0)))
|
||||||
|
|
||||||
|
;; make-matrix
|
||||||
|
|
||||||
|
(check-equal? (make-matrix 1 1 4) (matrix [[4]]))
|
||||||
|
(check-equal? (make-matrix 2 2 3) (matrix [[3 3] [3 3]]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (make-matrix 1 0 4)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (make-matrix 0 1 4)))
|
||||||
|
|
||||||
|
;; build-matrix
|
||||||
|
|
||||||
|
(check-equal? (build-matrix 4 4 (λ: ([i : Index] [j : Index])
|
||||||
|
(+ i j)))
|
||||||
|
(build-array #(4 4) (λ: ([js : Indexes])
|
||||||
|
(+ (vector-ref js 0) (vector-ref js 1)))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (build-matrix 1 0 (λ: ([i : Index] [j : Index]) (+ i j)))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (build-matrix 0 1 (λ: ([i : Index] [j : Index]) (+ i j)))))
|
||||||
|
|
||||||
|
;; diagonal-matrix
|
||||||
|
|
||||||
|
(check-equal? (diagonal-matrix '(1 2 3 4))
|
||||||
|
(matrix [[1 0 0 0]
|
||||||
|
[0 2 0 0]
|
||||||
|
[0 0 3 0]
|
||||||
|
[0 0 0 4]]))
|
||||||
|
|
||||||
|
(check-exn exn:fail:contract? (λ () (diagonal-matrix '())))
|
||||||
|
|
||||||
|
;; block-diagonal-matrix
|
||||||
|
|
||||||
|
(let ([m (random-matrix 4 4 100)])
|
||||||
|
(check-equal? (block-diagonal-matrix (list m))
|
||||||
|
m))
|
||||||
|
|
||||||
|
(check-equal?
|
||||||
|
(block-diagonal-matrix
|
||||||
|
(list (matrix [[1 2] [3 4]])
|
||||||
|
(matrix [[1 2 3] [4 5 6]])
|
||||||
|
(matrix [[1] [3] [5]])
|
||||||
|
(matrix [[2 4 6]])))
|
||||||
|
(matrix [[1 2 0 0 0 0 0 0 0]
|
||||||
|
[3 4 0 0 0 0 0 0 0]
|
||||||
|
[0 0 1 2 3 0 0 0 0]
|
||||||
|
[0 0 4 5 6 0 0 0 0]
|
||||||
|
[0 0 0 0 0 1 0 0 0]
|
||||||
|
[0 0 0 0 0 3 0 0 0]
|
||||||
|
[0 0 0 0 0 5 0 0 0]
|
||||||
|
[0 0 0 0 0 0 2 4 6]]))
|
||||||
|
|
||||||
|
(check-equal?
|
||||||
|
(block-diagonal-matrix (map (λ: ([i : Integer]) (matrix [[i]])) '(1 2 3 4)))
|
||||||
|
(diagonal-matrix '(1 2 3 4)))
|
||||||
|
|
||||||
|
(check-exn exn:fail:contract? (λ () (block-diagonal-matrix '())))
|
||||||
|
|
||||||
|
;; Vandermonde matrix
|
||||||
|
|
||||||
|
(check-equal? (vandermonde-matrix '(10) 1)
|
||||||
|
(matrix [[1]]))
|
||||||
|
(check-equal? (vandermonde-matrix '(10) 4)
|
||||||
|
(matrix [[1 10 100 1000]]))
|
||||||
|
(check-equal? (vandermonde-matrix '(1 2 3 4) 3)
|
||||||
|
(matrix [[1 1 1] [1 2 4] [1 3 9] [1 4 16]]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (vandermonde-matrix '() 1)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (vandermonde-matrix '(1) 0)))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Flat conversion
|
||||||
|
|
||||||
|
(check-equal? (list->matrix 1 3 '(1 2 3)) (row-matrix [1 2 3]))
|
||||||
|
(check-equal? (list->matrix 3 1 '(1 2 3)) (col-matrix [1 2 3]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (list->matrix 0 1 '())))
|
||||||
|
(check-exn exn:fail:contract? (λ () (list->matrix 1 0 '())))
|
||||||
|
(check-exn exn:fail:contract? (λ () (list->matrix 1 1 '(1 2))))
|
||||||
|
|
||||||
|
(check-equal? (vector->matrix 1 3 #(1 2 3)) (row-matrix [1 2 3]))
|
||||||
|
(check-equal? (vector->matrix 3 1 #(1 2 3)) (col-matrix [1 2 3]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (vector->matrix 0 1 #())))
|
||||||
|
(check-exn exn:fail:contract? (λ () (vector->matrix 1 0 #())))
|
||||||
|
(check-exn exn:fail:contract? (λ () (vector->matrix 1 1 #(1 2))))
|
||||||
|
|
||||||
|
(check-equal? (->row-matrix '(1 2 3)) (row-matrix [1 2 3]))
|
||||||
|
(check-equal? (->row-matrix #(1 2 3)) (row-matrix [1 2 3]))
|
||||||
|
(check-equal? (->row-matrix (row-matrix [1 2 3])) (row-matrix [1 2 3]))
|
||||||
|
(check-equal? (->row-matrix (col-matrix [1 2 3])) (row-matrix [1 2 3]))
|
||||||
|
(check-equal? (->row-matrix (make-array #() 1)) (row-matrix [1]))
|
||||||
|
(check-equal? (->row-matrix (make-array #(3) 1)) (row-matrix [1 1 1]))
|
||||||
|
(check-equal? (->row-matrix (make-array #(1 3 1) 1)) (row-matrix [1 1 1]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(2 3 1) 1))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(1 3 2) 1))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(0 3) 1))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->row-matrix (make-array #(3 0) 1))))
|
||||||
|
|
||||||
|
(check-equal? (->col-matrix '(1 2 3)) (col-matrix [1 2 3]))
|
||||||
|
(check-equal? (->col-matrix #(1 2 3)) (col-matrix [1 2 3]))
|
||||||
|
(check-equal? (->col-matrix (col-matrix [1 2 3])) (col-matrix [1 2 3]))
|
||||||
|
(check-equal? (->col-matrix (row-matrix [1 2 3])) (col-matrix [1 2 3]))
|
||||||
|
(check-equal? (->col-matrix (make-array #() 1)) (col-matrix [1]))
|
||||||
|
(check-equal? (->col-matrix (make-array #(3) 1)) (col-matrix [1 1 1]))
|
||||||
|
(check-equal? (->col-matrix (make-array #(1 3 1) 1)) (col-matrix [1 1 1]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(2 3 1) 1))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(1 3 2) 1))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(0 3) 1))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (->col-matrix (make-array #(3 0) 1))))
|
||||||
|
|
||||||
|
(check-equal? (matrix->list (matrix [[1 2 3] [4 5 6]])) '(1 2 3 4 5 6))
|
||||||
|
(check-equal? (matrix->list (row-matrix [1 2 3])) '(1 2 3))
|
||||||
|
(check-equal? (matrix->list (col-matrix [1 2 3])) '(1 2 3))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix->list a))))
|
||||||
|
|
||||||
|
(check-equal? (matrix->vector (matrix [[1 2 3] [4 5 6]])) #(1 2 3 4 5 6))
|
||||||
|
(check-equal? (matrix->vector (row-matrix [1 2 3])) #(1 2 3))
|
||||||
|
(check-equal? (matrix->vector (col-matrix [1 2 3])) #(1 2 3))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix->vector a))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Nested conversion
|
||||||
|
|
||||||
|
(check-equal? (list*->matrix '((1 2 3) (4 5 6))) (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(check-exn exn:fail:contract? (λ () (list*->matrix '((1 2 3) (4 5)))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (list*->matrix '(() () ()))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (list*->matrix '())))
|
||||||
|
|
||||||
|
(check-equal? ((inst vector*->matrix Integer) #(#(1 2 3) #(4 5 6))) (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #(#(1 2 3) #(4 5)))))
|
||||||
|
(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #(#() #() #()))))
|
||||||
|
(check-exn exn:fail:contract? (λ () ((inst vector*->matrix Integer) #())))
|
||||||
|
|
||||||
|
(check-equal? (matrix->list* (matrix [[1 2 3] [4 5 6]])) '((1 2 3) (4 5 6)))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix->list* a))))
|
||||||
|
|
||||||
|
(check-equal? (matrix->vector* (matrix [[1 2 3] [4 5 6]])) #(#(1 2 3) #(4 5 6)))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix->vector* a))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Equality
|
||||||
|
|
||||||
|
(check-true (matrix= (matrix [[1 2 3]
|
||||||
|
[4 5 6]])
|
||||||
|
(matrix [[1.0 2.0 3.0]
|
||||||
|
[4.0 5.0 6.0]])))
|
||||||
|
|
||||||
|
(check-true (matrix= (matrix [[1 2 3]
|
||||||
|
[4 5 6]])
|
||||||
|
(matrix [[1.0 2.0 3.0]
|
||||||
|
[4.0 5.0 6.0]])
|
||||||
|
(matrix [[1.0+0.0i 2.0+0.0i 3.0+0.0i]
|
||||||
|
[4.0+0.0i 5.0+0.0i 6.0+0.0i]])))
|
||||||
|
|
||||||
|
(check-false (matrix= (matrix [[1 2 3] [4 5 6]])
|
||||||
|
(matrix [[1 2 3] [4 5 7]])))
|
||||||
|
|
||||||
|
(check-false (matrix= (matrix [[0 2 3] [4 5 6]])
|
||||||
|
(matrix [[1 2 3] [4 5 7]])))
|
||||||
|
|
||||||
|
(check-false (matrix= (matrix [[1 2 3] [4 5 6]])
|
||||||
|
(matrix [[1 4] [2 5] [3 6]])))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix= a (matrix [[1]]))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix= (matrix [[1]]) a)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix= (matrix [[1]]) (matrix [[1]]) a))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Pointwise operations
|
||||||
|
|
||||||
|
(define-syntax-rule (test-matrix-map (matrix-map ...) (array-map ...))
|
||||||
|
(begin
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-map ... a)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-map ... (matrix [[1]]) a))))
|
||||||
|
|
||||||
|
(for*: ([m '(2 3 4)]
|
||||||
|
[n '(2 3 4)])
|
||||||
|
(define a0 (random-matrix m n))
|
||||||
|
(define a1 (random-matrix m n))
|
||||||
|
(define a2 (random-matrix m n))
|
||||||
|
(check-equal? (matrix-map ... a0)
|
||||||
|
(array-map ... a0))
|
||||||
|
(check-equal? (matrix-map ... a0 a1)
|
||||||
|
(array-map ... a0 a1))
|
||||||
|
(check-equal? (matrix-map ... a0 a1 a2)
|
||||||
|
(array-map ... a0 a1 a2))
|
||||||
|
;; Don't know why this (void) is necessary, but TR complains without it
|
||||||
|
(void))))
|
||||||
|
|
||||||
|
(test-matrix-map (matrix-map -) (array-map -))
|
||||||
|
(test-matrix-map ((values matrix-map) -) (array-map -))
|
||||||
|
|
||||||
|
(test-matrix-map (matrix+) (array+))
|
||||||
|
(test-matrix-map ((values matrix+)) (array+))
|
||||||
|
|
||||||
|
(test-matrix-map (matrix-) (array-))
|
||||||
|
(test-matrix-map ((values matrix-)) (array-))
|
||||||
|
|
||||||
|
(check-equal? (matrix-sum (list (matrix [[1 2 3] [4 5 6]])))
|
||||||
|
(matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(check-equal? (matrix-sum (list (matrix [[1 2 3] [4 5 6]])
|
||||||
|
(matrix [[0 1 2] [3 4 5]])))
|
||||||
|
(matrix+ (matrix [[1 2 3] [4 5 6]])
|
||||||
|
(matrix [[0 1 2] [3 4 5]])))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-sum '())))
|
||||||
|
|
||||||
|
(check-equal? (matrix-scale (matrix [[1 2 3] [4 5 6]]) 10)
|
||||||
|
(matrix [[10 20 30] [40 50 60]]))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-scale a 0))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Multiplication
|
||||||
|
|
||||||
|
(define-syntax-rule (test-matrix* matrix*)
|
||||||
|
(begin
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix* a (matrix [[1]])))))
|
||||||
|
|
||||||
(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]])
|
(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]])
|
||||||
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||||
(matrix [[30 36 42] [66 81 96] [102 126 150]]))
|
(matrix [[30 36 42] [66 81 96] [102 126 150]]))
|
||||||
|
|
||||||
(let ([m0 (random-matrix 4 5 100)]
|
(check-equal? (matrix* (row-matrix [1 2 3 4])
|
||||||
[m1 (random-matrix 5 2 100)]
|
(col-matrix [1 2 3 4]))
|
||||||
[m2 (random-matrix 2 10 100)])
|
(matrix [[30]]))
|
||||||
|
|
||||||
|
(check-equal? (matrix* (col-matrix [1 2 3 4])
|
||||||
|
(row-matrix [1 2 3 4]))
|
||||||
|
(matrix [[1 2 3 4]
|
||||||
|
[2 4 6 8]
|
||||||
|
[3 6 9 12]
|
||||||
|
[4 8 12 16]]))
|
||||||
|
|
||||||
|
(check-equal? (matrix* (matrix [[3]]) (matrix [[7]]))
|
||||||
|
(matrix [[21]]))
|
||||||
|
|
||||||
|
;; Left/right identity
|
||||||
|
(let ([m (random-matrix 2 2)])
|
||||||
|
(check-equal? (matrix* (identity-matrix 2) m)
|
||||||
|
m)
|
||||||
|
(check-equal? (matrix* m (identity-matrix 2))
|
||||||
|
m))
|
||||||
|
|
||||||
|
;; Shape
|
||||||
|
(let ([m0 (random-matrix 4 5)]
|
||||||
|
[m1 (random-matrix 5 2)]
|
||||||
|
[m2 (random-matrix 2 10)])
|
||||||
|
(check-equal? (let-values ([(m n) (matrix-shape (matrix* m0 m1))])
|
||||||
|
(list m n))
|
||||||
|
(list 4 2))
|
||||||
|
(check-equal? (let-values ([(m n) (matrix-shape (matrix* m1 m2))])
|
||||||
|
(list m n))
|
||||||
|
(list 5 10))
|
||||||
|
(check-equal? (let-values ([(m n) (matrix-shape (matrix* m0 m1 m2))])
|
||||||
|
(list m n))
|
||||||
|
(list 4 10)))
|
||||||
|
|
||||||
|
(check-exn exn:fail? (λ () (matrix* (random-matrix 1 2) (random-matrix 3 2))))
|
||||||
|
|
||||||
|
;; Associativity
|
||||||
|
(let ([m0 (random-matrix 4 5)]
|
||||||
|
[m1 (random-matrix 5 2)]
|
||||||
|
[m2 (random-matrix 2 10)])
|
||||||
|
(check-equal? (matrix* m0 m1 m2)
|
||||||
|
(matrix* (matrix* m0 m1) m2))
|
||||||
(check-equal? (matrix* (matrix* m0 m1) m2)
|
(check-equal? (matrix* (matrix* m0 m1) m2)
|
||||||
(matrix* m0 (matrix* m1 m2))))
|
(matrix* m0 (matrix* m1 m2))))
|
||||||
|
))
|
||||||
|
|
||||||
|
(test-matrix* matrix*)
|
||||||
|
;; `matrix*' is an inlining macro, so we need to check the function version as well
|
||||||
|
(test-matrix* (values matrix*))
|
||||||
|
|
||||||
;; ===================================================================================================
|
;; ===================================================================================================
|
||||||
;; Construction
|
;; Exponentiation
|
||||||
|
|
||||||
|
(let ([A (matrix [[1 2] [3 4]])])
|
||||||
|
(check-equal? (matrix-expt A 0) (identity-matrix 2))
|
||||||
|
(check-equal? (matrix-expt A 1) A)
|
||||||
|
(check-equal? (matrix-expt A 2) (matrix [[7 10] [15 22]]))
|
||||||
|
(check-equal? (matrix-expt A 3) (matrix [[37 54] [81 118]]))
|
||||||
|
(check-equal? (matrix-expt A 8) (matrix [[165751 241570] [362355 528106]])))
|
||||||
|
|
||||||
|
(check-equal? (matrix-expt (matrix [[2]]) 10) (matrix [[(expt 2 10)]]))
|
||||||
|
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-expt (row-matrix [1 2 3]) 0)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-expt (col-matrix [1 2 3]) 0)))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-expt a 0))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Comprehensions
|
||||||
|
|
||||||
|
;; for/matrix and friends are defined in terms of for/array and friends, so we only need to test that
|
||||||
|
;; it works for one case each, and that they properly raise exceptions when given zero-length axes
|
||||||
|
|
||||||
(check-equal?
|
(check-equal?
|
||||||
(block-diagonal-matrix
|
(for/matrix 2 2 ([i (in-range 4)]) i)
|
||||||
(list
|
(matrix [[0 1] [2 3]]))
|
||||||
(matrix [[1 2] [3 4]])
|
|
||||||
(matrix [[1 2 3] [4 5 6]])
|
#;; TR can't type this, but it's defined using exactly the same wrapper as `for/matrix'
|
||||||
(matrix [[1] [3] [5]])))
|
(check-equal?
|
||||||
(matrix
|
(for*/matrix 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j))
|
||||||
[[1 2 0 0 0 0]
|
(matrix [[0 1] [1 2]]))
|
||||||
[3 4 0 0 0 0]
|
|
||||||
[0 0 1 2 3 0]
|
(check-equal?
|
||||||
[0 0 4 5 6 0]
|
(for/matrix: 2 2 ([i (in-range 4)]) i)
|
||||||
[0 0 0 0 0 1]
|
(matrix [[0 1] [2 3]]))
|
||||||
[0 0 0 0 0 3]
|
|
||||||
[0 0 0 0 0 5]]))
|
(check-equal?
|
||||||
|
(for*/matrix: 2 2 ([i (in-range 2)] [j (in-range 2)]) (+ i j))
|
||||||
|
(matrix [[0 1] [1 2]]))
|
||||||
|
|
||||||
|
(check-exn exn:fail:contract? (λ () (for/matrix 2 0 () 0)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (for/matrix 0 2 () 0)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (for*/matrix 2 0 () 0)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (for*/matrix 0 2 () 0)))
|
||||||
|
|
||||||
|
(check-exn exn:fail:contract? (λ () (for/matrix: 2 0 () 0)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (for/matrix: 0 2 () 0)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (for*/matrix: 2 0 () 0)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (for*/matrix: 0 2 () 0)))
|
||||||
|
|
||||||
;; ===================================================================================================
|
;; ===================================================================================================
|
||||||
|
;; Extraction
|
||||||
|
|
||||||
|
;; matrix-ref
|
||||||
|
|
||||||
|
(let ([a (matrix [[10 11] [12 13]])])
|
||||||
|
(check-equal? (matrix-ref a 0 0) 10)
|
||||||
|
(check-equal? (matrix-ref a 0 1) 11)
|
||||||
|
(check-equal? (matrix-ref a 1 0) 12)
|
||||||
|
(check-equal? (matrix-ref a 1 1) 13)
|
||||||
|
(check-exn exn:fail? (λ () (matrix-ref a 2 0)))
|
||||||
|
(check-exn exn:fail? (λ () (matrix-ref a 0 2)))
|
||||||
|
(check-exn exn:fail? (λ () (matrix-ref a -1 0)))
|
||||||
|
(check-exn exn:fail? (λ () (matrix-ref a 0 -1))))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-ref a 0 0))))
|
||||||
|
|
||||||
|
;; matrix-diagonal
|
||||||
|
|
||||||
|
(check-equal? (matrix-diagonal (diagonal-matrix '(1 2 3 4)))
|
||||||
|
(array #[1 2 3 4]))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-diagonal a))))
|
||||||
|
|
||||||
|
;; submatrix
|
||||||
|
|
||||||
|
(check-equal? (submatrix (identity-matrix 8) (:: 2 4) (:: 2 4))
|
||||||
|
(identity-matrix 2))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (submatrix a '(0) '(0)))))
|
||||||
|
|
||||||
|
;; matrix-row
|
||||||
|
|
||||||
|
(let ([a (matrix [[1 2 3] [4 5 6]])])
|
||||||
|
(check-equal? (matrix-row a 0) (row-matrix [1 2 3]))
|
||||||
|
(check-equal? (matrix-row a 1) (row-matrix [4 5 6]))
|
||||||
|
(check-exn exn:fail? (λ () (matrix-row a -1)))
|
||||||
|
(check-exn exn:fail? (λ () (matrix-row a 2))))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-row a 0))))
|
||||||
|
|
||||||
|
;; matrix-col
|
||||||
|
|
||||||
|
(let ([a (matrix [[1 2 3] [4 5 6]])])
|
||||||
|
(check-equal? (matrix-col a 0) (col-matrix [1 4]))
|
||||||
|
(check-equal? (matrix-col a 1) (col-matrix [2 5]))
|
||||||
|
(check-equal? (matrix-col a 2) (col-matrix [3 6]))
|
||||||
|
(check-exn exn:fail? (λ () (matrix-col a -1)))
|
||||||
|
(check-exn exn:fail? (λ () (matrix-col a 3))))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-col a 0))))
|
||||||
|
|
||||||
|
;; matrix-rows
|
||||||
|
|
||||||
|
(check-equal? (matrix-rows (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(list (row-matrix [1 2 3])
|
||||||
|
(row-matrix [4 5 6])))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-rows a))))
|
||||||
|
|
||||||
|
;; matrix-cols
|
||||||
|
|
||||||
|
(check-equal? (matrix-cols (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(list (col-matrix [1 4])
|
||||||
|
(col-matrix [2 5])
|
||||||
|
(col-matrix [3 6])))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-cols a))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Embiggenment (it's a perfectly cromulent word)
|
||||||
|
|
||||||
|
;; matrix-augment
|
||||||
|
|
||||||
|
(let ([a (random-matrix 3 5)])
|
||||||
|
(check-equal? (matrix-augment (list a)) a)
|
||||||
|
(check-equal? (matrix-augment (matrix-cols a)) a))
|
||||||
|
|
||||||
|
(check-equal? (matrix-augment (list (col-matrix [1 2 3]) (col-matrix [4 5 6])))
|
||||||
|
(matrix [[1 4] [2 5] [3 6]]))
|
||||||
|
|
||||||
|
(check-equal? (matrix-augment (list (matrix [[1 2] [4 5]]) (col-matrix [3 6])))
|
||||||
|
(matrix [[1 2 3] [4 5 6]]))
|
||||||
|
|
||||||
|
(check-exn exn:fail? (λ () (matrix-augment (list (matrix [[1 2] [4 5]]) (col-matrix [3])))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-augment '())))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-augment (list a))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-augment (list (matrix [[1]]) a)))))
|
||||||
|
|
||||||
|
;; matrix-stack
|
||||||
|
|
||||||
|
(let ([a (random-matrix 5 3)])
|
||||||
|
(check-equal? (matrix-stack (list a)) a)
|
||||||
|
(check-equal? (matrix-stack (matrix-rows a)) a))
|
||||||
|
|
||||||
|
(check-equal? (matrix-stack (list (row-matrix [1 2 3]) (row-matrix [4 5 6])))
|
||||||
|
(matrix [[1 2 3] [4 5 6]]))
|
||||||
|
|
||||||
|
(check-equal? (matrix-stack (list (matrix [[1 2 3] [4 5 6]]) (row-matrix [7 8 9])))
|
||||||
|
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||||
|
|
||||||
|
(check-exn exn:fail? (λ () (matrix-stack (list (matrix [[1 2 3] [4 5 6]]) (row-matrix [7 8])))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-stack '())))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-stack (list a))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-stack (list (matrix [[1]]) a)))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Inner product space
|
||||||
|
|
||||||
|
;; matrix-norm
|
||||||
|
|
||||||
|
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(sqrt (+ (* 1 1) (* 2 2) (* 3 3) (* 4 4) (* 5 5) (* 6 6))))
|
||||||
|
|
||||||
|
;; Default norm is Frobenius norm
|
||||||
|
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(matrix-norm (matrix [[1 2 3] [4 5 6]]) 2))
|
||||||
|
|
||||||
|
;; This shouldn't overflow (so we check against `flhypot', which also shouldn't overflow)
|
||||||
|
(check-equal? (matrix-norm (matrix [[1e200 1e199]]))
|
||||||
|
(flhypot 1e200 1e199))
|
||||||
|
|
||||||
|
;; Taxicab (Manhattan) norm
|
||||||
|
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) 1)
|
||||||
|
(+ 1 2 3 4 5 6))
|
||||||
|
|
||||||
|
;; Infinity (maximum) norm
|
||||||
|
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) +inf.0)
|
||||||
|
(max 1 2 3 4 5 6))
|
||||||
|
|
||||||
|
;; The actual norm is indistinguishable from floating-point 6
|
||||||
|
(check-equal? (matrix-norm (matrix [[1 2 3] [4 5 6]]) 1000)
|
||||||
|
6.0)
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-norm a 1)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-norm a)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-norm a 5)))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-norm a +inf.0))))
|
||||||
|
|
||||||
|
(check-equal? (matrix-norm (row-matrix [1+1i]))
|
||||||
|
(sqrt 2))
|
||||||
|
|
||||||
|
(check-equal? (matrix-norm (row-matrix [1+1i 2+2i 3+3i]))
|
||||||
|
(matrix-norm (row-matrix [(magnitude 1+1i) (magnitude 2+2i) (magnitude 3+3i)])))
|
||||||
|
|
||||||
|
;; matrix-dot (induces the Frobenius norm)
|
||||||
|
|
||||||
|
(check-equal? (matrix-dot (matrix [[1 -2 3] [-4 5 -6]])
|
||||||
|
(matrix [[-1 2 -3] [4 -5 6]]))
|
||||||
|
(+ (* 1 -1) (* -2 2) (* 3 -3) (* -4 4) (* 5 -5) (* -6 6)))
|
||||||
|
|
||||||
|
(check-equal? (matrix-dot (row-matrix [1 2 3])
|
||||||
|
(row-matrix [0+4i 0-5i 0+6i]))
|
||||||
|
(+ (* 1 0-4i) (* 2 0+5i) (* 3 0-6i)))
|
||||||
|
|
||||||
|
(check-exn exn:fail? (λ () (matrix-dot (random-matrix 1 3) (random-matrix 3 1))))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-dot a (matrix [[1]]))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-dot (matrix [[1]]) a))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Simple operators
|
||||||
|
|
||||||
|
;; matrix-transpose
|
||||||
|
|
||||||
|
(check-equal? (matrix-transpose (matrix [[1 2 3] [4 5 6]]))
|
||||||
|
(matrix [[1 4] [2 5] [3 6]]))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-transpose a))))
|
||||||
|
|
||||||
|
;; matrix-conjugate
|
||||||
|
|
||||||
|
(check-equal? (matrix-conjugate (matrix [[1+i 2-i] [3+i 4-i]]))
|
||||||
|
(matrix [[1-i 2+i] [3-i 4+i]]))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-conjugate a))))
|
||||||
|
|
||||||
|
;; matrix-hermitian
|
||||||
|
|
||||||
|
(let ([a (array-make-rectangular (random-matrix 5 6)
|
||||||
|
(random-matrix 5 6))])
|
||||||
|
(check-equal? (matrix-hermitian a)
|
||||||
|
(matrix-conjugate (matrix-transpose a)))
|
||||||
|
(check-equal? (matrix-hermitian a)
|
||||||
|
(matrix-transpose (matrix-conjugate a))))
|
||||||
|
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-hermitian a))))
|
||||||
|
|
||||||
|
;; matrix-trace
|
||||||
|
|
||||||
|
(check-equal? (matrix-trace (matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||||
|
(+ 1 5 9))
|
||||||
|
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-trace (row-matrix [1 2 3]))))
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-trace (col-matrix [1 2 3]))))
|
||||||
|
(for: ([a (in-list nonmatrices)])
|
||||||
|
(check-exn exn:fail:contract? (λ () (matrix-trace a))))
|
||||||
|
|
||||||
|
;; ===================================================================================================
|
||||||
|
;; Tests not yet converted to rackunit
|
||||||
|
|
||||||
(begin
|
(begin
|
||||||
(begin "matrix-types.rkt"
|
|
||||||
(list
|
|
||||||
'matrix?
|
|
||||||
(matrix? (list*->array '[[1 2] [3 4]] real? ))
|
|
||||||
(not (matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? ))))
|
|
||||||
(list
|
|
||||||
'square-matrix?
|
|
||||||
(square-matrix? (list*->array '[[1 2] [3 4]] real? ))
|
|
||||||
(not (square-matrix? (list*->array '[[1 2 3] [4 5 6]] real? ))))
|
|
||||||
(list
|
|
||||||
'square-matrix-size
|
|
||||||
(= 3 (square-matrix-size (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real?))))
|
|
||||||
(list
|
|
||||||
'matrix=
|
|
||||||
(matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? ))
|
|
||||||
#;(not (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? ))))
|
|
||||||
(list
|
|
||||||
'matrix-shape
|
|
||||||
(let-values ([(m n) (matrix-shape (list*->matrix '[[1 2 3] [4 5 6]]))])
|
|
||||||
(equal? (list m n) '(2 3)))))
|
|
||||||
|
|
||||||
(begin "matrix-constructors.rkt"
|
|
||||||
(list
|
|
||||||
'identity-matrix
|
|
||||||
(equal? (array->list* (identity-matrix 1)) '[[1]])
|
|
||||||
(equal? (array->list* (identity-matrix 2)) '[[1 0] [0 1]])
|
|
||||||
(equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]]))
|
|
||||||
(list
|
|
||||||
'const-matrix
|
|
||||||
(equal? (array->list* (make-matrix 2 3 0)) '((0 0 0) (0 0 0)))
|
|
||||||
(equal? (array->list* (make-matrix 2 3 0.)) '((0. 0. 0.) (0. 0. 0.))))
|
|
||||||
(list
|
|
||||||
'matrix->list
|
|
||||||
(equal? (matrix->list* (list*->matrix '((1 2) (3 4)))) '((1 2) (3 4)))
|
|
||||||
(equal? (matrix->list* (list*->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.))))
|
|
||||||
(list
|
|
||||||
'matrix->vector
|
|
||||||
(equal? (matrix->vector* ((inst vector*->matrix Integer) '#(#(1 2) #(3 4))))
|
|
||||||
'#(#(1 2) #(3 4)))
|
|
||||||
(equal? (matrix->vector* ((inst vector*->matrix Flonum) '#(#(1. 2.) #(3. 4.))))
|
|
||||||
'#(#(1. 2.) #(3. 4.))))
|
|
||||||
(list
|
|
||||||
'matrix-row
|
|
||||||
(equal? (matrix-row (identity-matrix 3) 0) (list*->matrix '[[1 0 0]]))
|
|
||||||
(equal? (matrix-row (identity-matrix 3) 1) (list*->matrix '[[0 1 0]]))
|
|
||||||
(equal? (matrix-row (identity-matrix 3) 2) (list*->matrix '[[0 0 1]])))
|
|
||||||
(list
|
|
||||||
'matrix-col
|
|
||||||
(equal? (matrix-col (identity-matrix 3) 0) (list*->matrix '[[1] [0] [0]]))
|
|
||||||
(equal? (matrix-col (identity-matrix 3) 1) (list*->matrix '[[0] [1] [0]]))
|
|
||||||
(equal? (matrix-col (identity-matrix 3) 2) (list*->matrix '[[0] [0] [1]])))
|
|
||||||
(list
|
|
||||||
'submatrix
|
|
||||||
(equal? (submatrix (identity-matrix 3)
|
|
||||||
(in-range 0 1) (in-range 0 2)) (list*->matrix '[[1 0]]))
|
|
||||||
(equal? (submatrix (identity-matrix 3)
|
|
||||||
(in-range 0 2) (in-range 0 3)) (list*->matrix '[[1 0 0] [0 1 0]]))))
|
|
||||||
|
|
||||||
(begin
|
|
||||||
"matrix-pointwise.rkt"
|
|
||||||
(let ()
|
|
||||||
(define A (list*->matrix '[[1 2] [3 4]]))
|
|
||||||
(define ~A (list*->matrix '[[-1 -2] [-3 -4]]))
|
|
||||||
(define B (list*->matrix '[[5 6] [7 8]]))
|
|
||||||
(define A+B (list*->matrix '[[6 8] [10 12]]))
|
|
||||||
(define A-B (list*->matrix '[[-4 -4] [-4 -4]]))
|
|
||||||
(list 'matrix+ (equal? (matrix+ A B) A+B))
|
|
||||||
(list 'matrix-
|
|
||||||
(equal? (matrix- A B) A-B)
|
|
||||||
(equal? (matrix- A) ~A))))
|
|
||||||
|
|
||||||
(begin
|
|
||||||
"matrix-expt.rkt"
|
|
||||||
(let ()
|
|
||||||
(define A (list*->matrix '[[1 2] [3 4]]))
|
|
||||||
(list
|
|
||||||
'matrix-expt
|
|
||||||
(equal? (matrix-expt A 0) (identity-matrix 2))
|
|
||||||
(equal? (matrix-expt A 1) A)
|
|
||||||
(equal? (matrix-expt A 2) (list*->matrix '[[7 10] [15 22]]))
|
|
||||||
(equal? (matrix-expt A 3) (list*->matrix '[[37 54] [81 118]]))
|
|
||||||
(equal? (matrix-expt A 8) (list*->matrix '[[165751 241570] [362355 528106]])))))
|
|
||||||
|
|
||||||
(begin
|
(begin
|
||||||
"matrix-operations.rkt"
|
"matrix-operations.rkt"
|
||||||
(list 'vandermonde-matrix
|
|
||||||
(equal? (vandermonde-matrix '(1 2 3) 5)
|
|
||||||
(list*->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]])))
|
|
||||||
#;
|
|
||||||
(list 'in-column
|
|
||||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 0)])
|
|
||||||
x)
|
|
||||||
'(1 3))
|
|
||||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 1)])
|
|
||||||
x)
|
|
||||||
'(2 4))
|
|
||||||
(equal? (for/list: : (Listof Number) ([x (in-column (col-matrix [5 2 3]))]) x)
|
|
||||||
'(5 2 3)))
|
|
||||||
#;
|
|
||||||
(list 'in-row
|
|
||||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 0)])
|
|
||||||
x)
|
|
||||||
'(1 2))
|
|
||||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 1)])
|
|
||||||
x)
|
|
||||||
'(3 4)))
|
|
||||||
(list 'for/matrix:
|
|
||||||
(equal? (for/matrix: : Number 2 4 ([i (in-naturals)]) i)
|
|
||||||
(matrix [[0 1 2 3] [4 5 6 7]]))
|
|
||||||
(equal? (for/matrix: : Number 2 4 #:column ([i (in-naturals)]) i)
|
|
||||||
(matrix [[0 2 4 6] [1 3 5 7]]))
|
|
||||||
(equal? (for/matrix: : Number 3 3 ([i (in-range 10 100)]) i)
|
|
||||||
(matrix [[10 11 12] [13 14 15] [16 17 18]])))
|
|
||||||
(list 'for*/matrix:
|
|
||||||
(equal? (for*/matrix: : Number 3 3 ([i (in-range 3)] [j (in-range 3)]) (+ (* i 10) j))
|
|
||||||
(matrix [[0 1 2] [10 11 12] [20 21 22]])))
|
|
||||||
(list 'matrix-block-diagonal
|
|
||||||
(equal? (block-diagonal-matrix (list (matrix [[1 2] [3 4]]) (matrix [[5 6 7]])))
|
|
||||||
(list*->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]])))
|
|
||||||
(list 'matrix-augment
|
|
||||||
(equal? (matrix-augment (list (col-matrix [1 2 3])
|
|
||||||
(col-matrix [4 5 6])
|
|
||||||
(col-matrix [7 8 9])))
|
|
||||||
(matrix [[1 4 7] [2 5 8] [3 6 9]])))
|
|
||||||
(list 'matrix-stack
|
|
||||||
(equal? (matrix-stack (list (col-matrix [1 2 3])
|
|
||||||
(col-matrix [4 5 6])
|
|
||||||
(col-matrix [7 8 9])))
|
|
||||||
(col-matrix [1 2 3 4 5 6 7 8 9])))
|
|
||||||
#;
|
#;
|
||||||
(list 'column-dimension
|
(list 'column-dimension
|
||||||
(= (column-dimension #(1 2 3)) 3)
|
(= (column-dimension #(1 2 3)) 3)
|
||||||
|
@ -206,8 +687,6 @@
|
||||||
(+ (* 1 4) (* 2 5) (* 3 6)))
|
(+ (* 1 4) (* 2 5) (* 3 6)))
|
||||||
(= (column-dot (col-matrix [+3i +4i]) (col-matrix [+3i +4i]))
|
(= (column-dot (col-matrix [+3i +4i]) (col-matrix [+3i +4i]))
|
||||||
25)))
|
25)))
|
||||||
(list 'matrix-trace
|
|
||||||
(equal? (matrix-trace (vector->matrix 2 2 #(1 2 3 4))) 5))
|
|
||||||
(let ([matrix: vector->matrix])
|
(let ([matrix: vector->matrix])
|
||||||
(list 'column-norm
|
(list 'column-norm
|
||||||
(= (column-norm (col-matrix [2 4])) (sqrt 20))))
|
(= (column-norm (col-matrix [2 4])) (sqrt 20))))
|
||||||
|
@ -286,15 +765,6 @@
|
||||||
[9 10 -11 12]
|
[9 10 -11 12]
|
||||||
[13 14 15 16]]))
|
[13 14 15 16]]))
|
||||||
5280))
|
5280))
|
||||||
(list 'matrix-scale
|
|
||||||
(equal? (matrix-scale (list*->matrix '[[1 2] [3 4]]) 2)
|
|
||||||
(list*->matrix '[[2 4] [6 8]])))
|
|
||||||
(list 'matrix-transpose
|
|
||||||
(equal? (matrix-transpose (list*->matrix '[[1 2] [3 4]]))
|
|
||||||
(list*->matrix '[[1 3] [2 4]])))
|
|
||||||
(list 'matrix-hermitian
|
|
||||||
(equal? (matrix-hermitian (list*->matrix '[[1+i 2-i] [3+i 4-i]]))
|
|
||||||
(list*->matrix '[[1-i 3-i] [2+i 4+i]])))
|
|
||||||
(let ()
|
(let ()
|
||||||
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
|
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
|
||||||
(define (gauss-eliminate M u? p?)
|
(define (gauss-eliminate M u? p?)
|
||||||
|
@ -366,7 +836,8 @@
|
||||||
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0)
|
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0)
|
||||||
(equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1)
|
(equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1)
|
||||||
(equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0))
|
(equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0))
|
||||||
#;(let ()
|
#;
|
||||||
|
(let ()
|
||||||
(define-values (c1 n1)
|
(define-values (c1 n1)
|
||||||
(matrix-column+null-space (list*rix '[[0 0] [0 0]])))
|
(matrix-column+null-space (list*rix '[[0 0] [0 0]])))
|
||||||
(define-values (c2 n2)
|
(define-values (c2 n2)
|
||||||
|
@ -384,24 +855,8 @@
|
||||||
(list*->matrix '[[2] [5]])))
|
(list*->matrix '[[2] [5]])))
|
||||||
(equal? n3 '()))))
|
(equal? n3 '()))))
|
||||||
|
|
||||||
|
#;
|
||||||
|
(begin
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#;(begin
|
|
||||||
"matrix-multiply.rkt"
|
|
||||||
(list 'matrix*
|
|
||||||
(let ()
|
|
||||||
(define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6] [7 8]] '[[19 22] [43 50]]))
|
|
||||||
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))
|
|
||||||
(let ()
|
|
||||||
(define-values (A B AB) (values '[[1 2] [3 4]]
|
|
||||||
'[[5 6 7] [8 9 10]]
|
|
||||||
'[[21 24 27] [47 54 61]]))
|
|
||||||
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))))
|
|
||||||
#;(begin
|
|
||||||
"matrix-2d.rkt"
|
"matrix-2d.rkt"
|
||||||
(let ()
|
(let ()
|
||||||
(define e1 (matrix-transpose (vector->matrix #(#( 1 0)))))
|
(define e1 (matrix-transpose (vector->matrix #(#( 1 0)))))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user