Reviewing and refactoring `math/matrix', part 1

* Finally added `array-axis-expand' as a dual for `array-axis-reduce'
  in order to implement `vandermonde-matrix' elegantly

* Better, shorter matrix multiply; reworked all matrix arithmetic

* Split "matrix-operations.rkt" into at least 5 parts:
 * "matrix-operations.rkt"
 * "matrix-basic.rkt"
 * "matrix-comprehension.rkt"
 * "matrix-sequences.rkt"
 * "matrix-column.rkt"

Added "matrix-constructors.rkt"

Added `matrix', `row-matrix', and `col-matrix' macros

A lot of other little changes

Currently, `in-row' and `in-column' are broken. I intend to implement
them in a way that makes them work in untyped and Typed Racket.
This commit is contained in:
Neil Toronto 2012-12-20 12:16:48 -07:00
parent c2468f1f9a
commit 155ec7dc41
25 changed files with 1850 additions and 1245 deletions

View File

@ -9,6 +9,7 @@
"private/array/array-convert.rkt"
"private/array/array-fold.rkt"
"private/array/array-special-folds.rkt"
"private/array/array-unfold.rkt"
"private/array/array-print.rkt"
"private/array/array-fft.rkt"
"private/array/array-syntax.rkt"
@ -36,6 +37,7 @@
"private/array/array-convert.rkt"
"private/array/array-fold.rkt"
"private/array/array-special-folds.rkt"
"private/array/array-unfold.rkt"
"private/array/array-print.rkt"
"private/array/array-syntax.rkt"
"private/array/array-fft.rkt"

View File

@ -9,7 +9,7 @@
"number-theory.rkt"
"vector.rkt"
"array.rkt"
;"matrix.rkt"
"matrix.rkt"
"utils.rkt")
(provide (all-from-out
@ -22,5 +22,5 @@
"number-theory.rkt"
"vector.rkt"
"array.rkt"
;"matrix.rkt"
"matrix.rkt"
"utils.rkt"))

View File

@ -1,21 +1,22 @@
#lang typed/racket/base
(require "private/matrix/matrix-pointwise.rkt"
"private/matrix/matrix-multiply.rkt"
(require "private/matrix/matrix-arithmetic.rkt"
"private/matrix/matrix-constructors.rkt"
"private/matrix/matrix-basic.rkt"
"private/matrix/matrix-operations.rkt"
"private/matrix/matrix-comprehension.rkt"
"private/matrix/matrix-sequences.rkt"
"private/matrix/matrix-expt.rkt"
"private/matrix/matrix-types.rkt"
"private/matrix/utils.rkt")
(provide (all-from-out "private/matrix/matrix-pointwise.rkt"
"private/matrix/matrix-multiply.rkt"
(provide (all-from-out
"private/matrix/matrix-arithmetic.rkt"
"private/matrix/matrix-constructors.rkt"
"private/matrix/matrix-basic.rkt"
"private/matrix/matrix-operations.rkt"
"private/matrix/matrix-comprehension.rkt"
"private/matrix/matrix-sequences.rkt"
"private/matrix/matrix-expt.rkt"
"private/matrix/matrix-types.rkt")
;; From "utils.rkt"
array-matrix?
;; would also like matrix? : (Any -> Boolean : (Array Any)), but we can't have one until we
;; can define array? : (Any -> Boolean : (Array Any)), and there's been trouble with that
)
matrix?)

View File

@ -0,0 +1,51 @@
#lang typed/racket/base
(require racket/fixnum
"array-struct.rkt"
"array-pointwise.rkt"
"array-fold.rkt"
"utils.rkt"
"../unsafe.rkt")
(provide unsafe-array-axis-expand
array-axis-expand
list-array->array)
(: check-array-axis (All (A) (Symbol (Array A) Integer -> Index)))
(define (check-array-axis name arr k)
(define dims (array-dims arr))
(cond [(fx= dims 0) (raise-argument-error name "Array with at least one axis" 0 arr k)]
[(or (k . < . 0) (k . > . dims))
(raise-argument-error name (format "Index <= ~a" dims) 1 arr k)]
[else k]))
(: unsafe-array-axis-expand (All (A B) ((Array A) Index Index (A Index -> B) -> (Array B))))
(define (unsafe-array-axis-expand arr k dk f)
(define ds (array-shape arr))
(define new-ds (unsafe-vector-insert ds k dk))
(define proc (unsafe-array-proc arr))
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define jk (unsafe-vector-ref js k))
(f (proc (unsafe-vector-remove js k)) jk))))
(: array-axis-expand (All (A B) ((Array A) Integer Integer (A Index -> B) -> (Array B))))
(define (array-axis-expand arr k dk f)
(let ([k (check-array-axis 'array-axis-expand arr k)])
(cond [(not (index? dk)) (raise-argument-error 'array-axis-expand "Index" 2 arr k dk f)]
[else (unsafe-array-axis-expand arr k dk f)])))
;; ===================================================================================================
;; Specific unfolds/expansions
(: list-array->array (All (A) (case-> ((Array (Listof A)) -> (Array A))
((Array (Listof A)) Integer -> (Array A)))))
(define (list-array->array arr [k 0])
(define dims (array-dims arr))
(cond [(and (k . >= . 0) (k . <= . dims))
(let ([arr (array-strict (array-map (inst list->vector A) arr))])
;(define dks (remove-duplicates (array->list (array-map vector-length arr))))
(define dk (array-all-min (array-map vector-length arr)))
(unsafe-array-axis-expand arr k dk (inst unsafe-vector-ref A)))]
[else
(raise-argument-error 'list-array->array (format "Index <= ~a" dims) 1 arr k)]))

View File

@ -23,8 +23,7 @@
mutable-array-copy
mutable-array
;; Conversion
array->mutable-array
flat-vector->matrix)
array->mutable-array)
(define-syntax (mutable-array stx)
(syntax-parse stx #:literals (:)

View File

@ -21,23 +21,6 @@
(raise-argument-error name (format "Index < ~a" dims) 1 arr k)]
[else k]))
(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B))))
(define (array-axis-reduce arr k f)
(let ([k (check-array-axis 'array-axis-reduce arr k)])
(define ds (array-shape arr))
(define dk (unsafe-vector-ref ds k))
(define new-ds (unsafe-vector-remove ds k))
(define proc (unsafe-array-proc arr))
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define old-js (unsafe-vector-insert js k 0))
(f dk (λ: ([jk : Integer])
(cond [(or (jk . < . 0) (jk . >= . dk))
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
[else
(unsafe-vector-set! old-js k jk)
(proc (vector-copy-all old-js))])))))))
(: unsafe-array-axis-reduce (All (A B) ((Array A) Index (Index (Index -> A) -> B) -> (Array B))))
(begin-encourage-inline
(define (unsafe-array-axis-reduce arr k f)
@ -52,6 +35,19 @@
(unsafe-vector-set! old-js k jk)
(proc old-js)))))))
(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B))))
(define (array-axis-reduce arr k f)
(let ([k (check-array-axis 'array-axis-reduce arr k)])
(unsafe-array-axis-reduce
arr k
(λ: ([dk : Index] [proc : (Index -> A)])
(: safe-proc (Integer -> A))
(define (safe-proc jk)
(cond [(or (jk . < . 0) (jk . >= . dk))
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
[else (proc jk)]))
(f dk safe-proc)))))
(: array-axis-fold/init (All (A B) ((Array A) Integer (A B -> B) B -> (Array B))))
(define (array-axis-fold/init arr k f init)
(let ([k (check-array-axis 'array-axis-fold arr k)])

View File

@ -41,10 +41,6 @@
(define (mutable-array-copy arr)
(unsafe-vector->array (array-shape arr) (vector-copy-all (mutable-array-data arr))))
(: flat-vector->matrix : (All (A) (Index Index (Vectorof A) -> (Array A))))
(define (flat-vector->matrix m n v)
(vector->array (vector m n) v))
;; ===================================================================================================
;; Conversions

View File

@ -1,5 +1,8 @@
#lang typed/racket/base
(require math/matrix)
(require math/array
"matrix-types.rkt"
"matrix-constructors.rkt")
(provide matrix-2d-rotation
matrix-2d-scaling
@ -11,34 +14,30 @@
; Transformations from:
; http://en.wikipedia.org/wiki/Transformation_matrix
(: matrix-2d-rotation : Real -> (Matrix Number))
(: matrix-2d-rotation : Real -> (Matrix Real))
; matrix representing rotation θ radians counter clockwise
(define (matrix-2d-rotation θ)
(define cosθ (cos θ))
(define sinθ (sin θ))
(matrix/dim 2 2
cosθ (- sinθ)
sinθ cosθ))
(matrix [[cosθ (- sinθ)]
[sinθ cosθ]]))
(: matrix-2d-scaling : Real Real -> (Matrix Number))
(: matrix-2d-scaling : Real Real -> (Matrix Real))
(define (matrix-2d-scaling sx sy)
(matrix/dim 2 2
sx 0
0 sy))
(matrix [[sx 0]
[0 sy]]))
(: matrix-2d-shear-x : Real -> (Matrix Number))
(: matrix-2d-shear-x : Real -> (Matrix Real))
(define (matrix-2d-shear-x k)
(matrix/dim 2 2
1 k
0 1))
(matrix [[1 k]
[0 1]]))
(: matrix-2d-shear-y : Real -> (Matrix Number))
(: matrix-2d-shear-y : Real -> (Matrix Real))
(define (matrix-2d-shear-y k)
(matrix/dim 2 2
1 0
k 1))
(matrix [[1 0]
[k 1]]))
(: matrix-2d-reflection : Real Real -> (Matrix Number))
(: matrix-2d-reflection : Real Real -> (Matrix Real))
(define (matrix-2d-reflection a b)
; reflection about the line through (0,0) and (a,b)
(define a2 (* a a))
@ -46,11 +45,10 @@
(define 2ab (* 2 a b))
(define norm2 (+ a2 b2))
(define 2ab/norm2 (/ 2ab norm2))
(matrix/dim 2 2
(/ (- a2 b2) norm2) 2ab/norm2
2ab/norm2 (/ (- b2 a2) norm2)))
(matrix [[(/ (- a2 b2) norm2) 2ab/norm2]
[2ab/norm2 (/ (- b2 a2) norm2)]]))
(: matrix-2d-orthogonal-projection : Real Real -> (Matrix Number))
(: matrix-2d-orthogonal-projection : Real Real -> (Matrix Real))
; orthogonal projection onto the line through (0,0) and (a,b)
(define (matrix-2d-orthogonal-projection a b)
(define a2 (* a a))
@ -58,6 +56,5 @@
(define ab (* a b))
(define norm2 (+ a2 b2))
(define ab/norm2 (/ ab norm2))
(matrix/dim 2 2
(/ a2 norm2) ab/norm2
ab/norm2 (/ b2 norm2)))
(matrix [[(/ a2 norm2) ab/norm2]
[ab/norm2 (/ b2 norm2)]]))

View File

@ -0,0 +1,39 @@
#lang racket/base
(require (for-syntax racket/base)
typed/untyped-utils
(prefix-in typed: "typed-matrix-arithmetic.rkt")
(rename-in "untyped-matrix-arithmetic.rkt"
[matrix-map untyped:matrix-map]))
(define-typed/untyped-identifier matrix-map
typed:matrix-map untyped:matrix-map)
(define-syntax (define/inline-macro stx)
(syntax-case stx ()
[(_ name pat inline-fun typed:fun)
(syntax/loc stx
(define-syntax (name inner-stx)
(syntax-case inner-stx ()
[(_ . pat) (syntax/loc inner-stx (inline-fun . pat))]
[(_ . es) (syntax/loc inner-stx (typed:fun . es))]
[_ (syntax/loc inner-stx typed:fun)])))]))
(define/inline-macro matrix* (a . as) inline-matrix* typed:matrix*)
(define/inline-macro matrix+ (a . as) inline-matrix+ typed:matrix+)
(define/inline-macro matrix- (a . as) inline-matrix- typed:matrix-)
(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale)
(provide
;; Equality
(rename-out [typed:matrix= matrix=])
;; Mapping
inline-matrix-map
matrix-map
;; Multiplication
matrix*
;; Pointwise operators
matrix+
matrix-
matrix-scale
(rename-out [typed:matrix-sum matrix-sum]))

View File

@ -0,0 +1,206 @@
#lang typed/racket
(require racket/list
racket/fixnum
math/array
math/flonum
"matrix-types.rkt"
"utils.rkt"
"../unsafe.rkt")
(provide
;; Extraction
matrix-ref
matrix-diagonal
submatrix
matrix-row
matrix-col
matrix-rows
matrix-cols
;; Predicates
zero-matrix?
;; Embiggenment
matrix-augment
matrix-stack
;; Norm and inner product
matrix-norm
matrix-dot
;; Simple operators
matrix-transpose
matrix-conjugate
matrix-hermitian
matrix-trace)
;; ===================================================================================================
;; Extraction
(: matrix-ref (All (A) (Array A) Integer Integer -> A))
(define (matrix-ref a i j)
(define-values (m n) (matrix-shape a))
(cond [(or (i . < . 0) (i . >= . m))
(raise-argument-error 'matrix-ref (format "Index < ~a" m) 1 a i j)]
[(or (j . < . 0) (j . >= . n))
(raise-argument-error 'matrix-ref (format "Index < ~a" n) 2 a i j)]
[else
(unsafe-array-ref a ((inst vector Index) i j))]))
(: matrix-diagonal (All (A) ((Array A) -> (Array A))))
(define (matrix-diagonal a)
(define m (square-matrix-size a))
(define proc (unsafe-array-proc a))
(unsafe-build-array
((inst vector Index) m)
(λ: ([js : Indexes])
(define i (unsafe-vector-ref js 0))
(proc ((inst vector Index) i i)))))
(: submatrix (All (A) (Matrix A) Slice-Spec Slice-Spec -> (Matrix A)))
(define (submatrix a row-range col-range)
(array-slice-ref (ensure-matrix 'submatrix a) (list row-range col-range)))
(: matrix-row (All (A) (Matrix A) Integer -> (Matrix A)))
(define (matrix-row a i)
(define-values (m n) (matrix-shape a))
(cond [(or (i . < . 0) (i . >= . m))
(raise-argument-error 'matrix-row (format "Index < ~a" m) 1 a i)]
[else
(define proc (unsafe-array-proc a))
(unsafe-build-array
((inst vector Index) 1 n)
(λ: ([ij : Indexes])
(unsafe-vector-set! ij 0 i)
(define res (proc ij))
(unsafe-vector-set! ij 0 0)
res))]))
(: matrix-col (All (A) (Matrix A) Index -> (Matrix A)))
(define (matrix-col a j)
(define-values (m n) (matrix-shape a))
(cond [(or (j . < . 0) (j . >= . n))
(raise-argument-error 'matrix-row (format "Index < ~a" n) 1 a j)]
[else
(define proc (unsafe-array-proc a))
(unsafe-build-array
((inst vector Index) m 1)
(λ: ([ij : Indexes])
(unsafe-vector-set! ij 1 j)
(define res (proc ij))
(unsafe-vector-set! ij 1 0)
res))]))
(: matrix-rows (All (A) (Array A) -> (Listof (Array A))))
(define (matrix-rows a)
(array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0))
(: matrix-cols (All (A) (Array A) -> (Listof (Array A))))
(define (matrix-cols a)
(array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1))
;; ===================================================================================================
;; Predicates
(: zero-matrix? ((Array Number) -> Boolean))
(define (zero-matrix? a)
(array-all-and (array-map zero? a)))
;; ===================================================================================================
;; Embiggenment (this is a perfectly cromulent word)
(: matrix-augment (All (A) (Listof (Array A)) -> (Array A)))
(define (matrix-augment as)
(cond [(empty? as) (raise-argument-error 'matrix-augment "nonempty List" as)]
[else
(define m (matrix-num-rows (first as)))
(cond [(andmap (λ: ([a : (Array A)]) (= m (matrix-num-rows a))) (rest as))
(array-append* as 1)]
[else
(error 'matrix-augment
"matrices must have the same number of rows; given ~a"
(format-matrices/error as))])]))
(: matrix-stack (All (A) (Listof (Array A)) -> (Array A)))
(define (matrix-stack as)
(cond [(empty? as) (raise-argument-error 'matrix-stack "nonempty List" as)]
[else
(define n (matrix-num-cols (first as)))
(cond [(andmap (λ: ([a : (Array A)]) (= n (matrix-num-cols a))) (rest as))
(array-append* as 0)]
[else
(error 'matrix-stack
"matrices must have the same number of columns; given ~a"
(format-matrices/error as))])]))
;; ===================================================================================================
;; Matrix norms and Frobenius inner product
(: maximum-norm ((Array Number) -> Real))
(define (maximum-norm a)
(array-all-max (array-magnitude a)))
(: taxicab-norm ((Array Number) -> Real))
(define (taxicab-norm a)
(array-all-sum (array-magnitude a)))
(: frobenius-norm ((Array Number) -> Real))
(define (frobenius-norm a)
(let ([a (array-strict (array-magnitude a))])
;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max a))
(cond [(and (rational? mx) (positive? mx))
(assert
(* mx (sqrt (array-all-sum (inline-array-map (λ: ([x : Number]) (sqr (/ x mx))) a))))
real?)]
[else mx])))
(: p-norm ((Array Number) Positive-Real -> Real))
(define (p-norm a p)
(let ([a (array-strict (array-magnitude a))])
;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max a))
(cond [(and (rational? mx) (positive? mx))
(assert
(* mx (expt (array-all-sum (inline-array-map (λ: ([x : Real]) (expt (/ x mx) p)) a))
(/ p)))
real?)]
[else mx])))
(: matrix-norm (case-> ((Array Number) -> Real)
((Array Number) Real -> Real)))
;; Computes the p norm of a matrix
(define (matrix-norm a [p 2])
(cond [(not (matrix? a)) (raise-argument-error 'matrix-norm "matrix?" 0 a p)]
[(p . = . 2) (frobenius-norm a)]
[(p . = . +inf.0) (maximum-norm a)]
[(p . = . 1) (taxicab-norm a)]
[(p . > . 1) (p-norm a p)]
[else (raise-argument-error 'matrix-norm "Real >= 1" 1 a p)]))
(: matrix-dot (case-> ((Array Real) (Array Real) -> Real)
((Array Number) (Array Number) -> Number)))
;; Computes the Frobenius inner product of two matrices
(define (matrix-dot a b)
(cond [(not (matrix? a)) (raise-argument-error 'matrix-dot "matrix?" 0 a b)]
[(not (matrix? b)) (raise-argument-error 'matrix-dot "matrix?" 1 a b)]
[else (array-all-sum (array* a (array-conjugate b)))]))
;; ===================================================================================================
;; Operators
(: matrix-transpose (All (A) (Array A) -> (Array A)))
(define (matrix-transpose a)
(array-axis-swap (ensure-matrix 'matrix-transpose a) 0 1))
(: matrix-conjugate (case-> ((Array Real) -> (Array Real))
((Array Number) -> (Array Number))))
(define (matrix-conjugate a)
(array-conjugate (ensure-matrix 'matrix-conjugate a)))
(: matrix-hermitian (case-> ((Array Real) -> (Array Real))
((Array Number) -> (Array Number))))
(define (matrix-hermitian a)
(array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1))
(: matrix-trace (case-> ((Array Real) -> Real)
((Array Number) -> Number)))
(define (matrix-trace a)
(array-all-sum (matrix-diagonal a)))

View File

@ -0,0 +1,126 @@
#lang typed/racket/base
(require math/array
math/base
"matrix-types.rkt"
"matrix-constructors.rkt"
"matrix-arithmetic.rkt"
"../unsafe.rkt")
(provide unit-column
column-height
unsafe-column->vector
column-scale
column+
column-dot
column-norm
column-project
column-project/unit
column-normalize)
(: unit-column : Integer Integer -> (Result-Column Number))
(define (unit-column m i)
(cond
[(and (index? m) (index? i))
(define v (make-vector m 0))
(if (< i m)
(vector-set! v i 1)
(error 'unit-vector "dimension must be largest"))
(vector->matrix m 1 v)]
[else
(error 'unit-vector "expected two indices")]))
(: column-height : (Column Number) -> Index)
(define (column-height v)
(if (vector? v)
(vector-length v)
(matrix-num-rows v)))
(: unsafe-column->vector : (Column Number) -> (Vectorof Number))
(define (unsafe-column->vector v)
(if (vector? v) v
(let ()
(define-values (m n) (matrix-shape v))
(if (= n 1)
(mutable-array-data (array->mutable-array v))
(error 'unsafe-column->vector
"expected a column (vector or mx1 matrix), got ~a" v)))))
(: column-scale : (Column Number) Number -> (Result-Column Number))
(define (column-scale a s)
(if (vector? a)
(let*: ([n (vector-length a)]
[v : (Vectorof Number) (make-vector n 0)])
(for: ([i (in-range 0 n)]
[x : Number (in-vector a)])
(vector-set! v i (* s x)))
(->col-matrix v))
(matrix-scale a s)))
(: column+ : (Column Number) (Column Number) -> (Result-Column Number))
(define (column+ v w)
(cond [(and (vector? v) (vector? w))
(let ([n (vector-length v)]
[m (vector-length w)])
(unless (= m n)
(error 'column+
"expected two column vectors of the same length, got ~a and ~a" v w))
(define: v+w : (Vectorof Number) (make-vector n 0))
(for: ([i (in-range 0 n)]
[x : Number (in-vector v)]
[y : Number (in-vector w)])
(vector-set! v+w i (+ x y)))
(->col-matrix v+w))]
[else
(unless (= (column-height v) (column-height w))
(error 'column+
"expected two column vectors of the same length, got ~a and ~a" v w))
(array+ (->col-matrix v) (->col-matrix w))]))
(: column-dot : (Column Number) (Column Number) -> Number)
(define (column-dot c d)
(define v (unsafe-column->vector c))
(define w (unsafe-column->vector d))
(define m (column-height v))
(define s (column-height w))
(cond
[(not (= m s)) (error 'column-dot
"expected two mx1 matrices with same number of rows, got ~a and ~a"
c d)]
[else
(for/sum: : Number ([i (in-range 0 m)])
(assert i index?)
; Note: If d is a vector of reals,
; then the conjugate is a no-op
(* (unsafe-vector-ref v i)
(conjugate (unsafe-vector-ref w i))))]))
(: column-norm : (Column Number) -> Real)
(define (column-norm v)
(define norm (sqrt (column-dot v v)))
(assert norm real?))
(: column-project : (Column Number) (Column Number) -> (Result-Column Number))
; (column-project v w)
; Return the projection og vector v on vector w.
(define (column-project v w)
(let ([w.w (column-dot w w)])
(if (zero? w.w)
(error 'column-project "projection on the zero vector not defined")
(matrix-scale (->col-matrix w) (/ (column-dot v w) w.w)))))
(: column-project/unit : (Column Number) (Column Number) -> (Result-Column Number))
; (column-project-on-unit v w)
; Return the projection og vector v on a unit vector w.
(define (column-project/unit v w)
(matrix-scale (->col-matrix w) (column-dot v w)))
(: column-normalize : (Column Number) -> (Result-Column Number))
; (column-vector-normalize v)
; Return unit vector with same direction as v.
; If v is the zero vector, the zero vector is returned.
(define (column-normalize w)
(let ([norm (column-norm w)]
[w (->col-matrix w)])
(cond [(zero? norm) w]
[else (matrix-scale w (/ norm))])))

View File

@ -0,0 +1,140 @@
#lang racket
(require math/array
typed/racket/base
"matrix-types.rkt"
"matrix-constructors.rkt")
(provide for/matrix
for*/matrix
for/matrix:
for*/matrix:)
;;; COMPREHENSIONS
; (for/matrix m n (clause ...) . defs+exprs)
; Return an m x n matrix with elements from the last expr.
; The first n values produced becomes the first row.
; The next n values becomes the second row and so on.
; The bindings in clauses run in parallel.
(define-syntax (for/matrix stx)
(syntax-case stx ()
[(_ m-expr n-expr (clause ...) . defs+exprs)
(syntax/loc stx
(let ([m m-expr] [n n-expr])
(define flat-vector
(for/vector #:length (* m n)
(clause ...) . defs+exprs))
(vector->matrix m n flat-vector)))]))
; (for*/matrix m n (clause ...) . defs+exprs)
; Return an m x n matrix with elements from the last expr.
; The first n values produced becomes the first row.
; The next n values becomes the second row and so on.
; The bindings in clauses run nested.
; (for*/matrix m n #:column (clause ...) . defs+exprs)
; Return an m x n matrix with elements from the last expr.
; The first m values produced becomes the first column.
; The next m values becomes the second column and so on.
; The bindings in clauses run nested.
(define-syntax (for*/matrix stx)
(syntax-case stx ()
[(_ m-expr n-expr #:column (clause ...) . defs+exprs)
(syntax/loc stx
(let* ([m m-expr]
[n n-expr]
[v (make-vector (* m n) 0)]
[w (for*/vector #:length (* m n) (clause ...) . defs+exprs)])
(for* ([i (in-range m)] [j (in-range n)])
(vector-set! v (+ (* i n) j)
(vector-ref w (+ (* j m) i))))
(vector->matrix m n v)))]
[(_ m-expr n-expr (clause ...) . defs+exprs)
(syntax/loc stx
(let ([m m-expr] [n n-expr])
(vector->matrix
m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))]))
(define-syntax (for/column: stx)
(syntax-case stx ()
[(_ : type m-expr (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: flat-vector : (Vectorof Number) (make-vector m 0))
(for: ([i (in-range m)] for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! flat-vector i x))
(vector->col-matrix flat-vector)))]))
(define-syntax (for/matrix: stx)
(syntax-case stx ()
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(define: k : Index 0)
(for: ([i (in-range m*n)] for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
(set! k (assert (+ k 1) index?)))
(vector->matrix m n v)))]
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(for: ([i (in-range m*n)] for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v i x))
(vector->matrix m n v)))]))
(define-syntax (for*/matrix: stx)
(syntax-case stx ()
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(define: k : Index 0)
(for*: (for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
(set! k (assert (+ k 1) index?)))
(vector->matrix m n v)))]
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(define: i : Index 0)
(for*: (for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v i x)
(set! i (assert (+ i 1) index?)))
(vector->matrix m n v)))]))
#;
(module* test #f
(require rackunit)
; "matrix-sequences.rkt"
; These work in racket not in typed racket
(check-equal? (matrix->list* (for*/matrix 2 3 ([i 2] [j 3]) (+ i j)))
'[[0 1 2] [1 2 3]])
(check-equal? (matrix->list* (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j)))
'[[0 2 2] [1 1 3]])
(check-equal? (matrix->list* (for*/matrix 2 2 #:column ([i 4]) i))
'[[0 2] [1 3]])
(check-equal? (matrix->list* (for/matrix 2 2 ([i 4]) i))
'[[0 1] [2 3]])
(check-equal? (matrix->list* (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j)))
'[[6 8 10] [12 14 16]]))

View File

@ -1,62 +1,376 @@
#lang typed/racket
#lang racket/base
(require math/array
"../unsafe.rkt"
"matrix-types.rkt")
(provide identity-matrix flidentity-matrix
matrix->list list->matrix fllist->matrix
matrix->vector vector->matrix flvector->matrix
flat-vector->matrix
(provide
;; Constructors
identity-matrix
make-matrix
matrix-row
matrix-column
submatrix)
build-matrix
diagonal-matrix/zero
diagonal-matrix
block-diagonal-matrix/zero
block-diagonal-matrix
vandermonde-matrix
;; Basic conversion
list->matrix
matrix->list
vector->matrix
matrix->vector
->row-matrix
->col-matrix
;; Nested conversion
list*->matrix
matrix->list*
vector*->matrix
matrix->vector*
;; Syntax
matrix
row-matrix
col-matrix)
(: identity-matrix (Integer -> (Matrix Real)))
(module typed-defs typed/racket/base
(require racket/fixnum
racket/list
racket/vector
math/array
"../array/utils.rkt"
"matrix-types.rkt"
"utils.rkt"
"../unsafe.rkt")
(provide (all-defined-out))
;; =================================================================================================
;; Constructors
(: identity-matrix (Integer -> (Matrix (U 0 1))))
(define (identity-matrix m) (diagonal-array 2 m 1 0))
(: flidentity-matrix (Integer -> (Matrix Float)))
(define (flidentity-matrix m) (diagonal-array 2 m 1.0 0.0))
(: make-matrix (All (A) (Integer Integer A -> (Matrix A))))
(define (make-matrix m n x)
(make-array (vector m n) x))
(: list->matrix : (Listof* Number) -> (Matrix Number))
(define (list->matrix rows)
(list*->array rows number?))
(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A))))
(define (build-matrix m n proc)
(cond [(or (not (index? m)) (= m 0))
(raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)]
[(or (not (index? n)) (= n 0))
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
[else
(unsafe-build-array
((inst vector Index) m n)
(λ: ([js : Indexes])
(proc (unsafe-vector-ref js 0)
(unsafe-vector-ref js 1))))]))
(: fllist->matrix : (Listof* Flonum) -> (Matrix Flonum))
(define (fllist->matrix rows)
(list*->array rows flonum? ))
(: diagonal-matrix/zero (All (A) (Array A) A -> (Array A)))
(define (diagonal-matrix/zero a zero)
(define ds (array-shape a))
(cond [(= 1 (vector-length ds))
(define m (unsafe-vector-ref ds 0))
(define proc (unsafe-array-proc a))
(unsafe-build-array
((inst vector Index) m m)
(λ: ([js : Indexes])
(define i (unsafe-vector-ref js 0))
(cond [(= i (unsafe-vector-ref js 1)) (proc ((inst vector Index) i))]
[else zero])))]
[else
(raise-argument-error 'diagonal-matrix "Array with one dimension" a)]))
(: matrix->list : (All (A) (Matrix A) -> (Listof* A)))
(: diagonal-matrix (case-> ((Array Real) -> (Array Real))
((Array Number) -> (Array Number))))
(define (diagonal-matrix a)
(diagonal-matrix/zero a 0))
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
(define (block-diagonal-matrix/zero* as zero)
(define num (vector-length as))
(define-values (ms ns)
(let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty]
[ns : (Listof Index) empty]
) ([a (in-vector as)])
(define-values (m n) (matrix-shape a))
(values (cons m ms) (cons n ns)))])
(values (reverse ms) (reverse ns))))
(define res-m (assert (apply + ms) index?))
(define res-n (assert (apply + ns) index?))
(define vs ((inst make-vector Index) res-m 0))
(define hs ((inst make-vector Index) res-n 0))
(define is ((inst make-vector Index) res-m 0))
(define js ((inst make-vector Index) res-n 0))
(define-values (_res-i _res-j)
(for/fold: ([res-i : Nonnegative-Fixnum 0]
[res-j : Nonnegative-Fixnum 0]
) ([m (in-list ms)]
[n (in-list ns)]
[k : Nonnegative-Fixnum (in-range num)])
(let ([k (assert k index?)])
(for: ([i : Nonnegative-Fixnum (in-range m)])
(vector-set! vs (unsafe-fx+ res-i i) k)
(vector-set! is (unsafe-fx+ res-i i) (assert i index?)))
(for: ([j : Nonnegative-Fixnum (in-range n)])
(vector-set! hs (unsafe-fx+ res-j j) k)
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
(define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as))
(unsafe-build-array
((inst vector Index) res-m res-n)
(λ: ([ij : Indexes])
(define i (unsafe-vector-ref ij 0))
(define j (unsafe-vector-ref ij 1))
(define v (unsafe-vector-ref vs i))
(cond [(fx= v (vector-ref hs j))
(define proc (unsafe-vector-ref procs v))
(define iv (unsafe-vector-ref is i))
(define jv (unsafe-vector-ref js j))
(unsafe-vector-set! ij 0 iv)
(unsafe-vector-set! ij 1 jv)
(define res (proc ij))
(unsafe-vector-set! ij 0 i)
(unsafe-vector-set! ij 1 j)
res]
[else
zero]))))
(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A)))
(define (block-diagonal-matrix/zero as zero)
(let ([as (list->vector as)])
(define num (vector-length as))
(cond [(= num 0)
(raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)]
[(= num 1)
(unsafe-vector-ref as 0)]
[else
(block-diagonal-matrix/zero* as zero)])))
(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real))
((Listof (Array Number)) -> (Array Number))))
(define (block-diagonal-matrix as)
(block-diagonal-matrix/zero as 0))
(: expt-hack (case-> (Real Integer -> Real)
(Number Integer -> Number)))
;; Stop using this when TR correctly derives expt : Real Integer -> Real
(define (expt-hack x n)
(cond [(real? x) (assert (expt x n) real?)]
[else (expt x n)]))
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real))
((Listof Number) Integer -> (Array Number))))
(define (vandermonde-matrix xs n)
(cond [(empty? xs)
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]
[(or (not (index? n)) (zero? n))
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
[else
(array-axis-expand (list->array xs) 1 n expt-hack)]))
;; =================================================================================================
;; Flat conversion
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
(define (list->matrix m n xs)
(cond [(or (not (index? m)) (= m 0))
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
[(or (not (index? n)) (= n 0))
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
[else (list->array (vector m n) xs)]))
(: matrix->list (All (A) ((Array A) -> (Listof A))))
(define (matrix->list a)
(array->list* a))
(array->list (ensure-matrix 'matrix->list a)))
(: vector->matrix : (Vectorof* Number) -> (Matrix Number))
(define (vector->matrix rows)
(vector*->array rows number? ))
(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A))))
(define (vector->matrix m n v)
(cond [(or (not (index? m)) (= m 0))
(raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)]
[(or (not (index? n)) (= n 0))
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
[else (vector->array (vector m n) v)]))
(: flvector->matrix : (Vectorof* Flonum) -> (Matrix Flonum))
(define (flvector->matrix rows)
(vector*->array rows flonum? ))
(: matrix->vector : (All (A) (Matrix A) -> (Vectorof* A)))
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
(define (matrix->vector a)
(array->vector* a))
(array->vector (ensure-matrix 'matrix->vector a)))
(: submatrix : (Matrix Number) (Sequenceof Index) (Sequenceof Index) -> (Matrix Number))
(define (submatrix a row-range col-range)
(array-slice-ref a (list row-range col-range)))
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
(define (list->row-matrix xs)
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
(: matrix-row : (Matrix Number) Index -> (Matrix Number))
(define (matrix-row a i)
(define-values (m n) (matrix-dimensions a))
(array-slice-ref a (list (in-range i (add1 i)) (in-range n))))
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
(define (list->col-matrix xs)
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
(: matrix-column : (Matrix Number) Index -> (Matrix Number))
(define (matrix-column a j)
(define-values (m n)(matrix-dimensions a))
(array-slice-ref a (list (in-range m) (in-range j (add1 j)))))
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
(define (vector->row-matrix xs)
(define n (vector-length xs))
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
[else (vector->array ((inst vector Index) 1 n) xs)]))
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
(define (vector->col-matrix xs)
(define n (vector-length xs))
(cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)]
[else (vector->array ((inst vector Index) n 1) xs)]))
(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index)))
(define (find-nontrivial-axis ds)
(define dims (vector-length ds))
(let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0])
(cond [(k . < . dims) (define dk (unsafe-vector-ref ds k))
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
[else (values 0 0)])))
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
(define (array->row-matrix arr)
(define (fail)
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
(define ds (array-shape arr))
(define dims (vector-length ds))
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
(cond [(zero? (array-size arr)) (fail)]
[(row-matrix? arr) arr]
[(= num-ones dims)
(define: js : (Vectorof Index) (make-vector dims 0))
(define proc (unsafe-array-proc arr))
(unsafe-build-array ((inst vector Index) 1 1)
(λ: ([ij : Indexes]) (proc js)))]
[(= num-ones (- dims 1))
(define-values (k n) (find-nontrivial-axis ds))
(define js (make-thread-local-indexes dims))
(define proc (unsafe-array-proc arr))
(unsafe-build-array ((inst vector Index) 1 n)
(λ: ([ij : Indexes])
(let ([js (js)])
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
(proc js))))]
[else (fail)]))
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
(define (array->col-matrix arr)
(define (fail)
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
(define ds (array-shape arr))
(define dims (vector-length ds))
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
(cond [(zero? (array-size arr)) (fail)]
[(col-matrix? arr) arr]
[(= num-ones dims)
(define: js : (Vectorof Index) (make-vector dims 0))
(define proc (unsafe-array-proc arr))
(unsafe-build-array ((inst vector Index) 1 1)
(λ: ([ij : Indexes]) (proc js)))]
[(= num-ones (- dims 1))
(define-values (k m) (find-nontrivial-axis ds))
(define js (make-thread-local-indexes dims))
(define proc (unsafe-array-proc arr))
(unsafe-build-array ((inst vector Index) m 1)
(λ: ([ij : Indexes])
(let ([js (js)])
(unsafe-vector-set! js k (unsafe-vector-ref ij 0))
(proc js))))]
[else (fail)]))
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
(define (->row-matrix xs)
(cond [(list? xs) (list->row-matrix xs)]
[(array? xs) (array->row-matrix xs)]
[else (vector->row-matrix xs)]))
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
(define (->col-matrix xs)
(cond [(list? xs) (list->col-matrix xs)]
[(array? xs) (array->col-matrix xs)]
[else (vector->col-matrix xs)]))
;; =================================================================================================
;; Nested conversion
(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index)))
(define (list*-shape xss fail)
(define m (length xss))
(cond [(m . > . 0)
(define n (length (first xss)))
(cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss)))
(values m n)]
[else (fail)])]
[else (fail)]))
(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing)
-> (Values Positive-Index Positive-Index)))
(define (vector*-shape xss fail)
(define m (vector-length xss))
(cond [(m . > . 0)
(define ns ((inst vector-map Index (Vectorof A)) vector-length xss))
(define n (vector-length (unsafe-vector-ref xss 0)))
(cond [(and (n . > . 0)
(let: loop : Boolean ([i : Nonnegative-Fixnum 1])
(cond [(i . fx< . m)
(if (= n (vector-length (unsafe-vector-ref xss i)))
(loop (fx+ i 1))
#f)]
[else #t])))
(values m n)]
[else (fail)])]
[else (fail)]))
(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A)))
(define (list*->matrix xss)
(define (fail)
(raise-argument-error 'list*->matrix
"nested lists with rectangular shape and at least one matrix element"
xss))
(define-values (m n) (list*-shape xss fail))
(list->array ((inst vector Index) m n) (apply append xss)))
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
(define (matrix->list* a)
(cond [(matrix? a) (array->list (array->list-array a 1))]
[else (raise-argument-error 'matrix->list* "matrix?" a)]))
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
(define (vector*->matrix xss)
(define (fail)
(raise-argument-error 'vector*->matrix
"nested vectors with rectangular shape and at least one matrix element"
xss))
(define-values (m n) (vector*-shape xss fail))
(vector->matrix m n (apply vector-append (vector->list xss))))
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
(define (matrix->vector* a)
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))]
[else (raise-argument-error 'matrix->vector* "matrix?" a)]))
) ; module
(require (for-syntax racket/base
syntax/parse)
(only-in typed/racket/base :)
math/array
(submod "." typed-defs))
(define-syntax (matrix stx)
(syntax-parse stx #:literals (:)
[(_ [[x0 xs0 ...] [x xs ...] ...])
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))]
[(_ [[x0 xs0 ...] [x xs ...] ...] : T)
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))]
[(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T)))
(raise-syntax-error 'matrix "given empty row" stx #'r)]
[(_ (~and [] c) (~optional (~seq : T)))
(raise-syntax-error 'matrix "given empty matrix" stx #'c)]))
(define-syntax (row-matrix stx)
(syntax-parse stx #:literals (:)
[(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))]
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))]
[(_ (~and [] r) (~optional (~seq : T)))
(raise-syntax-error 'row-matrix "given empty row" stx #'r)]))
(define-syntax (col-matrix stx)
(syntax-parse stx #:literals (:)
[(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))]
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))]
[(_ (~and [] c) (~optional (~seq : T)))
(raise-syntax-error 'row-matrix "given empty column" stx #'c)]))

View File

@ -1,22 +1,21 @@
#lang typed/racket
(require "../../array.rkt"
(require math/array
"matrix-types.rkt"
"matrix-constructors.rkt"
"matrix-multiply.rkt"
"matrix-types.rkt")
"matrix-arithmetic.rkt")
(provide matrix-expt)
(: matrix-expt : (Matrix Number) Integer -> (Matrix Number))
(define (matrix-expt a n)
(unless (array-matrix? a)
(raise-type-error 'matrix-expt "(Matrix Number)" a))
(unless (square-matrix? a)
(error 'matrix-expt "Square matrix expected, got ~a" a))
(cond
[(= n 0) (identity-matrix (square-matrix-size a))]
[(= n 1) a]
(cond [(not (square-matrix? a)) (raise-argument-error 'matrix-expt "square-matrix?" 0 a n)]
[(negative? n) (raise-argument-error 'matrix-expt "Natural" 1 a n)]
[(zero? n) (identity-matrix (square-matrix-size a))]
[else
(let: loop : (Matrix Number) ([n : Positive-Integer n])
(cond [(= n 1) a]
[(= n 2) (matrix* a a)]
[(even? n) (let ([a^n/2 (matrix-expt a (quotient n 2))])
(matrix* a^n/2 a^n/2))]
[else (matrix* a (matrix-expt a (sub1 n)))]))
[else (matrix* a (matrix-expt a (sub1 n)))]))]))

View File

@ -1,60 +0,0 @@
#lang typed/racket
(require "../unsafe.rkt"
"../../array.rkt"
"matrix-types.rkt")
(provide matrix*)
;; The `make-matrix-*' operators have to be macros; see ../array/array-pointwise.rkt for an
;; explanation.
#;(: make-matrix-multiply (All (A) (Symbol
((Array A) Integer -> (Array A))
((Array A) (Array A) -> (Array A))
-> ((Array A) (Array A) -> (Array A)))))
(define-syntax-rule (make-matrix-multiply name array-axis-sum array*)
(λ (arr brr)
(unless (array-matrix? arr) (raise-type-error name "matrix" 0 arr brr))
(unless (array-matrix? brr) (raise-type-error name "matrix" 1 arr brr))
(match-define (vector ad0 ad1) (array-shape arr))
(match-define (vector bd0 bd1) (array-shape brr))
(unless (= ad1 bd0)
(error name
"1st argument column size and 2nd argument row size are not equal; given ~e and ~e"
arr brr))
;; Get strict versions of both because each element in both is evaluated multiple times
(let ([arr (array->mutable-array arr)]
[brr (array->mutable-array brr)])
;; This next part could be done with array-permute, but it's much slower that way
(define avs (mutable-array-data arr))
(define bvs (mutable-array-data brr))
;; Extend arr in the center dimension
(define: ds-ext : (Vectorof Index) (vector ad0 bd1 ad1))
(define arr-ext
(unsafe-build-array
ds-ext (λ: ([js : (Vectorof Index)])
(define j0 (unsafe-vector-ref js 0))
(define j1 (unsafe-vector-ref js 2))
;(unsafe-array-ref* arr j0 j1) [twice as slow]
(unsafe-vector-ref avs (unsafe-fx+ j1 (unsafe-fx* j0 ad1))))))
;; Transpose brr and extend in the leftmost dimension
;; Note that ds-ext = (vector ad0 bd1 bd0) because bd0 = ad1
(define brr-ext
(unsafe-build-array
ds-ext (λ: ([js : (Vectorof Index)])
(define j0 (unsafe-vector-ref js 2))
(define j1 (unsafe-vector-ref js 1))
;(unsafe-array-ref* brr j0 j1) [twice as slow]
(unsafe-vector-ref bvs (unsafe-fx+ j1 (unsafe-fx* j0 bd1))))))
(array-axis-sum (array* arr-ext brr-ext) 2))))
;; ---------------------------------------------------------------------------------------------------
(: matrix* (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Number) (Matrix Number) -> (Matrix Number))))
(define matrix* (make-matrix-multiply 'matrix* array-axis-sum array*))
;(: matrix-fl* ((Array Float) (Array Float) -> (Array Float)))
;(define matrix-fl* (make-matrix-multiply 'matrix-fl* array-axis-flsum array-fl*))

View File

@ -1,14 +1,16 @@
#lang typed/racket/base
(require math/array
(require racket/list
math/array
(only-in typed/racket conjugate)
"../unsafe.rkt"
"matrix-types.rkt"
"matrix-constructors.rkt"
"matrix-pointwise.rkt"
"matrix-arithmetic.rkt"
"matrix-basic.rkt"
"matrix-column.rkt"
(for-syntax racket))
; TODO:
; 1. compute null space from QR factorization
; (better numerical stability than from Gauss elimnation)
@ -25,21 +27,6 @@
; But TR has problems with #:when so what is the proper expansion ?
(provide
; basic
matrix-ref
matrix-scale
matrix-row-vector?
matrix-column-vector?
matrix/dim ; construct
matrix-augment ; horizontally
matrix-stack ; vertically
matrix-block-diagonal
; norms
matrix-norm
; operators
matrix-transpose
matrix-conjugate
matrix-hermitian
matrix-inverse
; row and column
matrix-scale-row
@ -56,7 +43,6 @@
matrix-rank
matrix-nullity
matrix-determinant
matrix-trace
; spaces
;matrix-column+null-space
; solvers
@ -64,17 +50,6 @@
matrix-solve-many
; spaces
matrix-column-space
; column vectors
column ; construct
unit-column
result-column ; convert to lazy
column-dimension
column-dot
column-norm
column-projection
column-normalize
scale-column
column+
; projection
projection-on-orthogonal-basis
projection-on-orthonormal-basis
@ -84,70 +59,8 @@
; factorization
matrix-lu
matrix-qr
; comprehensions
for/matrix:
for*/matrix:
for/matrix-sum:
; sequences
in-row
in-column
; special matrices
vandermonde-matrix
)
;;;
;;; Basic
;;;
(: matrix-ref : (Matrix Number) Integer Integer -> Number)
(define (matrix-ref M i j)
((inst array-ref Number) M (vector i j)))
(: matrix-scale : Number (Matrix Number) -> (Matrix Number))
(define (matrix-scale s a)
(array-scale a s))
(: matrix-row-vector? : (Matrix Number) -> Boolean)
(define (matrix-row-vector? a)
(= (matrix-row-dimension a) 1))
(: matrix-column-vector? : (Matrix Number) -> Boolean)
(define (matrix-column-vector? a)
(= (matrix-column-dimension a) 1))
;;;
;;; Norms
;;;
(: matrix-norm : (Matrix Number) -> Real)
(define (matrix-norm a)
(define n
(sqrt
(array-ref
(array-axis-sum
(array-axis-sum
(matrix.sqr (matrix.magnitude a)) 0) 0)
'#())))
(assert n real?))
;;;
;;; Operators
;;;
(: matrix-transpose : (Matrix Number) -> (Matrix Number))
(define (matrix-transpose a)
(array-axis-swap a 0 1))
(: matrix-conjugate : (Matrix Number) -> (Matrix Number))
(define (matrix-conjugate a)
(array-conjugate a))
(: matrix-hermitian : (Matrix Number) -> (Matrix Number))
(define (matrix-hermitian a)
(matrix-transpose
(array-conjugate a)))
;;;
;;; Row and column
;;;
@ -296,7 +209,7 @@
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer)))
((Matrix Number) -> (Values (Matrix Number) (Listof Integer)))))
(define (matrix-gauss-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t])
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(: loop : (Integer Integer (Matrix Number) Integer (Listof Integer)
-> (Values (Matrix Number) (Listof Integer))))
(define (loop i j ; i from 0 to m
@ -362,20 +275,20 @@
; TODO: Use QR or SVD instead for inexact matrices
; See answer: http://scicomp.stackexchange.com/questions/1861/understanding-how-numpy-does-svd
; rank = dimension of column space = dimension of row space
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(define-values (_ cols-without-pivot) (matrix-gauss-eliminate M))
(- n (length cols-without-pivot)))
(: matrix-nullity : (Matrix Number) -> Integer)
(define (matrix-nullity M)
; nullity = dimension of null space
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(define-values (_ cols-without-pivot) (matrix-gauss-eliminate M))
(length cols-without-pivot))
(: matrix-determinant : (Matrix Number) -> Number)
(define (matrix-determinant M)
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(cond
[(= m 1) (matrix-ref M 0 0)]
[(= m 2) (let ([a (matrix-ref M 0 0)]
@ -406,18 +319,12 @@
(set! product (* product (matrix-ref M i i))))
product))]))
(: matrix-trace : (Matrix Number) -> Number)
(define (matrix-trace M)
(define-values (m n) (matrix-dimensions M))
(for/sum: : Number ([i (in-range 0 m)])
(matrix-ref M i i)))
(: matrix-column-space : (Matrix Number) -> (Listof (Matrix Number)))
; Returns
; 1) a list of column vectors spanning the column space
; 2) a list of column vectors spanning the null space
(define (matrix-column-space M)
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(: M1 (Matrix Number))
(: cols-without-pivot (Listof Integer))
(define-values (M1 cols-without-pivot) (matrix-gauss-eliminate M #t))
@ -426,7 +333,7 @@
(for/list:
([i : Index n]
#:when (not (member i cols-without-pivot)))
(matrix-column M1 i)))
(matrix-col M1 i)))
column-space)
(: matrix-row-echelon-form :
@ -442,7 +349,7 @@
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer)))
((Matrix Number) -> (Values (Matrix Number) (Listof Integer)))))
(define (matrix-gauss-jordan-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t])
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(: loop : (Integer Integer (Matrix Number) Integer (Listof Integer)
-> (Values (Matrix Number) (Listof Integer))))
(define (loop i j ; i from 0 to m
@ -512,23 +419,11 @@
(let-values ([(M wp) (matrix-gauss-jordan-eliminate M unitize-pivot-row?)])
M))
(: matrix-augment : (Matrix Number) (Matrix Number) * -> (Matrix Number))
(define (matrix-augment a . as)
(array-append* (cons a as) 1))
(: matrix-stack : (Matrix Number) * -> (Matrix Number))
(define (matrix-stack . as)
(if (null? as)
(error 'matrix-stack
"expected non-empty list of matrices")
(array-append* as 0)))
(: matrix-inverse : (Matrix Number) -> (Matrix Number))
(define (matrix-inverse M)
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(unless (= m n) (error 'matrix-inverse "matrix not square"))
(let ([MI (matrix-augment M (identity-matrix m))])
(let ([MI (matrix-augment (list M (identity-matrix m)))])
(define 2m (* 2 m))
(if (index? 2m)
(submatrix (matrix-reduced-row-echelon-form MI #t)
@ -539,8 +434,8 @@
; Return a column-vector x such that Mx = b.
; If no such vector exists return #f.
(define (matrix-solve M b)
(define-values (m n) (matrix-dimensions M))
(define-values (s t) (matrix-dimensions b))
(define-values (m n) (matrix-shape M))
(define-values (s t) (matrix-shape b))
(define m+1 (+ m 1))
(cond
[(not (= t 1)) (error 'matrix-solve "expected column vector (i.e. r x 1 - matrix), got: ~a " b)]
@ -548,18 +443,14 @@
[(index? m+1)
(submatrix
(matrix-reduced-row-echelon-form
(matrix-augment M b) #t)
(matrix-augment (list M b)) #t)
(in-range 0 m) (in-range m m+1))]
[else (error 'matrix-solve "internatl error")]))
(: matrix-solve-many : (Matrix Number) (Listof (Matrix Number)) -> (Matrix Number))
(define (matrix-solve-many M bs)
; TODO: Rewrite matrix-augment* to use array-append when it is ready
(: matrix-augment* : (Listof (Matrix Number)) -> (Matrix Number))
(define (matrix-augment* vs)
(foldl matrix-augment (car vs) (cdr vs)))
(define-values (m n) (matrix-dimensions M))
(define-values (s t) (matrix-dimensions (car bs)))
(define-values (m n) (matrix-shape M))
(define-values (s t) (matrix-shape (car bs)))
(define k (length bs))
(define m+1 (+ m 1))
(define m+k (+ m k))
@ -567,8 +458,8 @@
[(not (= t 1)) (error 'matrix-solve-many "expected column vector (i.e. r x 1 - matrix), got: ~a " (car bs))]
[(not (= m s)) (error 'matrix-solve-many "expected column vectors with same number of rows as the matrix")]
[(and (index? m+1) (index? m+k))
(define bs-as-matrix (matrix-augment* bs))
(define MB (matrix-augment M bs-as-matrix))
(define bs-as-matrix (matrix-augment bs))
(define MB (matrix-augment (list M bs-as-matrix)))
(define reduced-MB (matrix-reduced-row-echelon-form MB #t))
(submatrix reduced-MB
(in-range 0 m+k)
@ -584,7 +475,7 @@
(: matrix-lu :
(Matrix Number) -> (U False (List (Matrix Number) (Matrix Number))))
(define (matrix-lu M)
(define-values (m _) (matrix-dimensions M))
(define-values (m _) (matrix-shape M))
(define: ms : (Listof Number) '())
(define V
(let/ec: return : (U False (Matrix Number))
@ -639,111 +530,6 @@
(list L V))))
(: column-dimension : (Column Number) -> Index)
(define (column-dimension v)
(if (vector? v)
(unsafe-vector-length v)
(matrix-row-dimension v)))
(: unsafe-column->vector : (Column Number) -> (Vectorof Number))
(define (unsafe-column->vector v)
(if (vector? v) v
(let ()
(define-values (m n) (matrix-dimensions v))
(if (= n 1)
(mutable-array-data (array->mutable-array v))
(error 'unsafe-column->vector
"expected a column (vector or mx1 matrix), got ~a" v)))))
(: vector->column : (Vectorof Number) -> (Result-Column Number))
(define (vector->column v)
(define m (vector-length v))
(flat-vector->matrix m 1 v))
(: column : Number * -> (Result-Column Number))
(define (column . xs)
(vector->column
(list->vector xs)))
(: result-column : (Column Number) -> (Result-Column Number))
(define (result-column c)
(if (vector? c)
(vector->column c)
c))
(: scale-column : Number (Column Number) -> (Result-Column Number))
(define (scale-column s a)
(if (vector? a)
(let*: ([n (vector-length a)]
[v : (Vectorof Number) (make-vector n 0)])
(for: ([i (in-range 0 n)]
[x : Number (in-vector a)])
(vector-set! v i (* s x)))
(vector->column v))
(matrix-scale s a)))
(: column+ : (Column Number) (Column Number) -> (Result-Column Number))
(define (column+ v w)
(cond [(and (vector? v) (vector? w))
(let ([n (vector-length v)]
[m (vector-length w)])
(unless (= m n)
(error 'column+
"expected two column vectors of the same length, got ~a and ~a" v w))
(define: v+w : (Vectorof Number) (make-vector n 0))
(for: ([i (in-range 0 n)]
[x : Number (in-vector v)]
[y : Number (in-vector w)])
(vector-set! v+w i (+ x y)))
(result-column v+w))]
[else
(unless (= (column-dimension v) (column-dimension w))
(error 'column+
"expected two column vectors of the same length, got ~a and ~a" v w))
(matrix+ (result-column v) (result-column w))]))
(: column-dot : (Column Number) (Column Number) -> Number)
(define (column-dot c d)
(define v (unsafe-column->vector c))
(define w (unsafe-column->vector d))
(define m (column-dimension v))
(define s (column-dimension w))
(cond
[(not (= m s)) (error 'column-dot
"expected two mx1 matrices with same number of rows, got ~a and ~a"
c d)]
[else
(for/sum: : Number ([i (in-range 0 m)])
(assert i index?)
; Note: If d is a vector of reals,
; then the conjugate is a no-op
(* (unsafe-vector-ref v i)
(conjugate (unsafe-vector-ref w i))))]))
(: column-norm : (Column Number) -> Real)
(define (column-norm v)
(define norm (sqrt (column-dot v v)))
(assert norm real?))
(: column-projection : (Column Number) (Column Number) -> (Result-Column Number))
; (column-projection v w)
; Return the projection og vector v on vector w.
(define (column-projection v w)
(let ([w.w (column-dot w w)])
(if (zero? w.w)
(error 'column-projection "projection on the zero vector not defined")
(matrix-scale (/ (column-dot v w) w.w) (result-column w)))))
(: column-projection-on-unit : (Column Number) (Column Number) -> (Result-Column Number))
; (column-projection-on-unit v w)
; Return the projection og vector v on a unit vector w.
(define (column-projection-on-unit v w)
(matrix-scale (column-dot v w) (result-column w)))
(: projection-on-orthogonal-basis :
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
; (projection-on-orthogonal-basis v bs)
@ -754,20 +540,9 @@
(if (null? bs)
(error 'projection-on-orthogonal-basis
"received empty list of basis vectors")
(for/matrix-sum: : Number ([b (in-list bs)])
(column-projection v (result-column b)))))
; #;(for/matrix-sum ([b bs])
; (matrix-scale (column-dot v b) b))
; (define: sum : (U False (Result-Column Number)) #f)
; (for ([b1 (in-list bs)])
; (define: b : (Result-Column Number) (result-column b1))
; (cond [(not sum) (set! sum (column-projection v b))]
; [else (set! sum (matrix+ (assert sum) (column-projection v b)))]))
; (cond [sum (assert sum)]
; [else (error 'projection-on-orthogonal-basis
; "received empty list of basis vectors")])
(matrix-sum (map (λ: ([b : (Column Number)])
(column-project v (->col-matrix b)))
bs))))
; (projection-on-orthonormal-basis v bs)
; Project the vector v on the orthonormal basis vectors in bs.
@ -776,35 +551,29 @@
(: projection-on-orthonormal-basis :
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
(define (projection-on-orthonormal-basis v bs)
#;(for/matrix-sum ([b bs]) (matrix-scale (column-dot v b) b))
#;(for/matrix-sum ([b bs]) (matrix-scale b (column-dot v b)))
(define: sum : (U False (Result-Column Number)) #f)
(for ([b1 (in-list bs)])
(define: b : (Result-Column Number) (result-column b1))
(cond [(not sum) (set! sum (column-projection-on-unit v b))]
[else (set! sum (matrix+ (assert sum) (column-projection-on-unit v b)))]))
(define: b : (Result-Column Number) (->col-matrix b1))
(cond [(not sum) (set! sum (column-project/unit v b))]
[else (set! sum (array+ (assert sum) (column-project/unit v b)))]))
(cond [sum (assert sum)]
[else (error 'projection-on-orthogonal-basis
[else (error 'projection-on-orthonormal-basis
"received empty list of basis vectors")]))
(: zero-column-vector? : (Matrix Number) -> Boolean)
(define (zero-column-vector? v)
(define-values (m n) (matrix-dimensions v))
(for/and: ([i (in-range 0 m)])
(zero? (matrix-ref v i 0))))
(: gram-schmidt-orthogonal : (Listof (Column Number)) -> (Listof (Result-Column Number)))
; (gram-schmidt-orthogonal ws)
; Given a list ws of column vectors, produce
; an orthogonal basis for the span of the
; vectors in ws.
(define (gram-schmidt-orthogonal ws1)
(define ws (map result-column ws1))
(define ws (map (λ: ([w : (Column Number)]) (->col-matrix w)) ws1))
(cond
[(null? ws) '()]
[(null? (cdr ws)) (list (car ws))]
[else
(: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number)) -> (Listof (Result-Column Number)))
(: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number))
-> (Listof (Result-Column Number)))
(define (loop vs ws)
(cond [(null? ws) vs]
[else
@ -812,84 +581,32 @@
(let ([w-proj (projection-on-orthogonal-basis w vs)])
; Note: We project onto vs (not on the original ws)
; in order to get numerical stability.
(let ([w-minus-proj (matrix- w w-proj)])
(if (zero-column-vector? w-minus-proj)
(let ([w-minus-proj (array-strict (array- w w-proj))])
(if (zero-matrix? w-minus-proj)
(loop vs (cdr ws)) ; w in span{vs} => omit it
(loop (cons (matrix- w w-proj) vs) (cdr ws)))))]))
(loop (cons w-minus-proj vs) (cdr ws)))))]))
(reverse (loop (list (car ws)) (cdr ws)))]))
(: column-normalize : (Column Number) -> (Result-Column Number))
; (column-vector-normalize v)
; Return unit vector with same direction as v.
; If v is the zero vector, the zero vector is returned.
(define (column-normalize w)
(let ([norm (column-norm w)]
[w (result-column w)])
(cond [(zero? norm) w]
[else (matrix-scale (/ norm) w)])))
(: gram-schmidt-orthonormal : (Listof (Column Number)) -> (Listof (Result-Column Number)))
; (gram-schmidt-orthonormal ws)
; Given a list ws of column vectors, produce
; an orthonormal basis for the span of the
; vectors in ws.
(define (gram-schmidt-orthonormal ws)
(map column-normalize
(gram-schmidt-orthogonal ws)))
(map column-normalize (gram-schmidt-orthogonal ws)))
(: projection-on-subspace :
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
; (projection-on-subspace v ws)
; Returns the projection of v on span{w_i}, w_i in ws.
(define (projection-on-subspace v ws)
(projection-on-orthogonal-basis
v (gram-schmidt-orthogonal ws)))
(: unit-column : Integer Integer -> (Result-Column Number))
(define (unit-column m i)
(cond
[(and (index? m) (index? i))
(define v (make-vector m 0))
(if (< i m)
(vector-set! v i 1)
(error 'unit-vector "dimension must be largest"))
(flat-vector->matrix m 1 v)]
[else
(error 'unit-vector "expected two indices")]))
(: take : (All (A) ((Listof A) Index -> (Listof A))))
(define (take xs n)
(if (= n 0)
'()
(let ([n-1 (- n 1)])
(if (index? n-1)
(cons (car xs) (take (cdr xs) n-1))
(error 'take "can not take more elements than the length of the list")))))
; (list 'take (equal? (take (list 0 1 2 3 4) 2) '(0 1)))
(: matrix->columns : (Matrix Number) -> (Listof (Matrix Number)))
(define (matrix->columns M)
(define-values (m n) (matrix-dimensions M))
(for/list: : (Listof (Matrix Number))
([j (in-range 0 n)])
(matrix-column M (assert j index?))))
(: matrix-augment* : (Listof (Matrix Number)) -> (Matrix Number))
(define (matrix-augment* Ms)
(define MM (car Ms))
(for: ([M (in-list (cdr Ms))])
(set! MM (matrix-augment MM M)))
MM)
(projection-on-orthogonal-basis v (gram-schmidt-orthogonal ws)))
(: extend-span-to-basis :
(Listof (Matrix Number)) Integer -> (Listof (Matrix Number)))
; Extend the basis in vs to with rdimensional basis
(define (extend-span-to-basis vs r)
(define-values (m n) (matrix-dimensions (car vs)))
(define-values (m n) (matrix-shape (car vs)))
(: loop : (Listof (Matrix Number)) (Listof (Matrix Number)) Integer -> (Listof (Matrix Number)))
(define (loop vs ws i)
(if (>= i m)
@ -897,15 +614,15 @@
(let ()
(define ei (unit-column m i))
(define pi (projection-on-subspace ei vs))
(if (matrix-all= ei pi)
(if (matrix= ei pi)
(loop vs ws (+ i 1))
(let ([w (matrix- ei pi)])
(let ([w (array- ei pi)])
(loop (cons w vs) (cons w ws) (+ i 1)))))))
(: norm> : (Matrix Number) (Matrix Number) -> Boolean)
(define (norm> v w)
(> (column-norm v) (column-norm w)))
(if (index? r)
((inst take (Matrix Number)) (sort (loop vs '() 0) norm>) r)
(take (sort (loop vs '() 0) norm>) r)
(error 'extend-span-to-basis "expected index as second argument, got ~a" r)))
(: matrix-qr : (Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
@ -915,13 +632,13 @@
; 2) columns of Q is are orthonormal
; 3) R is upper-triangular
; Note: columnspace(A)=columnspace(Q) !
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(let* ([basis-for-column-space
(gram-schmidt-orthonormal (matrix->columns M))]
(gram-schmidt-orthonormal (matrix-cols M))]
[extension
(extend-span-to-basis
basis-for-column-space (- n (length basis-for-column-space)))]
[Q (matrix-augment*
[Q (matrix-augment
(append basis-for-column-space
(map column-normalize
extension)))]
@ -938,305 +655,5 @@
(set! sum (+ sum (* (matrix-ref Q k i)
(matrix-ref M k j)))))
(vector-set! v (+ (* i n) j) sum))))
(flat-vector->matrix n n v))])
(vector->matrix n n v))])
(values Q R)))
(: matrix/dim : Integer Integer Number * -> (Matrix Number))
; construct a mxn matrix with elements from the values xs
; the length of xs must be m*n
(define (matrix/dim m n . xs)
(cond [(and (index? m) (index? n))
(flat-vector->matrix m n (list->vector xs))]
[else (error 'matrix/dim "expected two indices as dimensions, got ~a and ~a" m n)]))
(: matrix-block-diagonal : (Listof (Matrix Number)) -> (Matrix Number))
(define (matrix-block-diagonal as)
(define sum-m 0)
(define sum-n 0)
(define: ms : (Listof Index) '())
(define: ns : (Listof Index) '())
(for: ([a (in-list as)])
(define-values (m n) (matrix-dimensions a))
(set! sum-m (+ sum-m m))
(set! sum-n (+ sum-n n))
(set! ms (cons m ms))
(set! ns (cons n ns)))
(set! ms (reverse ms))
(set! ns (reverse ns))
(: loop : (Listof (Matrix Number)) (Listof Index) (Listof Index)
(Listof (Matrix Number)) Integer -> (Matrix Number))
(define (loop as ms ns rows left)
(cond [(null? as) (apply matrix-stack (reverse rows))]
[else
(define m (car ms))
(define n (car ns))
(define a (car as))
(define row
(matrix-augment
((inst make-matrix Number) m left 0) a (make-matrix m (- sum-n n left) 0)))
(loop (cdr as) (cdr ms) (cdr ns) (cons row rows) (+ left n))]))
(loop as ms ns '() 0))
(define-syntax (for/column: stx)
(syntax-case stx ()
[(_ : type m-expr (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: flat-vector : (Vectorof Number) (make-vector m 0))
(for: ([i (in-range m)] for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! flat-vector i x))
(vector->column flat-vector)))]))
(define-syntax (for/matrix: stx)
(syntax-case stx ()
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(define: k : Index 0)
(for: ([i (in-range m*n)] for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
(set! k (assert (+ k 1) index?)))
(flat-vector->matrix m n v)))]
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(for: ([i (in-range m*n)] for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v i x))
(flat-vector->matrix m n v)))]))
(define-syntax (for*/matrix: stx)
(syntax-case stx ()
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(define: k : Index 0)
(for*: (for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
(set! k (assert (+ k 1) index?)))
(flat-vector->matrix m n v)))]
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: m : Index m-expr)
(define: n : Index n-expr)
(define: m*n : Index (assert (* m n) index?))
(define: v : (Vectorof Number) (make-vector m*n 0))
(define: i : Index 0)
(for*: (for:-clause ...)
(define x (let () . defs+exprs))
(vector-set! v i x)
(set! i (assert (+ i 1) index?)))
(flat-vector->matrix m n v)))]))
(define-syntax (for/matrix-sum: stx)
(syntax-case stx ()
[(_ : type (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: sum : (U False (Matrix Number)) #f)
(for: (for:-clause ...)
(define a (let () . defs+exprs))
(set! sum (if sum (matrix+ (assert sum) a) a)))
(assert sum)))]))
;;;
;;; SEQUENCES
;;;
(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number))
(define (in-row/proc M r)
(define-values (m n) (matrix-dimensions M))
(make-do-sequence
(λ ()
(values
; pos->element
(λ: ([j : Index]) (matrix-ref M r j))
; next-pos
(λ: ([j : Index]) (assert (+ j 1) index?))
; initial-pos
0
; continue-with-pos?
(λ: ([j : Index ]) (< j n))
#f #f))))
(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number))
(define (in-column/proc M s)
(define-values (m n) (matrix-dimensions M))
(make-do-sequence
(λ ()
(values
; pos->element
(λ: ([i : Index]) (matrix-ref M i s))
; next-pos
(λ: ([i : Index]) (assert (+ i 1) index?))
; initial-pos
0
; continue-with-pos?
(λ: ([i : Index]) (< i m))
#f #f))))
; (in-row M i]
; Returns a sequence of all elements of row i,
; that is xi0, xi1, xi2, ...
(define-sequence-syntax in-row
(λ () #'in-row/proc)
(λ (stx)
(syntax-case stx ()
[[(x) (_ M-expr r-expr)]
#'((x)
(:do-in
([(M r n d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(values M1 r-expr rd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? r)
(raise-type-error 'in-row "expected row number, got ~a" r))
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
(raise-type-error 'in-row "expected row number, got ~a" r)))
([j 0])
(< j n)
([(x) (vector-ref d (+ (* r n) j))])
#true
#true
[(+ j 1)]))]
[[(i x) (_ M-expr r-expr)]
#'((i x)
(:do-in
([(M r n d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(values M1 r-expr rd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? r)
(raise-type-error 'in-row "expected row number, got ~a" r)))
([j 0])
(< j n)
([(x) (vector-ref d (+ (* r n) j))]
[(i) j])
#true
#true
[(+ j 1)]))]
[[_ clause] (raise-syntax-error
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
; (in-column M j]
; Returns a sequence of all elements of column j,
; that is x0j, x1j, x2j, ...
(define-sequence-syntax in-column
(λ () #'in-column/proc)
(λ (stx)
(syntax-case stx ()
; M-expr evaluates to column
[[(x) (_ M-expr)]
#'((x)
(:do-in
([(M n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(values M1 rd cd
(mutable-array-data
(array->mutable-array M1))))])
(unless (array-matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
([j 0])
(< j n)
([(x) (vector-ref d j)])
#true
#true
[(+ j 1)]))]
; M-expr evaluats to matrix, s-expr to the column index
[[(x) (_ M-expr s-expr)]
#'((x)
(:do-in
([(M s n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(values M1 s-expr rd cd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? s)
(raise-type-error 'in-row "expected col number, got ~a" s))
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
(raise-type-error 'in-col "expected col number, got ~a" s)))
([j 0])
(< j m)
([(x) (vector-ref d (+ (* j n) s))])
#true
#true
[(+ j 1)]))]
[[(i x) (_ M-expr s-expr)]
#'((x)
(:do-in
([(M s n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(values M1 s-expr rd cd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(raise-type-error 'in-column "expected matrix, got ~a" M))
(unless (integer? s)
(raise-type-error 'in-column "expected col number, got ~a" s))
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
(raise-type-error 'in-column "expected col number, got ~a" s)))
([j 0])
(< j m)
([(x) (vector-ref d (+ (* j n) s))]
[(i) j])
#true
#true
[(+ j 1)]))]
[[_ clause] (raise-syntax-error
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
(: vandermonde-matrix : (Listof Number) Integer -> (Matrix Number))
(define (vandermonde-matrix xs n)
; construct matrix M with M(i,j)=α_i^j ; where i and j begin from 0 ...
; Inefficient version:
(cond
[(not (index? n))
(error 'vandermonde-matrix "expected Index as second argument, got ~a" n)]
[else (define: m : Index (length xs))
(define: αs : (Vectorof Number) (list->vector xs))
(define: α^j : (Vectorof Number) (make-vector n 1))
(for*/matrix: : Number m n #:column
([j (in-range 0 n)]
[i (in-range 0 m)])
(define αi^j (vector-ref α^j i))
(define αi (vector-ref αs i ))
(vector-set! α^j i (* αi^j αi))
αi^j)]))

View File

@ -1,57 +0,0 @@
#lang typed/racket
(require "../unsafe.rkt"
"../../array.rkt"
"matrix-types.rkt")
(provide matrix+ matrix-
matrix.sqr matrix.magnitude)
;; The `make-matrix-*' operators have to be macros; see ../array/array-pointwise.rkt for an
;; explanation.
#;(: make-matrix-pointwise1 (All (A) (Symbol
((Array A) -> (Array A))
-> ((Array A) -> (Array A)))))
(define-syntax-rule (make-matrix-pointwise1 name array-op1)
(λ (arr)
(unless (array-matrix? arr) (raise-type-error name "matrix" arr))
(array-op1 arr)))
#;(: make-matrix-pointwise2 (All (A) (Symbol
((Array A) (Array A) -> (Array A))
-> ((Array A) (Array A) -> (Array A)))))
(define-syntax-rule (make-matrix-pointwise2 name array-op2)
(λ (arr brr)
(unless (array-matrix? arr) (raise-type-error name "matrix" 0 arr brr))
(unless (array-matrix? brr) (raise-type-error name "matrix" 1 arr brr))
(array-op2 arr brr)))
#;(: make-matrix-pointwise1/2
(All (A) (Symbol
(case-> ((Array A) -> (Array A))
((Array A) (Array A) -> (Array A)))
->
(case-> ((Array A) -> (Array A))
((Array A) (Array A) -> (Array A))))))
(define-syntax-rule (make-matrix-pointwise1/2 name array-op1/2)
(case-lambda
[(arr) ((make-matrix-pointwise1 name array-op1/2) arr)]
[(arr brr) ((make-matrix-pointwise2 name array-op1/2) arr brr)]))
;; ---------------------------------------------------------------------------------------------------
(: matrix+ (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Number) (Matrix Number) -> (Matrix Number))))
(: matrix- (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Number) -> (Matrix Number))
((Matrix Real) (Matrix Real) -> (Matrix Real))
((Matrix Number) (Matrix Number) -> (Matrix Number))))
(: matrix.sqr (case-> ((Matrix Real) -> (Matrix Real))
((Matrix Number) -> (Matrix Number))))
(: matrix.magnitude ((Matrix Number) -> (Matrix Real)))
(define matrix+ (make-matrix-pointwise2 'matrix+ array+))
(define matrix- (make-matrix-pointwise1/2 'matrix- array-))
(define matrix.sqr (make-matrix-pointwise1 'matrix.sqr array-sqr))
(define matrix.magnitude (make-matrix-pointwise1 'matrix.magnitude array-magnitude))

View File

@ -1,100 +1,16 @@
#lang racket
(provide for/matrix
for*/matrix
in-row
(provide in-row
in-column)
(require math/array
(except-in math/matrix in-row in-column))
;;; COMPREHENSIONS
; (for/matrix m n (clause ...) . defs+exprs)
; Return an m x n matrix with elements from the last expr.
; The first n values produced becomes the first row.
; The next n values becomes the second row and so on.
; The bindings in clauses run in parallel.
(define-syntax (for/matrix stx)
(syntax-case stx ()
[(_ m-expr n-expr (clause ...) . defs+exprs)
(syntax/loc stx
(let ([m m-expr] [n n-expr])
(define flat-vector
(for/vector #:length (* m n)
(clause ...) . defs+exprs))
; TODO (efficiency): Use a flat-vector->array instead
(flat-vector->matrix m n flat-vector)))]))
; (for*/matrix m n (clause ...) . defs+exprs)
; Return an m x n matrix with elements from the last expr.
; The first n values produced becomes the first row.
; The next n values becomes the second row and so on.
; The bindings in clauses run nested.
; (for*/matrix m n #:column (clause ...) . defs+exprs)
; Return an m x n matrix with elements from the last expr.
; The first m values produced becomes the first column.
; The next m values becomes the second column and so on.
; The bindings in clauses run nested.
(define-syntax (for*/matrix stx)
(syntax-case stx ()
[(_ m-expr n-expr #:column (clause ...) . defs+exprs)
(syntax/loc stx
(let* ([m m-expr]
[n n-expr]
[v (make-vector (* m n) 0)]
[w (for*/vector #:length (* m n) (clause ...) . defs+exprs)])
(for* ([i (in-range m)] [j (in-range n)])
(vector-set! v (+ (* i n) j)
(vector-ref w (+ (* j m) i))))
(flat-vector->matrix m n v)))]
[(_ m-expr n-expr (clause ...) . defs+exprs)
(syntax/loc stx
(let ([m m-expr] [n n-expr])
(flat-vector->matrix
m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))]))
; TODO: The following is uncommented until matrix+ can be imported.
; (for/matrix-sum (clause ...) . defs+exprs)
; Return the matrix sum of all matrices produced by the last expr.
; The bindings in clauses are parallel.
;(define-syntax (for/matrix-sum stx)
; (syntax-case stx ()
; [(_ (clause ...) . defs+exprs)
; (syntax/loc stx
; (let ([ms (for/list (clause ...) . defs+exprs)])
; (foldl matrix+ (first ms) (rest ms))))]))
;
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
; (for/matrix-sum ([i 3]) M))
; (flat-vector->matrix 2 2 #(3 6 9 12)))
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
; (for/matrix-sum ([i 2] [j 2]) M))
; (flat-vector->matrix 2 2 #(2 4 6 8)))
; (for*/matrix-sum (clause ...) . defs+exprs)
; Return the matrix sum of all matrices produced by the last expr.
; The bindings in clauses are in nested.
;(define-syntax (for*/matrix-sum stx)
; (syntax-case stx ()
; [(_ (clause ...) . defs+exprs)
; (syntax/loc stx
; (let ([ms (for*/list (clause ...) . defs+exprs)])
; (foldl matrix+ (first ms) (rest ms))))]))
;
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
; (for*/matrix-sum ([i 2] [j 2]) M))
; (flat-vector->matrix 2 2 #(4 8 12 16)))
;;;
;;; SEQUENCES
;;;
"matrix-types.rkt"
"matrix-basic.rkt"
"matrix-constructors.rkt"
)
(define (in-row/proc M r)
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(make-do-sequence
(λ ()
(values
@ -120,12 +36,12 @@
(:do-in
([(M r n d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(define-values (rd cd) (matrix-shape M1))
(values M1 r-expr rd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(unless (matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? r)
(raise-type-error 'in-row "expected row number, got ~a" r))
@ -142,12 +58,12 @@
(:do-in
([(M r n d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(define-values (rd cd) (matrix-shape M1))
(values M1 r-expr rd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(unless (matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? r)
(raise-type-error 'in-row "expected row number, got ~a" r)))
@ -167,7 +83,7 @@
(define (in-column/proc M s)
(define-values (m n) (matrix-dimensions M))
(define-values (m n) (matrix-shape M))
(make-do-sequence
(λ ()
(values
@ -190,12 +106,12 @@
(:do-in
([(M s n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(define-values (rd cd) (matrix-shape M1))
(values M1 s-expr rd cd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(unless (matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? s)
(raise-type-error 'in-row "expected col number, got ~a" s))
@ -212,12 +128,12 @@
(:do-in
([(M s n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-dimensions M1))
(define-values (rd cd) (matrix-shape M1))
(values M1 s-expr rd cd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (array-matrix? M)
(unless (matrix? M)
(raise-type-error 'in-column "expected matrix, got ~a" M))
(unless (integer? s)
(raise-type-error 'in-column "expected col number, got ~a" s))
@ -233,28 +149,192 @@
[[_ clause] (raise-syntax-error
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
(define-syntax (for/matrix-sum: stx)
(syntax-case stx ()
[(_ : type (for:-clause ...) . defs+exprs)
(syntax/loc stx
(let ()
(define: sum : (U False (Matrix Number)) #f)
(for: (for:-clause ...)
(define a (let () . defs+exprs))
(set! sum (if sum (array+ (assert sum) a) a)))
(assert sum)))]))
#|
;;;
;;; SEQUENCES
;;;
(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number))
(define (in-row/proc M r)
(define-values (m n) (matrix-shape M))
(make-do-sequence
(λ ()
(values
; pos->element
(λ: ([j : Index]) (matrix-ref M r j))
; next-pos
(λ: ([j : Index]) (assert (+ j 1) index?))
; initial-pos
0
; continue-with-pos?
(λ: ([j : Index ]) (< j n))
#f #f))))
(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number))
(define (in-column/proc M s)
(define-values (m n) (matrix-shape M))
(make-do-sequence
(λ ()
(values
; pos->element
(λ: ([i : Index]) (matrix-ref M i s))
; next-pos
(λ: ([i : Index]) (assert (+ i 1) index?))
; initial-pos
0
; continue-with-pos?
(λ: ([i : Index]) (< i m))
#f #f))))
; (in-row M i]
; Returns a sequence of all elements of row i,
; that is xi0, xi1, xi2, ...
(define-sequence-syntax in-row
(λ () #'in-row/proc)
(λ (stx)
(syntax-case stx ()
[[(x) (_ M-expr r-expr)]
#'((x)
(:do-in
([(M r n d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-shape M1))
(values M1 r-expr rd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? r)
(raise-type-error 'in-row "expected row number, got ~a" r))
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
(raise-type-error 'in-row "expected row number, got ~a" r)))
([j 0])
(< j n)
([(x) (vector-ref d (+ (* r n) j))])
#true
#true
[(+ j 1)]))]
[[(i x) (_ M-expr r-expr)]
#'((i x)
(:do-in
([(M r n d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-shape M1))
(values M1 r-expr rd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? r)
(raise-type-error 'in-row "expected row number, got ~a" r)))
([j 0])
(< j n)
([(x) (vector-ref d (+ (* r n) j))]
[(i) j])
#true
#true
[(+ j 1)]))]
[[_ clause] (raise-syntax-error
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
; (in-column M j]
; Returns a sequence of all elements of column j,
; that is x0j, x1j, x2j, ...
(define-sequence-syntax in-column
(λ () #'in-column/proc)
(λ (stx)
(syntax-case stx ()
; M-expr evaluates to column
[[(x) (_ M-expr)]
#'((x)
(:do-in
([(M n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-shape M1))
(values M1 rd cd
(mutable-array-data
(array->mutable-array M1))))])
(unless (matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
([j 0])
(< j n)
([(x) (vector-ref d j)])
#true
#true
[(+ j 1)]))]
; M-expr evaluats to matrix, s-expr to the column index
[[(x) (_ M-expr s-expr)]
#'((x)
(:do-in
([(M s n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-shape M1))
(values M1 s-expr rd cd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (matrix? M)
(raise-type-error 'in-row "expected matrix, got ~a" M))
(unless (integer? s)
(raise-type-error 'in-row "expected col number, got ~a" s))
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
(raise-type-error 'in-col "expected col number, got ~a" s)))
([j 0])
(< j m)
([(x) (vector-ref d (+ (* j n) s))])
#true
#true
[(+ j 1)]))]
[[(i x) (_ M-expr s-expr)]
#'((x)
(:do-in
([(M s n m d)
(let ([M1 M-expr])
(define-values (rd cd) (matrix-shape M1))
(values M1 s-expr rd cd
(mutable-array-data
(array->mutable-array M1))))])
(begin
(unless (matrix? M)
(raise-type-error 'in-column "expected matrix, got ~a" M))
(unless (integer? s)
(raise-type-error 'in-column "expected col number, got ~a" s))
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
(raise-type-error 'in-column "expected col number, got ~a" s)))
([j 0])
(< j m)
([(x) (vector-ref d (+ (* j n) s))]
[(i) j])
#true
#true
[(+ j 1)]))]
[[_ clause] (raise-syntax-error
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|#
#;
(module* test #f
(require (except-in math/matrix in-row in-column)
rackunit)
(require rackunit)
; "matrix-sequences.rkt"
; These work in racket not in typed racket
(check-equal? (matrix->list (for*/matrix 2 3 ([i 2] [j 3]) (+ i j)))
'[[0 1 2] [1 2 3]])
(check-equal? (matrix->list (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j)))
'[[0 2 2] [1 1 3]])
(check-equal? (matrix->list (for*/matrix 2 2 #:column ([i 4]) i))
'[[0 2] [1 3]])
(check-equal? (matrix->list (for/matrix 2 2 ([i 4]) i))
'[[0 1] [2 3]])
(check-equal? (matrix->list (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j)))
'[[6 8 10] [12 14 16]])
(check-equal? (for/list ([x (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
(check-equal? (for/list ([x (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
'(3 4))
(check-equal? (for/list ([(i x) (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)])
(check-equal? (for/list ([(i x) (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)])
(list i x))
'((0 3) (1 4)))
(check-equal? (for/list ([x (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
(check-equal? (for/list ([x (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
'(2 4))
(check-equal? (for/list ([(i x) (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)])
(check-equal? (for/list ([(i x) (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)])
(list i x))
'((0 2) (1 4))))

View File

@ -1,16 +1,21 @@
#lang typed/racket
#lang typed/racket/base
(require "../array/array-struct.rkt"
"../array/array-fold.rkt"
"../array/array-pointwise.rkt"
"../unsafe.rkt")
(provide Matrix
Column Result-Column
Column-Matrix
array-matrix?
matrix-all=
matrix?
square-matrix?
row-matrix?
col-matrix?
matrix-shape
square-matrix-size
matrix-dimensions
matrix-row-dimension
matrix-column-dimension)
(require math/array)
matrix-num-rows
matrix-num-cols)
(define-type (Matrix A) (Array A))
; matrices are represented as arrays
@ -22,39 +27,44 @@
(define-type (Column-Matrix A) (Matrix A))
; a column vector represented as a matrix
(: array-matrix? (All (A) ((Array A) -> Boolean)))
(define (array-matrix? x)
(= 2 (array-dims x)))
(define matrix?
(plambda: (A) ([arr : (Array A)])
(and (> (array-size arr) 0)
(= (array-dims arr) 2))))
(: square-matrix? : (All (A) (Matrix A) -> Boolean))
(define (square-matrix? a)
(and (array-matrix? a)
(let ([sh (array-shape a)])
(= (vector-ref sh 0) (vector-ref sh 1)))))
(define square-matrix?
(plambda: (A) ([arr : (Array A)])
(and (matrix? arr)
(let ([sh (array-shape arr)])
(= (vector-ref sh 0) (vector-ref sh 1))))))
(: square-matrix-size : (All (A) (Matrix A) -> Index))
(define (square-matrix-size a)
(vector-ref (array-shape a) 0))
(define row-matrix?
(plambda: (A) ([arr : (Array A)])
(and (matrix? arr)
(= (vector-ref (array-shape arr) 0) 1))))
(: matrix-all= : (Matrix Number) (Matrix Number) -> Boolean)
(define (matrix-all= arr0 arr1)
(array-all-and (array= arr0 arr1)))
(define col-matrix?
(plambda: (A) ([arr : (Array A)])
(and (matrix? arr)
(= (vector-ref (array-shape arr) 1) 1))))
(: matrix-dimensions : (Matrix Number) -> (Values Index Index))
(define (matrix-dimensions a)
(define sh (array-shape a))
; TODO: Remove list conversion when trbug1 is fixed
(define sh-tmp (vector->list sh))
(values (car sh-tmp) (cadr sh-tmp))
; (values (vector-ref sh 0) (vector-ref sh 1))
)
(: matrix-shape : (All (A) (Matrix A) -> (Values Index Index)))
(define (matrix-shape a)
(cond [(matrix? a) (define sh (array-shape a))
(values (unsafe-vector-ref sh 0) (unsafe-vector-ref sh 1))]
[else (raise-argument-error 'matrix-shape "matrix?" a)]))
(: matrix-row-dimension : (Matrix Number) -> Index)
(define (matrix-row-dimension a)
(define sh (array-shape a))
(vector-ref sh 0))
(: square-matrix-size (All (A) ((Matrix A) -> Index)))
(define (square-matrix-size arr)
(cond [(square-matrix? arr) (unsafe-vector-ref (array-shape arr) 0)]
[else (raise-argument-error 'square-matrix-size "square-matrix?" arr)]))
(: matrix-column-dimension : (Matrix Number) -> Index)
(define (matrix-column-dimension a)
(define sh (array-shape a))
(vector-ref sh 1))
(: matrix-num-rows (All (A) ((Matrix A) -> Index)))
(define (matrix-num-rows a)
(cond [(matrix? a) (vector-ref (array-shape a) 0)]
[else (raise-argument-error 'matrix-col-length "matrix?" a)]))
(: matrix-num-cols (All (A) ((Matrix A) -> Index)))
(define (matrix-num-cols a)
(cond [(matrix? a) (vector-ref (array-shape a) 1)]
[else (raise-argument-error 'matrix-row-length "matrix?" a)]))

View File

@ -0,0 +1,93 @@
#lang typed/racket/base
(require racket/list
math/array
"matrix-types.rkt"
"utils.rkt"
(except-in "untyped-matrix-arithmetic.rkt" matrix-map))
(provide matrix-map
matrix=
matrix*
matrix+
matrix-
matrix-scale
matrix-sum)
(: matrix-map (All (R A B T ...)
(case-> ((A -> R) (Array A) -> (Array R))
((A B T ... T -> R) (Array A) (Array B) (Array T) ... T -> (Array R)))))
(define matrix-map
(case-lambda:
[([f : (A -> R)] [arr : (Array A)])
(inline-matrix-map f arr)]
[([f : (A B -> R)] [arr0 : (Array A)] [arr1 : (Array B)])
(inline-matrix-map f arr0 arr1)]
[([f : (A B T ... T -> R)] [arr0 : (Array A)] [arr1 : (Array B)] . [arrs : (Array T) ... T])
(define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs))
(define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1))
(define gs (map unsafe-array-proc arrs))
(unsafe-build-array
((inst vector Index) m n)
(λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> T)]) (g js)) gs))))]))
(: matrix=? ((Array Number) (Array Number) -> Boolean))
(define (matrix=? arr0 arr1)
(define-values (m0 n0) (matrix-shape arr0))
(define-values (m1 n1) (matrix-shape arr1))
(and (= m0 m1)
(= n0 n1)
(let ([proc0 (unsafe-array-proc arr0)]
[proc1 (unsafe-array-proc arr1)])
(array-all-and (unsafe-build-array
((inst vector Index) m0 n0)
(λ: ([js : Indexes])
(= (proc0 js) (proc1 js))))))))
(: matrix= (case-> ((Array Number) (Array Number) -> Boolean)
((Array Number) (Array Number) (Array Number) (Array Number) * -> Boolean)))
(define matrix=
(case-lambda:
[([arr0 : (Array Number)] [arr1 : (Array Number)]) (matrix=? arr0 arr1)]
[([arr0 : (Array Number)] [arr1 : (Array Number)] . [arrs : (Array Number) *])
(and (matrix=? arr0 arr1)
(let: loop : Boolean ([arr1 : (Array Number) arr1]
[arrs : (Listof (Array Number)) arrs])
(cond [(empty? arrs) #t]
[else (and (matrix=? arr1 (first arrs))
(loop (first arrs) (rest arrs)))])))]))
(: matrix* (case-> ((Array Real) (Array Real) * -> (Array Real))
((Array Number) (Array Number) * -> (Array Number))))
(define (matrix* a . as)
(let loop ([a a] [as as])
(cond [(empty? as) a]
[else (loop (inline-matrix* a (first as)) (rest as))])))
(: matrix+ (case-> ((Array Real) (Array Real) * -> (Array Real))
((Array Number) (Array Number) * -> (Array Number))))
(define (matrix+ a . as)
(let loop ([a a] [as as])
(cond [(empty? as) a]
[else (loop (inline-matrix+ a (first as)) (rest as))])))
(: matrix- (case-> ((Array Real) (Array Real) * -> (Array Real))
((Array Number) (Array Number) * -> (Array Number))))
(define (matrix- a . as)
(cond [(empty? as) (inline-matrix- a)]
[else
(let loop ([a a] [as as])
(cond [(empty? as) a]
[else (loop (inline-matrix- a (first as)) (rest as))]))]))
(: matrix-scale (case-> ((Array Real) Real -> (Array Real))
((Array Number) Number -> (Array Number))))
(define (matrix-scale a x) (inline-matrix-scale a x))
(: matrix-sum (case-> ((Listof (Array Real)) -> (Array Real))
((Listof (Array Number)) -> (Array Number))))
(define (matrix-sum lst)
(cond [(empty? lst) (raise-argument-error 'matrix-sum "nonempty List" lst)]
[else (apply matrix+ lst)]))

View File

@ -0,0 +1,105 @@
#lang racket/base
(provide inline-matrix*
inline-matrix+
inline-matrix-
inline-matrix-scale
inline-matrix-map
matrix-map)
(module syntax-defs racket/base
(require (for-syntax racket/base)
(only-in typed/racket/base λ: : inst Index)
math/array
"matrix-types.rkt"
"utils.rkt")
(provide (all-defined-out))
;(: matrix-multiply ((Array Number) (Array Number) -> (Array Number)))
;; This is a macro so the result can have as precise a type as possible
(define-syntax-rule (inline-matrix-multiply arr-expr brr-expr)
(let ([arr arr-expr]
[brr brr-expr])
(let-values ([(m p n) (matrix-multiply-shape arr brr)]
;; Make arr strict because its elements are reffed multiple times
[(_) (array-strict! arr)])
(let (;; Extend arr in the center dimension
[arr-proc (unsafe-array-proc (array-axis-insert arr 1 n))]
;; Transpose brr and extend in the leftmost dimension
[brr-proc (unsafe-array-proc
(array-axis-insert (array-strict (array-axis-swap brr 0 1)) 0 m))])
;; The *transpose* of brr is traversed in row-major order when this result is traversed
;; in row-major order (which is why the transpose is strict, not brr)
(array-axis-sum
(unsafe-build-array
((inst vector Index) m n p)
(λ: ([js : Indexes])
(* (arr-proc js) (brr-proc js))))
2)))))
(define-syntax (inline-matrix* stx)
(syntax-case stx ()
[(_ arr)
(syntax/loc stx arr)]
[(_ arr brr crrs ...)
(syntax/loc stx (inline-matrix* (inline-matrix-multiply arr brr) crrs ...))]))
(define-syntax (inline-matrix-map stx)
(syntax-case stx ()
[(_ f arr-expr)
(syntax/loc stx
(let*-values ([(arr) arr-expr]
[(m n) (matrix-shape arr)]
[(proc) (unsafe-array-proc arr)])
(unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (f (proc js))))))]
[(_ f arr-expr brr-exprs ...)
(with-syntax ([(brrs ...) (generate-temporaries #'(brr-exprs ...))]
[(procs ...) (generate-temporaries #'(brr-exprs ...))])
(syntax/loc stx
(let ([arr arr-expr]
[brrs brr-exprs] ...)
(let-values ([(m n) (matrix-shapes 'matrix-map arr brrs ...)]
[(proc) (unsafe-array-proc arr)]
[(procs) (unsafe-array-proc brrs)] ...)
(unsafe-build-array
((inst vector Index) m n)
(λ: ([js : Indexes])
(f (proc js) (procs js) ...)))))))]))
(define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...))
(define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...))
(define-syntax-rule (inline-matrix-scale arr x) (inline-matrix-map (λ (y) (* x y)) arr))
) ; module
(require 'syntax-defs)
(module untyped-defs typed/racket/base
(require math/array
(submod ".." syntax-defs)
"utils.rkt")
(provide matrix-map)
(: matrix-map (All (R A) (case-> ((A -> R) (Array A) -> (Array R))
((A A A * -> R) (Array A) (Array A) (Array A) * -> (Array R)))))
(define matrix-map
(case-lambda:
[([f : (A -> R)] [arr : (Array A)])
(inline-matrix-map f arr)]
[([f : (A A -> R)] [arr0 : (Array A)] [arr1 : (Array A)])
(inline-matrix-map f arr0 arr1)]
[([f : (A A A * -> R)] [arr0 : (Array A)] [arr1 : (Array A)] . [arrs : (Array A) *])
(define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs))
(define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1))
(define gs (map (inst unsafe-array-proc A) arrs))
(unsafe-build-array
((inst vector Index) m n)
(λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> A)]) (g js)) gs))))]))
) ; module
(require 'untyped-defs)

View File

@ -1,6 +1,38 @@
#lang typed/racket/base
(require math/array)
(require racket/match
racket/string
math/array
"matrix-types.rkt")
(provide (all-defined-out))
(: format-matrices/error ((Listof (Array Any)) -> String))
(define (format-matrices/error as)
(string-join (map (λ: ([a : (Array Any)]) (format "~e" a)) as)))
(: matrix-shapes (Symbol (Matrix Any) (Matrix Any) * -> (Values Index Index)))
(define (matrix-shapes name arr . brrs)
(define-values (m n) (matrix-shape arr))
(unless (andmap (λ: ([brr : (Matrix Any)])
(match-define (vector bm bn) (array-shape brr))
(and (= bm m) (= bn n)))
brrs)
(error name
"matrices must have the same shape; given ~a"
(format-matrices/error (cons arr brrs))))
(values m n))
(: matrix-multiply-shape ((Matrix Any) (Matrix Any) -> (Values Index Index Index)))
(define (matrix-multiply-shape arr brr)
(define-values (ad0 ad1) (matrix-shape arr))
(define-values (bd0 bd1) (matrix-shape brr))
(unless (= ad1 bd0)
(error 'matrix-multiply
"1st argument column size and 2nd argument row size are not equal; given ~e and ~e"
arr brr))
(values ad0 ad1 bd1))
(: ensure-matrix (All (A) Symbol (Array A) -> (Array A)))
(define (ensure-matrix name a)
(if (matrix? a) a (raise-argument-error name "matrix?" a)))

View File

@ -1385,7 +1385,7 @@ Returns an array with shape @racket[(vector (array-size arr))], with the element
@;{==================================================================================================}
@section{Folds and Other Axis Reductions}
@section{Folds, Reductions and Expansions}
@defproc*[([(array-axis-fold [arr (Array A)] [k Integer] [f (A A -> A)]) (Array A)]
[(array-axis-fold [arr (Array A)] [k Integer] [f (A B -> B)] [init B]) (Array B)])]{
@ -1547,18 +1547,6 @@ and @racket[array-all-or] is defined similarly.
(array-all-or (array= arr (array 0)))]
}
@defproc[(array->list-array [arr (Array A)] [k Integer]) (Array (Listof A))]{
Returns an array of lists, computed as if by applying @racket[list] to the elements in each row of
axis @racket[k].
@examples[#:eval typed-eval
(define arr (index-array #(3 3)))
arr
(array->list-array arr 1)
(array-ref (array->list-array (array->list-array arr 1) 0) #())]
See @racket[mean] for more useful examples, and @racket[array-axis-reduce] for an example that shows
how @racket[array->list-array] is implemented.
}
@defproc[(array-axis-reduce [arr (Array A)] [k Integer] [h (Index (Integer -> A) -> B)]) (Array B)]{
Like @racket[array-axis-fold], but allows evaluation control (such as short-cutting @racket[and] and
@racket[or]) by moving the loop into @racket[h]. The result has the shape of @racket[arr], but with
@ -1591,6 +1579,60 @@ Every fold, including @racket[array-axis-fold], is ultimately defined using
@racket[array-axis-reduce] or its unsafe counterpart.
}
@defproc[(array-axis-expand [arr (Array A)] [k Integer] [dk Integer] [g (A Index -> B)]) (Array B)]{
Inserts a new axis number @racket[k] of length @racket[dk], using @racket[g] to generate values;
@racket[k] must be @italic{no greater than} the dimension of @racket[arr], and @racket[dk] must be
nonnegative.
Conceptually, @racket[g] is applied @racket[dk] times to each element in each row of axis @racket[k],
once for each nonnegative index @racket[jk < dk]. (In reality, @racket[g] is applied only when the
resulting array is indexed.)
Turning vector elements into rows of a new last axis using @racket[array-axis-expand] and
@racket[vector-ref]:
@interaction[#:eval typed-eval
(define arr (array #['#(a b c) '#(d e f) '#(g h i)]))
(array-axis-expand arr 1 3 vector-ref)]
Creating a @racket[vandermonde-matrix]:
@interaction[#:eval typed-eval
(array-axis-expand (list->array '(1 2 3 4)) 1 5 expt)]
This function is a dual of @racket[array-axis-reduce] in that it can be used to invert applications
of @racket[array-axis-reduce].
To do so, @racket[g] should be a destructuring function that is dual to the constructor passed to
@racket[array-axis-reduce].
Example dual pairs are @racket[vector-ref] and @racket[build-vector], and @racket[list-ref] and
@racket[build-list].
(Do not pass @racket[list-ref] to @racket[array-axis-expand] if you care about performance, though.
See @racket[list-array->array] for a more efficient solution.)
}
@defproc[(array->list-array [arr (Array A)] [k Integer 0]) (Array (Listof A))]{
Returns an array of lists, computed as if by applying @racket[list] to the elements in each row of
axis @racket[k].
@examples[#:eval typed-eval
(define arr (index-array #(3 3)))
arr
(array->list-array arr 1)
(array-ref (array->list-array (array->list-array arr 1) 0) #())]
See @racket[mean] for more useful examples, and @racket[array-axis-reduce] for an example that shows
how @racket[array->list-array] is implemented.
}
@defproc[(list-array->array [arr (Array (Listof A))] [k Integer 0]) (Array A)]{
Returns an array in which the list elements of @racket[arr] comprise a new axis @racket[k].
Equivalent to @racket[(array-axis-expand arr k n list-ref)] where @racket[n] is the
length of the lists in @racket[arr], but with O(1) indexing.
@examples[#:eval typed-eval
(define arr (array->list-array (index-array #(3 3)) 1))
arr
(list-array->array arr 1)]
For fixed @racket[k], this function and @racket[array->list-array] are mutual inverses with respect
to their array arguments.
}
@;{==================================================================================================}

View File

@ -3,28 +3,86 @@
(require math/array
math/base
math/flonum
math/matrix)
math/matrix
"../private/matrix/matrix-column.rkt"
"test-utils.rkt")
(: random-matrix (Integer Integer Integer -> (Matrix Integer)))
;; Generates a random matrix with integer elements < k. Useful to test properties.
(define (random-matrix m n k)
(array-strict (build-array (vector m n) (λ (_) (random k)))))
;; ===================================================================================================
;; Types
(check-true (matrix? (array #[#[1]])))
(check-false (matrix? (array #[1])))
(check-false (matrix? (array 1)))
(check-false (matrix? (array #[])))
(check-true (row-matrix? (matrix [[1 2 3 4]])))
(check-false (row-matrix? (matrix [[1] [2] [3] [4]])))
(check-true (col-matrix? (matrix [[1] [2] [3] [4]])))
(check-false (col-matrix? (matrix [[1 2 3 4]])))
;; ===================================================================================================
;; Matrix multiplication
(check-equal? (matrix* (identity-matrix 2)
(matrix [[1 20] [300 4000]]))
(matrix [[1 20] [300 4000]]))
(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]])
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
(matrix [[30 36 42] [66 81 96] [102 126 150]]))
(let ([m0 (random-matrix 4 5 100)]
[m1 (random-matrix 5 2 100)]
[m2 (random-matrix 2 10 100)])
(check-equal? (matrix* (matrix* m0 m1) m2)
(matrix* m0 (matrix* m1 m2))))
;; ===================================================================================================
;; Construction
(check-equal?
(block-diagonal-matrix
(list
(matrix [[1 2] [3 4]])
(matrix [[1 2 3] [4 5 6]])
(matrix [[1] [3] [5]])))
(matrix
[[1 2 0 0 0 0]
[3 4 0 0 0 0]
[0 0 1 2 3 0]
[0 0 4 5 6 0]
[0 0 0 0 0 1]
[0 0 0 0 0 3]
[0 0 0 0 0 5]]))
;; ===================================================================================================
(begin
(begin "matrix-types.rkt"
(list
'array-matrix?
(array-matrix? (list*->array '[[1 2] [3 4]] real? ))
(not (array-matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? ))))
'matrix?
(matrix? (list*->array '[[1 2] [3 4]] real? ))
(not (matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? ))))
(list
'square-matrix?
(square-matrix? (list*->array '[[1 2] [3 4]] real? ))
(not (square-matrix? (list*->array '[[1 2 3] [4 5 6]] real? ))))
(list
'square-matrix-size
(= 2 (square-matrix-size (list*->array '[[1 2 3] [4 5 6]] real? ))))
(= 3 (square-matrix-size (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real?))))
(list
'matrix-all=-
(matrix-all= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? ))
(not (matrix-all= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? ))))
'matrix=
(matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? ))
#;(not (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? ))))
(list
'matrix-dimensions
(let-values ([(m n) (matrix-dimensions (list->matrix '[[1 2 3] [4 5 6]]))])
'matrix-shape
(let-values ([(m n) (matrix-shape (list*->matrix '[[1 2 3] [4 5 6]]))])
(equal? (list m n) '(2 3)))))
(begin "matrix-constructors.rkt"
@ -32,47 +90,46 @@
'identity-matrix
(equal? (array->list* (identity-matrix 1)) '[[1]])
(equal? (array->list* (identity-matrix 2)) '[[1 0] [0 1]])
(equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]])
(equal? (array->list* (flidentity-matrix 1)) '[[1.]])
(equal? (array->list* (flidentity-matrix 2)) '[[1. 0.] [0. 1.]])
(equal? (array->list* (flidentity-matrix 3)) '[[1. 0. 0.] [0. 1. 0.] [0. 0. 1.]]))
(equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]]))
(list
'const-matrix
(equal? (array->list* (make-matrix 2 3 0)) '((0 0 0) (0 0 0)))
(equal? (array->list* (make-matrix 2 3 0.)) '((0. 0. 0.) (0. 0. 0.))))
(list
'matrix->list
(equal? (matrix->list (list->matrix '((1 2) (3 4)))) '((1 2) (3 4)))
(equal? (matrix->list (fllist->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.))))
(equal? (matrix->list* (list*->matrix '((1 2) (3 4)))) '((1 2) (3 4)))
(equal? (matrix->list* (list*->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.))))
(list
'matrix->vector
(equal? (matrix->vector (vector->matrix '#(#(1 2) #(3 4)))) '#(#(1 2) #(3 4)))
(equal? (matrix->vector (flvector->matrix '#(#(1. 2.) #(3. 4.)))) '#(#(1. 2.) #(3. 4.))))
(equal? (matrix->vector* ((inst vector*->matrix Integer) '#(#(1 2) #(3 4))))
'#(#(1 2) #(3 4)))
(equal? (matrix->vector* ((inst vector*->matrix Flonum) '#(#(1. 2.) #(3. 4.))))
'#(#(1. 2.) #(3. 4.))))
(list
'matrix-row
(equal? (matrix-row (identity-matrix 3) 0) (list->matrix '[[1 0 0]]))
(equal? (matrix-row (identity-matrix 3) 1) (list->matrix '[[0 1 0]]))
(equal? (matrix-row (identity-matrix 3) 2) (list->matrix '[[0 0 1]])))
(equal? (matrix-row (identity-matrix 3) 0) (list*->matrix '[[1 0 0]]))
(equal? (matrix-row (identity-matrix 3) 1) (list*->matrix '[[0 1 0]]))
(equal? (matrix-row (identity-matrix 3) 2) (list*->matrix '[[0 0 1]])))
(list
'matrix-col
(equal? (matrix-column (identity-matrix 3) 0) (list->matrix '[[1] [0] [0]]))
(equal? (matrix-column (identity-matrix 3) 1) (list->matrix '[[0] [1] [0]]))
(equal? (matrix-column (identity-matrix 3) 2) (list->matrix '[[0] [0] [1]])))
(equal? (matrix-col (identity-matrix 3) 0) (list*->matrix '[[1] [0] [0]]))
(equal? (matrix-col (identity-matrix 3) 1) (list*->matrix '[[0] [1] [0]]))
(equal? (matrix-col (identity-matrix 3) 2) (list*->matrix '[[0] [0] [1]])))
(list
'submatrix
(equal? (submatrix (identity-matrix 3)
(in-range 0 1) (in-range 0 2)) (list->matrix '[[1 0]]))
(in-range 0 1) (in-range 0 2)) (list*->matrix '[[1 0]]))
(equal? (submatrix (identity-matrix 3)
(in-range 0 2) (in-range 0 3)) (list->matrix '[[1 0 0] [0 1 0]]))))
(in-range 0 2) (in-range 0 3)) (list*->matrix '[[1 0 0] [0 1 0]]))))
(begin
"matrix-pointwise.rkt"
(let ()
(define A (list->matrix '[[1 2] [3 4]]))
(define ~A (list->matrix '[[-1 -2] [-3 -4]]))
(define B (list->matrix '[[5 6] [7 8]]))
(define A+B (list->matrix '[[6 8] [10 12]]))
(define A-B (list->matrix '[[-4 -4] [-4 -4]]))
(define A (list*->matrix '[[1 2] [3 4]]))
(define ~A (list*->matrix '[[-1 -2] [-3 -4]]))
(define B (list*->matrix '[[5 6] [7 8]]))
(define A+B (list*->matrix '[[6 8] [10 12]]))
(define A-B (list*->matrix '[[-4 -4] [-4 -4]]))
(list 'matrix+ (equal? (matrix+ A B) A+B))
(list 'matrix-
(equal? (matrix- A B) A-B)
@ -81,102 +138,102 @@
(begin
"matrix-expt.rkt"
(let ()
(define A (list->matrix '[[1 2] [3 4]]))
(define A (list*->matrix '[[1 2] [3 4]]))
(list
'matrix-expt
(equal? (matrix-expt A 0) (identity-matrix 2))
(equal? (matrix-expt A 1) A)
(equal? (matrix-expt A 2) (list->matrix '[[7 10] [15 22]]))
(equal? (matrix-expt A 3) (list->matrix '[[37 54] [81 118]]))
(equal? (matrix-expt A 8) (list->matrix '[[165751 241570] [362355 528106]]))))
#;(list
(define A (fllist->matrix '[[1. 2.] [3. 4.]]))
(check-equal? (matrix->list (flmatrix-expt A 0)) (matrix->list (flidentity-matrix 2)))
(check-equal? (matrix->list (flmatrix-expt A 1)) (matrix->list A))
(check-equal? (matrix->list (flmatrix-expt A 2)) '[[7. 10.] [15. 22.]])
(check-equal? (matrix->list (flmatrix-expt A 3)) '[[37. 54.] [81. 118.]])
(check-equal? (matrix->list (flmatrix-expt A 8)) '[[165751. 241570.] [362355. 528106.]])))
(equal? (matrix-expt A 2) (list*->matrix '[[7 10] [15 22]]))
(equal? (matrix-expt A 3) (list*->matrix '[[37 54] [81 118]]))
(equal? (matrix-expt A 8) (list*->matrix '[[165751 241570] [362355 528106]])))))
(begin
"matrix-operations.rkt"
(list 'vandermonde-matrix
(equal? (vandermonde-matrix '(1 2 3) 5)
(list->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]])))
(list*->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]])))
#;
(list 'in-column
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix/dim 2 2 1 2 3 4) 0)]) x)
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 0)])
x)
'(1 3))
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix/dim 2 2 1 2 3 4) 1)]) x)
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 1)])
x)
'(2 4))
(equal? (for/list: : (Listof Number) ([x (in-column (column 5 2 3))]) x)
(equal? (for/list: : (Listof Number) ([x (in-column (col-matrix [5 2 3]))]) x)
'(5 2 3)))
#;
(list 'in-row
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix/dim 2 2 1 2 3 4) 0)]) x)
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 0)])
x)
'(1 2))
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix/dim 2 2 1 2 3 4) 1)]) x)
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 1)])
x)
'(3 4)))
(list 'for/matrix:
(equal? (for/matrix: : Number 2 4 ([i (in-naturals)]) i)
(matrix/dim 2 4
0 1 2 3
4 5 6 7))
(matrix [[0 1 2 3] [4 5 6 7]]))
(equal? (for/matrix: : Number 2 4 #:column ([i (in-naturals)]) i)
(matrix/dim 2 4
0 2 4 6
1 3 5 7))
(matrix [[0 2 4 6] [1 3 5 7]]))
(equal? (for/matrix: : Number 3 3 ([i (in-range 10 100)]) i)
(matrix/dim 3 3 10 11 12 13 14 15 16 17 18)))
(matrix [[10 11 12] [13 14 15] [16 17 18]])))
(list 'for*/matrix:
(equal? (for*/matrix: : Number 3 3 ([i (in-range 3)] [j (in-range 3)]) (+ (* i 10) j))
(matrix/dim 3 3 0 1 2 10 11 12 20 21 22)))
(matrix [[0 1 2] [10 11 12] [20 21 22]])))
(list 'matrix-block-diagonal
(equal? (matrix-block-diagonal (list (matrix/dim 2 2 1 2 3 4) (matrix/dim 1 3 5 6 7)))
(list->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]])))
(equal? (block-diagonal-matrix (list (matrix [[1 2] [3 4]]) (matrix [[5 6 7]])))
(list*->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]])))
(list 'matrix-augment
(equal? (matrix-augment (column 1 2 3) (column 4 5 6) (column 7 8 9))
(matrix/dim 3 3 1 4 7 2 5 8 3 6 9)))
(equal? (matrix-augment (list (col-matrix [1 2 3])
(col-matrix [4 5 6])
(col-matrix [7 8 9])))
(matrix [[1 4 7] [2 5 8] [3 6 9]])))
(list 'matrix-stack
(equal? (matrix-stack (column 1 2 3) (column 4 5 6) (column 7 8 9))
(column 1 2 3 4 5 6 7 8 9)))
(equal? (matrix-stack (list (col-matrix [1 2 3])
(col-matrix [4 5 6])
(col-matrix [7 8 9])))
(col-matrix [1 2 3 4 5 6 7 8 9])))
#;
(list 'column-dimension
(= (column-dimension #(1 2 3)) 3)
(= (column-dimension (flat-vector->matrix 1 2 #(1 2))) 1))
(let ([matrix: flat-vector->matrix])
(= (column-dimension (vector->matrix 1 2 #(1 2))) 1))
(let ([matrix: vector->matrix])
(list 'column-dot
(= (column-dot (column 1 2) (column 1 2)) 5)
(= (column-dot (column 1 2) (column 3 4)) 11)
(= (column-dot (column 3 4) (column 3 4)) 25)
(= (column-dot (column 1 2 3) (column 4 5 6))
(= (column-dot (col-matrix [1 2]) (col-matrix [1 2])) 5)
(= (column-dot (col-matrix [1 2]) (col-matrix [3 4])) 11)
(= (column-dot (col-matrix [3 4]) (col-matrix [3 4])) 25)
(= (column-dot (col-matrix [1 2 3]) (col-matrix [4 5 6]))
(+ (* 1 4) (* 2 5) (* 3 6)))
(= (column-dot (column +3i +4i) (column +3i +4i))
(= (column-dot (col-matrix [+3i +4i]) (col-matrix [+3i +4i]))
25)))
(list 'matrix-trace
(equal? (matrix-trace (flat-vector->matrix 2 2 #(1 2 3 4))) 5))
(let ([matrix: flat-vector->matrix])
(equal? (matrix-trace (vector->matrix 2 2 #(1 2 3 4))) 5))
(let ([matrix: vector->matrix])
(list 'column-norm
(= (column-norm (column 2 4)) (sqrt 20))))
(list 'column-projection
(equal? (column-projection #(1 2 3) #(4 5 6)) (column 128/77 160/77 192/77))
(equal? (column-projection (column 1 2 3) (column 2 4 3))
(matrix-scale 19/29 (column 2 4 3))))
(= (column-norm (col-matrix [2 4])) (sqrt 20))))
(list 'column-project
(equal? (column-project #(1 2 3) #(4 5 6)) (col-matrix [128/77 160/77 192/77]))
(equal? (column-project (col-matrix [1 2 3]) (col-matrix [2 4 3]))
(matrix-scale (col-matrix [2 4 3]) 19/29)))
(list 'projection-on-orthogonal-basis
(equal? (projection-on-orthogonal-basis #(3 -2 2) (list #(-1 0 2) #( 2 5 1)))
(column -1/3 -1/3 1/3))
(col-matrix [-1/3 -1/3 1/3]))
(equal? (projection-on-orthogonal-basis
(column 3 -2 2) (list #(-1 0 2) (column 2 5 1)))
(column -1/3 -1/3 1/3)))
(col-matrix [3 -2 2]) (list #(-1 0 2) (col-matrix [2 5 1])))
(col-matrix [-1/3 -1/3 1/3])))
(list 'projection-on-orthonormal-basis
(equal? (projection-on-orthonormal-basis
#(1 2 3 4)
(list (matrix-scale 1/2 (column 1 1 1 1))
(matrix-scale 1/2 (column -1 1 -1 1))
(matrix-scale 1/2 (column 1 -1 -1 1))))
(column 2 3 2 3)))
(list (matrix-scale (col-matrix [ 1 1 1 1]) 1/2)
(matrix-scale (col-matrix [-1 1 -1 1]) 1/2)
(matrix-scale (col-matrix [ 1 -1 -1 1]) 1/2)))
(col-matrix [2 3 2 3])))
(list 'gram-schmidt-orthogonal
(equal? (gram-schmidt-orthogonal (list #(3 1) #(2 2)))
(list (column 3 1) (column -2/5 6/5))))
(list (col-matrix [3 1]) (col-matrix [-2/5 6/5]))))
(list 'vector-normalize
(equal? (column-normalize #(3 4))
(column 3/5 4/5)))
(col-matrix [3/5 4/5])))
(list 'gram-schmidt-orthonormal
(equal? (gram-schmidt-orthonormal (ann '(#(3 1) #(2 2)) (Listof (Column Number))))
(list (column-normalize #(3 1))
@ -184,74 +241,81 @@
(list 'projection-on-subspace
(equal? (projection-on-subspace #(1 2 3) '(#(2 4 3)))
(matrix-scale 19/29 (column 2 4 3))))
(matrix-scale (col-matrix [2 4 3]) 19/29)))
(list 'unit-vector
(equal? (unit-column 4 1) (column 0 1 0 0)))
(equal? (unit-column 4 1) (col-matrix [0 1 0 0])))
(list 'matrix-qr
(let-values ([(Q R) (matrix-qr (matrix/dim 3 2 1 1 0 1 1 1))])
(let-values ([(Q R) (matrix-qr (matrix [[1 1] [0 1] [1 1]]))])
(equal? (list Q R)
(list (matrix/dim 3 2 0.7071067811865475 0 0 1 0.7071067811865475 0)
(matrix/dim 2 2 1.414213562373095 1.414213562373095 0 1))))
(list (matrix [[0.7071067811865475 0]
[0 1]
[0.7071067811865475 0]])
(matrix [[1.414213562373095 1.414213562373095]
[0 1]]))))
(let ()
(define A (matrix/dim 4 4 1 2 3 4 1 2 4 5 1 2 5 6 1 2 6 7))
(define A (matrix [[1 2 3 4] [1 2 4 5] [1 2 5 6] [1 2 6 7]]))
(define-values (Q R) (matrix-qr A))
(equal? (list Q R)
(list
(flat-vector->matrix
4 4 (ann #(1/2 -0.6708203932499369 0.5477225575051662 -0.0
(vector->matrix
4 4 ((inst vector Number)
1/2 -0.6708203932499369 0.5477225575051662 -0.0
1/2 -0.22360679774997896 -0.7302967433402214 0.4082482904638629
1/2 0.22360679774997896 -0.18257418583505536 -0.8164965809277259
1/2 0.6708203932499369 0.3651483716701107 0.408248290463863)
(Vectorof Number)))
(flat-vector->matrix
4 4 (ann #(2 4 9 11 0 0.0 2.23606797749979 2.23606797749979
0 0 0.0 4.440892098500626e-16 0 0 0 0.0)
(Vectorof Number)))))))
1/2 0.6708203932499369 0.3651483716701107 0.408248290463863))
(vector->matrix
4 4 ((inst vector Number)
2 4 9 11 0 0.0 2.23606797749979 2.23606797749979
0 0 0.0 4.440892098500626e-16 0 0 0 0.0))))))
(list 'matrix-solve
(let* ([M (list->matrix '[[1 5] [2 3]])]
[b (list->matrix '[[5] [5]])])
(let* ([M (list*->matrix '[[1 5] [2 3]])]
[b (list*->matrix '[[5] [5]])])
(equal? (matrix* M (matrix-solve M b)) b)))
(list 'matrix-inverse
(equal? (let ([M (list->matrix '[[1 2] [3 4]])]) (matrix* M (matrix-inverse M)))
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* M (matrix-inverse M)))
(identity-matrix 2))
(equal? (let ([M (list->matrix '[[1 2] [3 4]])]) (matrix* (matrix-inverse M) M))
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* (matrix-inverse M) M))
(identity-matrix 2)))
(list 'matrix-determinant
(equal? (matrix-determinant (list->matrix '[[3]])) 3)
(equal? (matrix-determinant (list->matrix '[[1 2] [3 4]])) (- (* 1 4) (* 2 3)))
(equal? (matrix-determinant (list->matrix '[[1 2 3] [4 5 6] [7 8 9]])) 0)
(equal? (matrix-determinant (list->matrix '[[1 2 3] [4 -5 6] [7 8 9]])) 120)
(equal? (matrix-determinant (list->matrix '[[1 2 3 4] [-5 6 7 8] [9 10 -11 12] [13 14 15 16]])) 5280))
(equal? (matrix-determinant (list*->matrix '[[3]])) 3)
(equal? (matrix-determinant (list*->matrix '[[1 2] [3 4]])) (- (* 1 4) (* 2 3)))
(equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 5 6] [7 8 9]])) 0)
(equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 -5 6] [7 8 9]])) 120)
(equal? (matrix-determinant (list*->matrix '[[1 2 3 4]
[-5 6 7 8]
[9 10 -11 12]
[13 14 15 16]]))
5280))
(list 'matrix-scale
(equal? (matrix-scale 2 (list->matrix '[[1 2] [3 4]]))
(list->matrix '[[2 4] [6 8]])))
(equal? (matrix-scale (list*->matrix '[[1 2] [3 4]]) 2)
(list*->matrix '[[2 4] [6 8]])))
(list 'matrix-transpose
(equal? (matrix-transpose (list->matrix '[[1 2] [3 4]]))
(list->matrix '[[1 3] [2 4]])))
(equal? (matrix-transpose (list*->matrix '[[1 2] [3 4]]))
(list*->matrix '[[1 3] [2 4]])))
(list 'matrix-hermitian
(equal? (matrix-hermitian (list->matrix '[[1+i 2-i] [3+i 4-i]]))
(list->matrix '[[1-i 3-i] [2+i 4+i]])))
(equal? (matrix-hermitian (list*->matrix '[[1+i 2-i] [3+i 4-i]]))
(list*->matrix '[[1-i 3-i] [2+i 4+i]])))
(let ()
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
(define (gauss-eliminate M u? p?)
(let-values ([(M wp) (matrix-gauss-eliminate M u? p?)])
M))
(list 'matrix-gauss-eliminate
(equal? (let ([M (list->matrix '[[1 2] [3 4]])])
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])])
(gauss-eliminate M #f #f))
(list->matrix '[[1 2] [0 -2]]))
(equal? (let ([M (list->matrix '[[2 4] [3 4]])])
(list*->matrix '[[1 2] [0 -2]]))
(equal? (let ([M (list*->matrix '[[2 4] [3 4]])])
(gauss-eliminate M #t #f))
(list->matrix '[[1 2] [0 1]]))
(equal? (let ([M (list->matrix '[[2. 4.] [3. 4.]])])
(list*->matrix '[[1 2] [0 1]]))
(equal? (let ([M (list*->matrix '[[2. 4.] [3. 4.]])])
(gauss-eliminate M #t #t))
(list->matrix '[[1. 1.3333333333333333] [0. 1.]]))
(equal? (let ([M (list->matrix '[[1 4] [2 4]])])
(list*->matrix '[[1. 1.3333333333333333] [0. 1.]]))
(equal? (let ([M (list*->matrix '[[1 4] [2 4]])])
(gauss-eliminate M #t #t))
(list->matrix '[[1 2] [0 1]]))
(equal? (let ([M (list->matrix '[[1 2] [2 4]])])
(list*->matrix '[[1 2] [0 1]]))
(equal? (let ([M (list*->matrix '[[1 2] [2 4]])])
(gauss-eliminate M #f #t))
(list->matrix '[[2 4] [0 0]]))))
(list*->matrix '[[2 4] [0 0]]))))
(list
'matrix-scale-row
(equal? (matrix-scale-row (identity-matrix 3) 0 2)
@ -265,7 +329,7 @@
(equal? (matrix-add-scaled-row (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real? ) 0 2 1)
(list*->array '[[9 12 15] [4 5 6] [7 8 9]] real? )))
(let ()
(define M (list->matrix '[[1 1 0 3]
(define M (list*->matrix '[[1 1 0 3]
[2 1 -1 1]
[3 -1 -1 2]
[-1 2 3 -1]]))
@ -277,12 +341,12 @@
(define V (if (list? LU) (second LU) #f))
(list
'matrix-lu
(equal? L (list->matrix
(equal? L (list*->matrix
'[[1 0 0 0]
[2 1 0 0]
[3 4 1 0]
[-1 -3 0 1]]))
(equal? V (list->matrix
(equal? V (list*->matrix
'[[1 1 0 3]
[0 -1 -1 -5]
[0 0 3 13]
@ -290,34 +354,34 @@
(equal? (matrix* L V) M)))))
(list
'matrix-rank
(equal? (matrix-rank (list->matrix '[[0 0] [0 0]])) 0)
(equal? (matrix-rank (list->matrix '[[1 0] [0 0]])) 1)
(equal? (matrix-rank (list->matrix '[[1 0] [0 3]])) 2)
(equal? (matrix-rank (list->matrix '[[1 2] [2 4]])) 1)
(equal? (matrix-rank (list->matrix '[[1 2] [3 4]])) 2))
(equal? (matrix-rank (list*->matrix '[[0 0] [0 0]])) 0)
(equal? (matrix-rank (list*->matrix '[[1 0] [0 0]])) 1)
(equal? (matrix-rank (list*->matrix '[[1 0] [0 3]])) 2)
(equal? (matrix-rank (list*->matrix '[[1 2] [2 4]])) 1)
(equal? (matrix-rank (list*->matrix '[[1 2] [3 4]])) 2))
(list
'matrix-nullity
(equal? (matrix-nullity (list->matrix '[[0 0] [0 0]])) 2)
(equal? (matrix-nullity (list->matrix '[[1 0] [0 0]])) 1)
(equal? (matrix-nullity (list->matrix '[[1 0] [0 3]])) 0)
(equal? (matrix-nullity (list->matrix '[[1 2] [2 4]])) 1)
(equal? (matrix-nullity (list->matrix '[[1 2] [3 4]])) 0))
(equal? (matrix-nullity (list*->matrix '[[0 0] [0 0]])) 2)
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 0]])) 1)
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0)
(equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1)
(equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0))
#;(let ()
(define-values (c1 n1)
(matrix-column+null-space (list->matrix '[[0 0] [0 0]])))
(matrix-column+null-space (list*rix '[[0 0] [0 0]])))
(define-values (c2 n2)
(matrix-column+null-space (list->matrix '[[1 2] [2 4]])))
(matrix-column+null-space (list*->matrix '[[1 2] [2 4]])))
(define-values (c3 n3)
(matrix-column+null-space (list->matrix '[[1 2] [2 5]])))
(matrix-column+null-space (list*atrix '[[1 2] [2 5]])))
(list
'matrix-column+null-space
(equal? c1 '())
(equal? n1 (list (list->matrix '[[0] [0]])
(list->matrix '[[0] [0]])))
(equal? c2 (list (list->matrix '[[1] [2]])))
(equal? n1 (list (list*->matrix '[[0] [0]])
(list*->matrix '[[0] [0]])))
(equal? c2 (list (list*->matrix '[[1] [2]])))
;(equal? n2 '([0 0]))
(equal? c3 (list (list->matrix '[[1] [2]])
(list->matrix '[[2] [5]])))
(equal? c3 (list (list*->matrix '[[1] [2]])
(list*->matrix '[[2] [5]])))
(equal? n3 '()))))
@ -331,10 +395,12 @@
(list 'matrix*
(let ()
(define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6] [7 8]] '[[19 22] [43 50]]))
(equal? (matrix* (list->matrix A) (list->matrix B)) (list->matrix AB)))
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))
(let ()
(define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6 7] [8 9 10]] '[[21 24 27] [47 54 61]]))
(equal? (matrix* (list->matrix A) (list->matrix B)) (list->matrix AB)))))
(define-values (A B AB) (values '[[1 2] [3 4]]
'[[5 6 7] [8 9 10]]
'[[21 24 27] [47 54 61]]))
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))))
#;(begin
"matrix-2d.rkt"
(let ()
@ -343,10 +409,10 @@
(define -e1 (matrix-transpose (vector->matrix #(#(-1 0)))))
(define -e2 (matrix-transpose (vector->matrix #(#( 0 -1)))))
(define O (matrix-transpose (vector->matrix #(#( 0 0)))))
(define 2*e1 (matrix-scale 2 e1))
(define 4*e1 (matrix-scale 4 e1))
(define 3*e2 (matrix-scale 3 e2))
(define 4*e2 (matrix-scale 4 e2))
(define 2*e1 (matrix-scale e1 2))
(define 4*e1 (matrix-scale e1 4))
(define 3*e2 (matrix-scale e2 3))
(define 4*e2 (matrix-scale e2 4))
(begin
(list 'matrix-2d-rotation
(<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0)

View File

@ -0,0 +1,11 @@
#lang typed/racket
(require (except-in typed/rackunit check-equal?))
(provide (all-from-out typed/rackunit)
check-equal?)
;; This gets around the fact that typed/rackunit can no longer test higher-order values for equality,
;; since TR has firmed up its rules on passing `Any' types in and out of untyped code
(define-syntax-rule (check-equal? a b . message)
(check-true (equal? a b) . message))