Reviewing and refactoring `math/matrix', part 1
* Finally added `array-axis-expand' as a dual for `array-axis-reduce' in order to implement `vandermonde-matrix' elegantly * Better, shorter matrix multiply; reworked all matrix arithmetic * Split "matrix-operations.rkt" into at least 5 parts: * "matrix-operations.rkt" * "matrix-basic.rkt" * "matrix-comprehension.rkt" * "matrix-sequences.rkt" * "matrix-column.rkt" Added "matrix-constructors.rkt" Added `matrix', `row-matrix', and `col-matrix' macros A lot of other little changes Currently, `in-row' and `in-column' are broken. I intend to implement them in a way that makes them work in untyped and Typed Racket.
This commit is contained in:
parent
c2468f1f9a
commit
155ec7dc41
|
@ -9,6 +9,7 @@
|
|||
"private/array/array-convert.rkt"
|
||||
"private/array/array-fold.rkt"
|
||||
"private/array/array-special-folds.rkt"
|
||||
"private/array/array-unfold.rkt"
|
||||
"private/array/array-print.rkt"
|
||||
"private/array/array-fft.rkt"
|
||||
"private/array/array-syntax.rkt"
|
||||
|
@ -36,6 +37,7 @@
|
|||
"private/array/array-convert.rkt"
|
||||
"private/array/array-fold.rkt"
|
||||
"private/array/array-special-folds.rkt"
|
||||
"private/array/array-unfold.rkt"
|
||||
"private/array/array-print.rkt"
|
||||
"private/array/array-syntax.rkt"
|
||||
"private/array/array-fft.rkt"
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
"number-theory.rkt"
|
||||
"vector.rkt"
|
||||
"array.rkt"
|
||||
;"matrix.rkt"
|
||||
"matrix.rkt"
|
||||
"utils.rkt")
|
||||
|
||||
(provide (all-from-out
|
||||
|
@ -22,5 +22,5 @@
|
|||
"number-theory.rkt"
|
||||
"vector.rkt"
|
||||
"array.rkt"
|
||||
;"matrix.rkt"
|
||||
"matrix.rkt"
|
||||
"utils.rkt"))
|
||||
|
|
|
@ -1,21 +1,22 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require "private/matrix/matrix-pointwise.rkt"
|
||||
"private/matrix/matrix-multiply.rkt"
|
||||
(require "private/matrix/matrix-arithmetic.rkt"
|
||||
"private/matrix/matrix-constructors.rkt"
|
||||
"private/matrix/matrix-basic.rkt"
|
||||
"private/matrix/matrix-operations.rkt"
|
||||
"private/matrix/matrix-comprehension.rkt"
|
||||
"private/matrix/matrix-sequences.rkt"
|
||||
"private/matrix/matrix-expt.rkt"
|
||||
"private/matrix/matrix-types.rkt"
|
||||
"private/matrix/utils.rkt")
|
||||
|
||||
(provide (all-from-out "private/matrix/matrix-pointwise.rkt"
|
||||
"private/matrix/matrix-multiply.rkt"
|
||||
(provide (all-from-out
|
||||
"private/matrix/matrix-arithmetic.rkt"
|
||||
"private/matrix/matrix-constructors.rkt"
|
||||
"private/matrix/matrix-basic.rkt"
|
||||
"private/matrix/matrix-operations.rkt"
|
||||
"private/matrix/matrix-comprehension.rkt"
|
||||
"private/matrix/matrix-sequences.rkt"
|
||||
"private/matrix/matrix-expt.rkt"
|
||||
"private/matrix/matrix-types.rkt")
|
||||
;; From "utils.rkt"
|
||||
array-matrix?
|
||||
;; would also like matrix? : (Any -> Boolean : (Array Any)), but we can't have one until we
|
||||
;; can define array? : (Any -> Boolean : (Array Any)), and there's been trouble with that
|
||||
)
|
||||
matrix?)
|
||||
|
|
51
collects/math/private/array/array-unfold.rkt
Normal file
51
collects/math/private/array/array-unfold.rkt
Normal file
|
@ -0,0 +1,51 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require racket/fixnum
|
||||
"array-struct.rkt"
|
||||
"array-pointwise.rkt"
|
||||
"array-fold.rkt"
|
||||
"utils.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(provide unsafe-array-axis-expand
|
||||
array-axis-expand
|
||||
list-array->array)
|
||||
|
||||
(: check-array-axis (All (A) (Symbol (Array A) Integer -> Index)))
|
||||
(define (check-array-axis name arr k)
|
||||
(define dims (array-dims arr))
|
||||
(cond [(fx= dims 0) (raise-argument-error name "Array with at least one axis" 0 arr k)]
|
||||
[(or (k . < . 0) (k . > . dims))
|
||||
(raise-argument-error name (format "Index <= ~a" dims) 1 arr k)]
|
||||
[else k]))
|
||||
|
||||
(: unsafe-array-axis-expand (All (A B) ((Array A) Index Index (A Index -> B) -> (Array B))))
|
||||
(define (unsafe-array-axis-expand arr k dk f)
|
||||
(define ds (array-shape arr))
|
||||
(define new-ds (unsafe-vector-insert ds k dk))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array
|
||||
new-ds (λ: ([js : Indexes])
|
||||
(define jk (unsafe-vector-ref js k))
|
||||
(f (proc (unsafe-vector-remove js k)) jk))))
|
||||
|
||||
(: array-axis-expand (All (A B) ((Array A) Integer Integer (A Index -> B) -> (Array B))))
|
||||
(define (array-axis-expand arr k dk f)
|
||||
(let ([k (check-array-axis 'array-axis-expand arr k)])
|
||||
(cond [(not (index? dk)) (raise-argument-error 'array-axis-expand "Index" 2 arr k dk f)]
|
||||
[else (unsafe-array-axis-expand arr k dk f)])))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Specific unfolds/expansions
|
||||
|
||||
(: list-array->array (All (A) (case-> ((Array (Listof A)) -> (Array A))
|
||||
((Array (Listof A)) Integer -> (Array A)))))
|
||||
(define (list-array->array arr [k 0])
|
||||
(define dims (array-dims arr))
|
||||
(cond [(and (k . >= . 0) (k . <= . dims))
|
||||
(let ([arr (array-strict (array-map (inst list->vector A) arr))])
|
||||
;(define dks (remove-duplicates (array->list (array-map vector-length arr))))
|
||||
(define dk (array-all-min (array-map vector-length arr)))
|
||||
(unsafe-array-axis-expand arr k dk (inst unsafe-vector-ref A)))]
|
||||
[else
|
||||
(raise-argument-error 'list-array->array (format "Index <= ~a" dims) 1 arr k)]))
|
|
@ -23,8 +23,7 @@
|
|||
mutable-array-copy
|
||||
mutable-array
|
||||
;; Conversion
|
||||
array->mutable-array
|
||||
flat-vector->matrix)
|
||||
array->mutable-array)
|
||||
|
||||
(define-syntax (mutable-array stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
|
|
|
@ -21,23 +21,6 @@
|
|||
(raise-argument-error name (format "Index < ~a" dims) 1 arr k)]
|
||||
[else k]))
|
||||
|
||||
(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B))))
|
||||
(define (array-axis-reduce arr k f)
|
||||
(let ([k (check-array-axis 'array-axis-reduce arr k)])
|
||||
(define ds (array-shape arr))
|
||||
(define dk (unsafe-vector-ref ds k))
|
||||
(define new-ds (unsafe-vector-remove ds k))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array
|
||||
new-ds (λ: ([js : Indexes])
|
||||
(define old-js (unsafe-vector-insert js k 0))
|
||||
(f dk (λ: ([jk : Integer])
|
||||
(cond [(or (jk . < . 0) (jk . >= . dk))
|
||||
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
|
||||
[else
|
||||
(unsafe-vector-set! old-js k jk)
|
||||
(proc (vector-copy-all old-js))])))))))
|
||||
|
||||
(: unsafe-array-axis-reduce (All (A B) ((Array A) Index (Index (Index -> A) -> B) -> (Array B))))
|
||||
(begin-encourage-inline
|
||||
(define (unsafe-array-axis-reduce arr k f)
|
||||
|
@ -52,6 +35,19 @@
|
|||
(unsafe-vector-set! old-js k jk)
|
||||
(proc old-js)))))))
|
||||
|
||||
(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B))))
|
||||
(define (array-axis-reduce arr k f)
|
||||
(let ([k (check-array-axis 'array-axis-reduce arr k)])
|
||||
(unsafe-array-axis-reduce
|
||||
arr k
|
||||
(λ: ([dk : Index] [proc : (Index -> A)])
|
||||
(: safe-proc (Integer -> A))
|
||||
(define (safe-proc jk)
|
||||
(cond [(or (jk . < . 0) (jk . >= . dk))
|
||||
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
|
||||
[else (proc jk)]))
|
||||
(f dk safe-proc)))))
|
||||
|
||||
(: array-axis-fold/init (All (A B) ((Array A) Integer (A B -> B) B -> (Array B))))
|
||||
(define (array-axis-fold/init arr k f init)
|
||||
(let ([k (check-array-axis 'array-axis-fold arr k)])
|
||||
|
|
|
@ -41,10 +41,6 @@
|
|||
(define (mutable-array-copy arr)
|
||||
(unsafe-vector->array (array-shape arr) (vector-copy-all (mutable-array-data arr))))
|
||||
|
||||
(: flat-vector->matrix : (All (A) (Index Index (Vectorof A) -> (Array A))))
|
||||
(define (flat-vector->matrix m n v)
|
||||
(vector->array (vector m n) v))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Conversions
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
#lang typed/racket/base
|
||||
(require math/matrix)
|
||||
|
||||
(require math/array
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt")
|
||||
|
||||
(provide matrix-2d-rotation
|
||||
matrix-2d-scaling
|
||||
|
@ -11,34 +14,30 @@
|
|||
; Transformations from:
|
||||
; http://en.wikipedia.org/wiki/Transformation_matrix
|
||||
|
||||
(: matrix-2d-rotation : Real -> (Matrix Number))
|
||||
(: matrix-2d-rotation : Real -> (Matrix Real))
|
||||
; matrix representing rotation θ radians counter clockwise
|
||||
(define (matrix-2d-rotation θ)
|
||||
(define cosθ (cos θ))
|
||||
(define sinθ (sin θ))
|
||||
(matrix/dim 2 2
|
||||
cosθ (- sinθ)
|
||||
sinθ cosθ))
|
||||
(matrix [[cosθ (- sinθ)]
|
||||
[sinθ cosθ]]))
|
||||
|
||||
(: matrix-2d-scaling : Real Real -> (Matrix Number))
|
||||
(: matrix-2d-scaling : Real Real -> (Matrix Real))
|
||||
(define (matrix-2d-scaling sx sy)
|
||||
(matrix/dim 2 2
|
||||
sx 0
|
||||
0 sy))
|
||||
(matrix [[sx 0]
|
||||
[0 sy]]))
|
||||
|
||||
(: matrix-2d-shear-x : Real -> (Matrix Number))
|
||||
(: matrix-2d-shear-x : Real -> (Matrix Real))
|
||||
(define (matrix-2d-shear-x k)
|
||||
(matrix/dim 2 2
|
||||
1 k
|
||||
0 1))
|
||||
(matrix [[1 k]
|
||||
[0 1]]))
|
||||
|
||||
(: matrix-2d-shear-y : Real -> (Matrix Number))
|
||||
(: matrix-2d-shear-y : Real -> (Matrix Real))
|
||||
(define (matrix-2d-shear-y k)
|
||||
(matrix/dim 2 2
|
||||
1 0
|
||||
k 1))
|
||||
(matrix [[1 0]
|
||||
[k 1]]))
|
||||
|
||||
(: matrix-2d-reflection : Real Real -> (Matrix Number))
|
||||
(: matrix-2d-reflection : Real Real -> (Matrix Real))
|
||||
(define (matrix-2d-reflection a b)
|
||||
; reflection about the line through (0,0) and (a,b)
|
||||
(define a2 (* a a))
|
||||
|
@ -46,11 +45,10 @@
|
|||
(define 2ab (* 2 a b))
|
||||
(define norm2 (+ a2 b2))
|
||||
(define 2ab/norm2 (/ 2ab norm2))
|
||||
(matrix/dim 2 2
|
||||
(/ (- a2 b2) norm2) 2ab/norm2
|
||||
2ab/norm2 (/ (- b2 a2) norm2)))
|
||||
(matrix [[(/ (- a2 b2) norm2) 2ab/norm2]
|
||||
[2ab/norm2 (/ (- b2 a2) norm2)]]))
|
||||
|
||||
(: matrix-2d-orthogonal-projection : Real Real -> (Matrix Number))
|
||||
(: matrix-2d-orthogonal-projection : Real Real -> (Matrix Real))
|
||||
; orthogonal projection onto the line through (0,0) and (a,b)
|
||||
(define (matrix-2d-orthogonal-projection a b)
|
||||
(define a2 (* a a))
|
||||
|
@ -58,6 +56,5 @@
|
|||
(define ab (* a b))
|
||||
(define norm2 (+ a2 b2))
|
||||
(define ab/norm2 (/ ab norm2))
|
||||
(matrix/dim 2 2
|
||||
(/ a2 norm2) ab/norm2
|
||||
ab/norm2 (/ b2 norm2)))
|
||||
(matrix [[(/ a2 norm2) ab/norm2]
|
||||
[ab/norm2 (/ b2 norm2)]]))
|
||||
|
|
39
collects/math/private/matrix/matrix-arithmetic.rkt
Normal file
39
collects/math/private/matrix/matrix-arithmetic.rkt
Normal file
|
@ -0,0 +1,39 @@
|
|||
#lang racket/base
|
||||
|
||||
(require (for-syntax racket/base)
|
||||
typed/untyped-utils
|
||||
(prefix-in typed: "typed-matrix-arithmetic.rkt")
|
||||
(rename-in "untyped-matrix-arithmetic.rkt"
|
||||
[matrix-map untyped:matrix-map]))
|
||||
|
||||
(define-typed/untyped-identifier matrix-map
|
||||
typed:matrix-map untyped:matrix-map)
|
||||
|
||||
(define-syntax (define/inline-macro stx)
|
||||
(syntax-case stx ()
|
||||
[(_ name pat inline-fun typed:fun)
|
||||
(syntax/loc stx
|
||||
(define-syntax (name inner-stx)
|
||||
(syntax-case inner-stx ()
|
||||
[(_ . pat) (syntax/loc inner-stx (inline-fun . pat))]
|
||||
[(_ . es) (syntax/loc inner-stx (typed:fun . es))]
|
||||
[_ (syntax/loc inner-stx typed:fun)])))]))
|
||||
|
||||
(define/inline-macro matrix* (a . as) inline-matrix* typed:matrix*)
|
||||
(define/inline-macro matrix+ (a . as) inline-matrix+ typed:matrix+)
|
||||
(define/inline-macro matrix- (a . as) inline-matrix- typed:matrix-)
|
||||
(define/inline-macro matrix-scale (a x) inline-matrix-scale typed:matrix-scale)
|
||||
|
||||
(provide
|
||||
;; Equality
|
||||
(rename-out [typed:matrix= matrix=])
|
||||
;; Mapping
|
||||
inline-matrix-map
|
||||
matrix-map
|
||||
;; Multiplication
|
||||
matrix*
|
||||
;; Pointwise operators
|
||||
matrix+
|
||||
matrix-
|
||||
matrix-scale
|
||||
(rename-out [typed:matrix-sum matrix-sum]))
|
206
collects/math/private/matrix/matrix-basic.rkt
Normal file
206
collects/math/private/matrix/matrix-basic.rkt
Normal file
|
@ -0,0 +1,206 @@
|
|||
#lang typed/racket
|
||||
|
||||
(require racket/list
|
||||
racket/fixnum
|
||||
math/array
|
||||
math/flonum
|
||||
"matrix-types.rkt"
|
||||
"utils.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(provide
|
||||
;; Extraction
|
||||
matrix-ref
|
||||
matrix-diagonal
|
||||
submatrix
|
||||
matrix-row
|
||||
matrix-col
|
||||
matrix-rows
|
||||
matrix-cols
|
||||
;; Predicates
|
||||
zero-matrix?
|
||||
;; Embiggenment
|
||||
matrix-augment
|
||||
matrix-stack
|
||||
;; Norm and inner product
|
||||
matrix-norm
|
||||
matrix-dot
|
||||
;; Simple operators
|
||||
matrix-transpose
|
||||
matrix-conjugate
|
||||
matrix-hermitian
|
||||
matrix-trace)
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Extraction
|
||||
|
||||
(: matrix-ref (All (A) (Array A) Integer Integer -> A))
|
||||
(define (matrix-ref a i j)
|
||||
(define-values (m n) (matrix-shape a))
|
||||
(cond [(or (i . < . 0) (i . >= . m))
|
||||
(raise-argument-error 'matrix-ref (format "Index < ~a" m) 1 a i j)]
|
||||
[(or (j . < . 0) (j . >= . n))
|
||||
(raise-argument-error 'matrix-ref (format "Index < ~a" n) 2 a i j)]
|
||||
[else
|
||||
(unsafe-array-ref a ((inst vector Index) i j))]))
|
||||
|
||||
(: matrix-diagonal (All (A) ((Array A) -> (Array A))))
|
||||
(define (matrix-diagonal a)
|
||||
(define m (square-matrix-size a))
|
||||
(define proc (unsafe-array-proc a))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m)
|
||||
(λ: ([js : Indexes])
|
||||
(define i (unsafe-vector-ref js 0))
|
||||
(proc ((inst vector Index) i i)))))
|
||||
|
||||
(: submatrix (All (A) (Matrix A) Slice-Spec Slice-Spec -> (Matrix A)))
|
||||
(define (submatrix a row-range col-range)
|
||||
(array-slice-ref (ensure-matrix 'submatrix a) (list row-range col-range)))
|
||||
|
||||
(: matrix-row (All (A) (Matrix A) Integer -> (Matrix A)))
|
||||
(define (matrix-row a i)
|
||||
(define-values (m n) (matrix-shape a))
|
||||
(cond [(or (i . < . 0) (i . >= . m))
|
||||
(raise-argument-error 'matrix-row (format "Index < ~a" m) 1 a i)]
|
||||
[else
|
||||
(define proc (unsafe-array-proc a))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) 1 n)
|
||||
(λ: ([ij : Indexes])
|
||||
(unsafe-vector-set! ij 0 i)
|
||||
(define res (proc ij))
|
||||
(unsafe-vector-set! ij 0 0)
|
||||
res))]))
|
||||
|
||||
(: matrix-col (All (A) (Matrix A) Index -> (Matrix A)))
|
||||
(define (matrix-col a j)
|
||||
(define-values (m n) (matrix-shape a))
|
||||
(cond [(or (j . < . 0) (j . >= . n))
|
||||
(raise-argument-error 'matrix-row (format "Index < ~a" n) 1 a j)]
|
||||
[else
|
||||
(define proc (unsafe-array-proc a))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m 1)
|
||||
(λ: ([ij : Indexes])
|
||||
(unsafe-vector-set! ij 1 j)
|
||||
(define res (proc ij))
|
||||
(unsafe-vector-set! ij 1 0)
|
||||
res))]))
|
||||
|
||||
(: matrix-rows (All (A) (Array A) -> (Listof (Array A))))
|
||||
(define (matrix-rows a)
|
||||
(array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0))
|
||||
|
||||
(: matrix-cols (All (A) (Array A) -> (Listof (Array A))))
|
||||
(define (matrix-cols a)
|
||||
(array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Predicates
|
||||
|
||||
(: zero-matrix? ((Array Number) -> Boolean))
|
||||
(define (zero-matrix? a)
|
||||
(array-all-and (array-map zero? a)))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Embiggenment (this is a perfectly cromulent word)
|
||||
|
||||
(: matrix-augment (All (A) (Listof (Array A)) -> (Array A)))
|
||||
(define (matrix-augment as)
|
||||
(cond [(empty? as) (raise-argument-error 'matrix-augment "nonempty List" as)]
|
||||
[else
|
||||
(define m (matrix-num-rows (first as)))
|
||||
(cond [(andmap (λ: ([a : (Array A)]) (= m (matrix-num-rows a))) (rest as))
|
||||
(array-append* as 1)]
|
||||
[else
|
||||
(error 'matrix-augment
|
||||
"matrices must have the same number of rows; given ~a"
|
||||
(format-matrices/error as))])]))
|
||||
|
||||
(: matrix-stack (All (A) (Listof (Array A)) -> (Array A)))
|
||||
(define (matrix-stack as)
|
||||
(cond [(empty? as) (raise-argument-error 'matrix-stack "nonempty List" as)]
|
||||
[else
|
||||
(define n (matrix-num-cols (first as)))
|
||||
(cond [(andmap (λ: ([a : (Array A)]) (= n (matrix-num-cols a))) (rest as))
|
||||
(array-append* as 0)]
|
||||
[else
|
||||
(error 'matrix-stack
|
||||
"matrices must have the same number of columns; given ~a"
|
||||
(format-matrices/error as))])]))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Matrix norms and Frobenius inner product
|
||||
|
||||
(: maximum-norm ((Array Number) -> Real))
|
||||
(define (maximum-norm a)
|
||||
(array-all-max (array-magnitude a)))
|
||||
|
||||
(: taxicab-norm ((Array Number) -> Real))
|
||||
(define (taxicab-norm a)
|
||||
(array-all-sum (array-magnitude a)))
|
||||
|
||||
(: frobenius-norm ((Array Number) -> Real))
|
||||
(define (frobenius-norm a)
|
||||
(let ([a (array-strict (array-magnitude a))])
|
||||
;; Compute this divided by the maximum to avoid underflow and overflow
|
||||
(define mx (array-all-max a))
|
||||
(cond [(and (rational? mx) (positive? mx))
|
||||
(assert
|
||||
(* mx (sqrt (array-all-sum (inline-array-map (λ: ([x : Number]) (sqr (/ x mx))) a))))
|
||||
real?)]
|
||||
[else mx])))
|
||||
|
||||
(: p-norm ((Array Number) Positive-Real -> Real))
|
||||
(define (p-norm a p)
|
||||
(let ([a (array-strict (array-magnitude a))])
|
||||
;; Compute this divided by the maximum to avoid underflow and overflow
|
||||
(define mx (array-all-max a))
|
||||
(cond [(and (rational? mx) (positive? mx))
|
||||
(assert
|
||||
(* mx (expt (array-all-sum (inline-array-map (λ: ([x : Real]) (expt (/ x mx) p)) a))
|
||||
(/ p)))
|
||||
real?)]
|
||||
[else mx])))
|
||||
|
||||
(: matrix-norm (case-> ((Array Number) -> Real)
|
||||
((Array Number) Real -> Real)))
|
||||
;; Computes the p norm of a matrix
|
||||
(define (matrix-norm a [p 2])
|
||||
(cond [(not (matrix? a)) (raise-argument-error 'matrix-norm "matrix?" 0 a p)]
|
||||
[(p . = . 2) (frobenius-norm a)]
|
||||
[(p . = . +inf.0) (maximum-norm a)]
|
||||
[(p . = . 1) (taxicab-norm a)]
|
||||
[(p . > . 1) (p-norm a p)]
|
||||
[else (raise-argument-error 'matrix-norm "Real >= 1" 1 a p)]))
|
||||
|
||||
(: matrix-dot (case-> ((Array Real) (Array Real) -> Real)
|
||||
((Array Number) (Array Number) -> Number)))
|
||||
;; Computes the Frobenius inner product of two matrices
|
||||
(define (matrix-dot a b)
|
||||
(cond [(not (matrix? a)) (raise-argument-error 'matrix-dot "matrix?" 0 a b)]
|
||||
[(not (matrix? b)) (raise-argument-error 'matrix-dot "matrix?" 1 a b)]
|
||||
[else (array-all-sum (array* a (array-conjugate b)))]))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Operators
|
||||
|
||||
(: matrix-transpose (All (A) (Array A) -> (Array A)))
|
||||
(define (matrix-transpose a)
|
||||
(array-axis-swap (ensure-matrix 'matrix-transpose a) 0 1))
|
||||
|
||||
(: matrix-conjugate (case-> ((Array Real) -> (Array Real))
|
||||
((Array Number) -> (Array Number))))
|
||||
(define (matrix-conjugate a)
|
||||
(array-conjugate (ensure-matrix 'matrix-conjugate a)))
|
||||
|
||||
(: matrix-hermitian (case-> ((Array Real) -> (Array Real))
|
||||
((Array Number) -> (Array Number))))
|
||||
(define (matrix-hermitian a)
|
||||
(array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1))
|
||||
|
||||
(: matrix-trace (case-> ((Array Real) -> Real)
|
||||
((Array Number) -> Number)))
|
||||
(define (matrix-trace a)
|
||||
(array-all-sum (matrix-diagonal a)))
|
126
collects/math/private/matrix/matrix-column.rkt
Normal file
126
collects/math/private/matrix/matrix-column.rkt
Normal file
|
@ -0,0 +1,126 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require math/array
|
||||
math/base
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
"matrix-arithmetic.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(provide unit-column
|
||||
column-height
|
||||
unsafe-column->vector
|
||||
column-scale
|
||||
column+
|
||||
column-dot
|
||||
column-norm
|
||||
column-project
|
||||
column-project/unit
|
||||
column-normalize)
|
||||
|
||||
(: unit-column : Integer Integer -> (Result-Column Number))
|
||||
(define (unit-column m i)
|
||||
(cond
|
||||
[(and (index? m) (index? i))
|
||||
(define v (make-vector m 0))
|
||||
(if (< i m)
|
||||
(vector-set! v i 1)
|
||||
(error 'unit-vector "dimension must be largest"))
|
||||
(vector->matrix m 1 v)]
|
||||
[else
|
||||
(error 'unit-vector "expected two indices")]))
|
||||
|
||||
(: column-height : (Column Number) -> Index)
|
||||
(define (column-height v)
|
||||
(if (vector? v)
|
||||
(vector-length v)
|
||||
(matrix-num-rows v)))
|
||||
|
||||
(: unsafe-column->vector : (Column Number) -> (Vectorof Number))
|
||||
(define (unsafe-column->vector v)
|
||||
(if (vector? v) v
|
||||
(let ()
|
||||
(define-values (m n) (matrix-shape v))
|
||||
(if (= n 1)
|
||||
(mutable-array-data (array->mutable-array v))
|
||||
(error 'unsafe-column->vector
|
||||
"expected a column (vector or mx1 matrix), got ~a" v)))))
|
||||
|
||||
(: column-scale : (Column Number) Number -> (Result-Column Number))
|
||||
(define (column-scale a s)
|
||||
(if (vector? a)
|
||||
(let*: ([n (vector-length a)]
|
||||
[v : (Vectorof Number) (make-vector n 0)])
|
||||
(for: ([i (in-range 0 n)]
|
||||
[x : Number (in-vector a)])
|
||||
(vector-set! v i (* s x)))
|
||||
(->col-matrix v))
|
||||
(matrix-scale a s)))
|
||||
|
||||
(: column+ : (Column Number) (Column Number) -> (Result-Column Number))
|
||||
(define (column+ v w)
|
||||
(cond [(and (vector? v) (vector? w))
|
||||
(let ([n (vector-length v)]
|
||||
[m (vector-length w)])
|
||||
(unless (= m n)
|
||||
(error 'column+
|
||||
"expected two column vectors of the same length, got ~a and ~a" v w))
|
||||
(define: v+w : (Vectorof Number) (make-vector n 0))
|
||||
(for: ([i (in-range 0 n)]
|
||||
[x : Number (in-vector v)]
|
||||
[y : Number (in-vector w)])
|
||||
(vector-set! v+w i (+ x y)))
|
||||
(->col-matrix v+w))]
|
||||
[else
|
||||
(unless (= (column-height v) (column-height w))
|
||||
(error 'column+
|
||||
"expected two column vectors of the same length, got ~a and ~a" v w))
|
||||
(array+ (->col-matrix v) (->col-matrix w))]))
|
||||
|
||||
(: column-dot : (Column Number) (Column Number) -> Number)
|
||||
(define (column-dot c d)
|
||||
(define v (unsafe-column->vector c))
|
||||
(define w (unsafe-column->vector d))
|
||||
(define m (column-height v))
|
||||
(define s (column-height w))
|
||||
(cond
|
||||
[(not (= m s)) (error 'column-dot
|
||||
"expected two mx1 matrices with same number of rows, got ~a and ~a"
|
||||
c d)]
|
||||
[else
|
||||
(for/sum: : Number ([i (in-range 0 m)])
|
||||
(assert i index?)
|
||||
; Note: If d is a vector of reals,
|
||||
; then the conjugate is a no-op
|
||||
(* (unsafe-vector-ref v i)
|
||||
(conjugate (unsafe-vector-ref w i))))]))
|
||||
|
||||
(: column-norm : (Column Number) -> Real)
|
||||
(define (column-norm v)
|
||||
(define norm (sqrt (column-dot v v)))
|
||||
(assert norm real?))
|
||||
|
||||
(: column-project : (Column Number) (Column Number) -> (Result-Column Number))
|
||||
; (column-project v w)
|
||||
; Return the projection og vector v on vector w.
|
||||
(define (column-project v w)
|
||||
(let ([w.w (column-dot w w)])
|
||||
(if (zero? w.w)
|
||||
(error 'column-project "projection on the zero vector not defined")
|
||||
(matrix-scale (->col-matrix w) (/ (column-dot v w) w.w)))))
|
||||
|
||||
(: column-project/unit : (Column Number) (Column Number) -> (Result-Column Number))
|
||||
; (column-project-on-unit v w)
|
||||
; Return the projection og vector v on a unit vector w.
|
||||
(define (column-project/unit v w)
|
||||
(matrix-scale (->col-matrix w) (column-dot v w)))
|
||||
|
||||
(: column-normalize : (Column Number) -> (Result-Column Number))
|
||||
; (column-vector-normalize v)
|
||||
; Return unit vector with same direction as v.
|
||||
; If v is the zero vector, the zero vector is returned.
|
||||
(define (column-normalize w)
|
||||
(let ([norm (column-norm w)]
|
||||
[w (->col-matrix w)])
|
||||
(cond [(zero? norm) w]
|
||||
[else (matrix-scale w (/ norm))])))
|
140
collects/math/private/matrix/matrix-comprehension.rkt
Normal file
140
collects/math/private/matrix/matrix-comprehension.rkt
Normal file
|
@ -0,0 +1,140 @@
|
|||
#lang racket
|
||||
|
||||
(require math/array
|
||||
typed/racket/base
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt")
|
||||
|
||||
(provide for/matrix
|
||||
for*/matrix
|
||||
for/matrix:
|
||||
for*/matrix:)
|
||||
|
||||
;;; COMPREHENSIONS
|
||||
|
||||
; (for/matrix m n (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first n values produced becomes the first row.
|
||||
; The next n values becomes the second row and so on.
|
||||
; The bindings in clauses run in parallel.
|
||||
(define-syntax (for/matrix stx)
|
||||
(syntax-case stx ()
|
||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ([m m-expr] [n n-expr])
|
||||
(define flat-vector
|
||||
(for/vector #:length (* m n)
|
||||
(clause ...) . defs+exprs))
|
||||
(vector->matrix m n flat-vector)))]))
|
||||
|
||||
; (for*/matrix m n (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first n values produced becomes the first row.
|
||||
; The next n values becomes the second row and so on.
|
||||
; The bindings in clauses run nested.
|
||||
; (for*/matrix m n #:column (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first m values produced becomes the first column.
|
||||
; The next m values becomes the second column and so on.
|
||||
; The bindings in clauses run nested.
|
||||
(define-syntax (for*/matrix stx)
|
||||
(syntax-case stx ()
|
||||
[(_ m-expr n-expr #:column (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let* ([m m-expr]
|
||||
[n n-expr]
|
||||
[v (make-vector (* m n) 0)]
|
||||
[w (for*/vector #:length (* m n) (clause ...) . defs+exprs)])
|
||||
(for* ([i (in-range m)] [j (in-range n)])
|
||||
(vector-set! v (+ (* i n) j)
|
||||
(vector-ref w (+ (* j m) i))))
|
||||
(vector->matrix m n v)))]
|
||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ([m m-expr] [n n-expr])
|
||||
(vector->matrix
|
||||
m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))]))
|
||||
|
||||
|
||||
(define-syntax (for/column: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: flat-vector : (Vectorof Number) (make-vector m 0))
|
||||
(for: ([i (in-range m)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! flat-vector i x))
|
||||
(vector->col-matrix flat-vector)))]))
|
||||
|
||||
(define-syntax (for/matrix: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: k : Index 0)
|
||||
(for: ([i (in-range m*n)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
||||
(set! k (assert (+ k 1) index?)))
|
||||
(vector->matrix m n v)))]
|
||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(for: ([i (in-range m*n)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v i x))
|
||||
(vector->matrix m n v)))]))
|
||||
|
||||
(define-syntax (for*/matrix: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: k : Index 0)
|
||||
(for*: (for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
||||
(set! k (assert (+ k 1) index?)))
|
||||
(vector->matrix m n v)))]
|
||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: i : Index 0)
|
||||
(for*: (for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v i x)
|
||||
(set! i (assert (+ i 1) index?)))
|
||||
(vector->matrix m n v)))]))
|
||||
#;
|
||||
(module* test #f
|
||||
(require rackunit)
|
||||
; "matrix-sequences.rkt"
|
||||
; These work in racket not in typed racket
|
||||
(check-equal? (matrix->list* (for*/matrix 2 3 ([i 2] [j 3]) (+ i j)))
|
||||
'[[0 1 2] [1 2 3]])
|
||||
(check-equal? (matrix->list* (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j)))
|
||||
'[[0 2 2] [1 1 3]])
|
||||
(check-equal? (matrix->list* (for*/matrix 2 2 #:column ([i 4]) i))
|
||||
'[[0 2] [1 3]])
|
||||
(check-equal? (matrix->list* (for/matrix 2 2 ([i 4]) i))
|
||||
'[[0 1] [2 3]])
|
||||
(check-equal? (matrix->list* (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j)))
|
||||
'[[6 8 10] [12 14 16]]))
|
|
@ -1,62 +1,376 @@
|
|||
#lang typed/racket
|
||||
#lang racket/base
|
||||
|
||||
(require math/array
|
||||
"../unsafe.rkt"
|
||||
"matrix-types.rkt")
|
||||
|
||||
(provide identity-matrix flidentity-matrix
|
||||
matrix->list list->matrix fllist->matrix
|
||||
matrix->vector vector->matrix flvector->matrix
|
||||
flat-vector->matrix
|
||||
(provide
|
||||
;; Constructors
|
||||
identity-matrix
|
||||
make-matrix
|
||||
matrix-row
|
||||
matrix-column
|
||||
submatrix)
|
||||
build-matrix
|
||||
diagonal-matrix/zero
|
||||
diagonal-matrix
|
||||
block-diagonal-matrix/zero
|
||||
block-diagonal-matrix
|
||||
vandermonde-matrix
|
||||
;; Basic conversion
|
||||
list->matrix
|
||||
matrix->list
|
||||
vector->matrix
|
||||
matrix->vector
|
||||
->row-matrix
|
||||
->col-matrix
|
||||
;; Nested conversion
|
||||
list*->matrix
|
||||
matrix->list*
|
||||
vector*->matrix
|
||||
matrix->vector*
|
||||
;; Syntax
|
||||
matrix
|
||||
row-matrix
|
||||
col-matrix)
|
||||
|
||||
(: identity-matrix (Integer -> (Matrix Real)))
|
||||
(module typed-defs typed/racket/base
|
||||
(require racket/fixnum
|
||||
racket/list
|
||||
racket/vector
|
||||
math/array
|
||||
"../array/utils.rkt"
|
||||
"matrix-types.rkt"
|
||||
"utils.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(provide (all-defined-out))
|
||||
|
||||
;; =================================================================================================
|
||||
;; Constructors
|
||||
|
||||
(: identity-matrix (Integer -> (Matrix (U 0 1))))
|
||||
(define (identity-matrix m) (diagonal-array 2 m 1 0))
|
||||
|
||||
(: flidentity-matrix (Integer -> (Matrix Float)))
|
||||
(define (flidentity-matrix m) (diagonal-array 2 m 1.0 0.0))
|
||||
|
||||
(: make-matrix (All (A) (Integer Integer A -> (Matrix A))))
|
||||
(define (make-matrix m n x)
|
||||
(make-array (vector m n) x))
|
||||
|
||||
(: list->matrix : (Listof* Number) -> (Matrix Number))
|
||||
(define (list->matrix rows)
|
||||
(list*->array rows number?))
|
||||
(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A))))
|
||||
(define (build-matrix m n proc)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
|
||||
[else
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n)
|
||||
(λ: ([js : Indexes])
|
||||
(proc (unsafe-vector-ref js 0)
|
||||
(unsafe-vector-ref js 1))))]))
|
||||
|
||||
(: fllist->matrix : (Listof* Flonum) -> (Matrix Flonum))
|
||||
(define (fllist->matrix rows)
|
||||
(list*->array rows flonum? ))
|
||||
(: diagonal-matrix/zero (All (A) (Array A) A -> (Array A)))
|
||||
(define (diagonal-matrix/zero a zero)
|
||||
(define ds (array-shape a))
|
||||
(cond [(= 1 (vector-length ds))
|
||||
(define m (unsafe-vector-ref ds 0))
|
||||
(define proc (unsafe-array-proc a))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m m)
|
||||
(λ: ([js : Indexes])
|
||||
(define i (unsafe-vector-ref js 0))
|
||||
(cond [(= i (unsafe-vector-ref js 1)) (proc ((inst vector Index) i))]
|
||||
[else zero])))]
|
||||
[else
|
||||
(raise-argument-error 'diagonal-matrix "Array with one dimension" a)]))
|
||||
|
||||
(: matrix->list : (All (A) (Matrix A) -> (Listof* A)))
|
||||
(: diagonal-matrix (case-> ((Array Real) -> (Array Real))
|
||||
((Array Number) -> (Array Number))))
|
||||
(define (diagonal-matrix a)
|
||||
(diagonal-matrix/zero a 0))
|
||||
|
||||
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
|
||||
(define (block-diagonal-matrix/zero* as zero)
|
||||
(define num (vector-length as))
|
||||
(define-values (ms ns)
|
||||
(let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty]
|
||||
[ns : (Listof Index) empty]
|
||||
) ([a (in-vector as)])
|
||||
(define-values (m n) (matrix-shape a))
|
||||
(values (cons m ms) (cons n ns)))])
|
||||
(values (reverse ms) (reverse ns))))
|
||||
(define res-m (assert (apply + ms) index?))
|
||||
(define res-n (assert (apply + ns) index?))
|
||||
(define vs ((inst make-vector Index) res-m 0))
|
||||
(define hs ((inst make-vector Index) res-n 0))
|
||||
(define is ((inst make-vector Index) res-m 0))
|
||||
(define js ((inst make-vector Index) res-n 0))
|
||||
(define-values (_res-i _res-j)
|
||||
(for/fold: ([res-i : Nonnegative-Fixnum 0]
|
||||
[res-j : Nonnegative-Fixnum 0]
|
||||
) ([m (in-list ms)]
|
||||
[n (in-list ns)]
|
||||
[k : Nonnegative-Fixnum (in-range num)])
|
||||
(let ([k (assert k index?)])
|
||||
(for: ([i : Nonnegative-Fixnum (in-range m)])
|
||||
(vector-set! vs (unsafe-fx+ res-i i) k)
|
||||
(vector-set! is (unsafe-fx+ res-i i) (assert i index?)))
|
||||
(for: ([j : Nonnegative-Fixnum (in-range n)])
|
||||
(vector-set! hs (unsafe-fx+ res-j j) k)
|
||||
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
|
||||
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
|
||||
(define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) res-m res-n)
|
||||
(λ: ([ij : Indexes])
|
||||
(define i (unsafe-vector-ref ij 0))
|
||||
(define j (unsafe-vector-ref ij 1))
|
||||
(define v (unsafe-vector-ref vs i))
|
||||
(cond [(fx= v (vector-ref hs j))
|
||||
(define proc (unsafe-vector-ref procs v))
|
||||
(define iv (unsafe-vector-ref is i))
|
||||
(define jv (unsafe-vector-ref js j))
|
||||
(unsafe-vector-set! ij 0 iv)
|
||||
(unsafe-vector-set! ij 1 jv)
|
||||
(define res (proc ij))
|
||||
(unsafe-vector-set! ij 0 i)
|
||||
(unsafe-vector-set! ij 1 j)
|
||||
res]
|
||||
[else
|
||||
zero]))))
|
||||
|
||||
(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A)))
|
||||
(define (block-diagonal-matrix/zero as zero)
|
||||
(let ([as (list->vector as)])
|
||||
(define num (vector-length as))
|
||||
(cond [(= num 0)
|
||||
(raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)]
|
||||
[(= num 1)
|
||||
(unsafe-vector-ref as 0)]
|
||||
[else
|
||||
(block-diagonal-matrix/zero* as zero)])))
|
||||
|
||||
(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real))
|
||||
((Listof (Array Number)) -> (Array Number))))
|
||||
(define (block-diagonal-matrix as)
|
||||
(block-diagonal-matrix/zero as 0))
|
||||
|
||||
(: expt-hack (case-> (Real Integer -> Real)
|
||||
(Number Integer -> Number)))
|
||||
;; Stop using this when TR correctly derives expt : Real Integer -> Real
|
||||
(define (expt-hack x n)
|
||||
(cond [(real? x) (assert (expt x n) real?)]
|
||||
[else (expt x n)]))
|
||||
|
||||
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real))
|
||||
((Listof Number) Integer -> (Array Number))))
|
||||
(define (vandermonde-matrix xs n)
|
||||
(cond [(empty? xs)
|
||||
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]
|
||||
[(or (not (index? n)) (zero? n))
|
||||
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
||||
[else
|
||||
(array-axis-expand (list->array xs) 1 n expt-hack)]))
|
||||
|
||||
;; =================================================================================================
|
||||
;; Flat conversion
|
||||
|
||||
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
|
||||
(define (list->matrix m n xs)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
|
||||
[else (list->array (vector m n) xs)]))
|
||||
|
||||
(: matrix->list (All (A) ((Array A) -> (Listof A))))
|
||||
(define (matrix->list a)
|
||||
(array->list* a))
|
||||
(array->list (ensure-matrix 'matrix->list a)))
|
||||
|
||||
(: vector->matrix : (Vectorof* Number) -> (Matrix Number))
|
||||
(define (vector->matrix rows)
|
||||
(vector*->array rows number? ))
|
||||
(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->matrix m n v)
|
||||
(cond [(or (not (index? m)) (= m 0))
|
||||
(raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)]
|
||||
[(or (not (index? n)) (= n 0))
|
||||
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
|
||||
[else (vector->array (vector m n) v)]))
|
||||
|
||||
(: flvector->matrix : (Vectorof* Flonum) -> (Matrix Flonum))
|
||||
(define (flvector->matrix rows)
|
||||
(vector*->array rows flonum? ))
|
||||
|
||||
(: matrix->vector : (All (A) (Matrix A) -> (Vectorof* A)))
|
||||
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
|
||||
(define (matrix->vector a)
|
||||
(array->vector* a))
|
||||
(array->vector (ensure-matrix 'matrix->vector a)))
|
||||
|
||||
(: submatrix : (Matrix Number) (Sequenceof Index) (Sequenceof Index) -> (Matrix Number))
|
||||
(define (submatrix a row-range col-range)
|
||||
(array-slice-ref a (list row-range col-range)))
|
||||
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
|
||||
(define (list->row-matrix xs)
|
||||
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
|
||||
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
|
||||
|
||||
(: matrix-row : (Matrix Number) Index -> (Matrix Number))
|
||||
(define (matrix-row a i)
|
||||
(define-values (m n) (matrix-dimensions a))
|
||||
(array-slice-ref a (list (in-range i (add1 i)) (in-range n))))
|
||||
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
|
||||
(define (list->col-matrix xs)
|
||||
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
|
||||
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
|
||||
|
||||
(: matrix-column : (Matrix Number) Index -> (Matrix Number))
|
||||
(define (matrix-column a j)
|
||||
(define-values (m n)(matrix-dimensions a))
|
||||
(array-slice-ref a (list (in-range m) (in-range j (add1 j)))))
|
||||
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->row-matrix xs)
|
||||
(define n (vector-length xs))
|
||||
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
|
||||
[else (vector->array ((inst vector Index) 1 n) xs)]))
|
||||
|
||||
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
||||
(define (vector->col-matrix xs)
|
||||
(define n (vector-length xs))
|
||||
(cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)]
|
||||
[else (vector->array ((inst vector Index) n 1) xs)]))
|
||||
|
||||
(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index)))
|
||||
(define (find-nontrivial-axis ds)
|
||||
(define dims (vector-length ds))
|
||||
(let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0])
|
||||
(cond [(k . < . dims) (define dk (unsafe-vector-ref ds k))
|
||||
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
|
||||
[else (values 0 0)])))
|
||||
|
||||
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
|
||||
(define (array->row-matrix arr)
|
||||
(define (fail)
|
||||
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||
(define ds (array-shape arr))
|
||||
(define dims (vector-length ds))
|
||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||
(cond [(zero? (array-size arr)) (fail)]
|
||||
[(row-matrix? arr) arr]
|
||||
[(= num-ones dims)
|
||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 1)
|
||||
(λ: ([ij : Indexes]) (proc js)))]
|
||||
[(= num-ones (- dims 1))
|
||||
(define-values (k n) (find-nontrivial-axis ds))
|
||||
(define js (make-thread-local-indexes dims))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 n)
|
||||
(λ: ([ij : Indexes])
|
||||
(let ([js (js)])
|
||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
|
||||
(proc js))))]
|
||||
[else (fail)]))
|
||||
|
||||
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
|
||||
(define (array->col-matrix arr)
|
||||
(define (fail)
|
||||
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
|
||||
(define ds (array-shape arr))
|
||||
(define dims (vector-length ds))
|
||||
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
||||
(cond [(zero? (array-size arr)) (fail)]
|
||||
[(col-matrix? arr) arr]
|
||||
[(= num-ones dims)
|
||||
(define: js : (Vectorof Index) (make-vector dims 0))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) 1 1)
|
||||
(λ: ([ij : Indexes]) (proc js)))]
|
||||
[(= num-ones (- dims 1))
|
||||
(define-values (k m) (find-nontrivial-axis ds))
|
||||
(define js (make-thread-local-indexes dims))
|
||||
(define proc (unsafe-array-proc arr))
|
||||
(unsafe-build-array ((inst vector Index) m 1)
|
||||
(λ: ([ij : Indexes])
|
||||
(let ([js (js)])
|
||||
(unsafe-vector-set! js k (unsafe-vector-ref ij 0))
|
||||
(proc js))))]
|
||||
[else (fail)]))
|
||||
|
||||
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||
(define (->row-matrix xs)
|
||||
(cond [(list? xs) (list->row-matrix xs)]
|
||||
[(array? xs) (array->row-matrix xs)]
|
||||
[else (vector->row-matrix xs)]))
|
||||
|
||||
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
||||
(define (->col-matrix xs)
|
||||
(cond [(list? xs) (list->col-matrix xs)]
|
||||
[(array? xs) (array->col-matrix xs)]
|
||||
[else (vector->col-matrix xs)]))
|
||||
|
||||
;; =================================================================================================
|
||||
;; Nested conversion
|
||||
|
||||
(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index)))
|
||||
(define (list*-shape xss fail)
|
||||
(define m (length xss))
|
||||
(cond [(m . > . 0)
|
||||
(define n (length (first xss)))
|
||||
(cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss)))
|
||||
(values m n)]
|
||||
[else (fail)])]
|
||||
[else (fail)]))
|
||||
|
||||
(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing)
|
||||
-> (Values Positive-Index Positive-Index)))
|
||||
(define (vector*-shape xss fail)
|
||||
(define m (vector-length xss))
|
||||
(cond [(m . > . 0)
|
||||
(define ns ((inst vector-map Index (Vectorof A)) vector-length xss))
|
||||
(define n (vector-length (unsafe-vector-ref xss 0)))
|
||||
(cond [(and (n . > . 0)
|
||||
(let: loop : Boolean ([i : Nonnegative-Fixnum 1])
|
||||
(cond [(i . fx< . m)
|
||||
(if (= n (vector-length (unsafe-vector-ref xss i)))
|
||||
(loop (fx+ i 1))
|
||||
#f)]
|
||||
[else #t])))
|
||||
(values m n)]
|
||||
[else (fail)])]
|
||||
[else (fail)]))
|
||||
|
||||
(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A)))
|
||||
(define (list*->matrix xss)
|
||||
(define (fail)
|
||||
(raise-argument-error 'list*->matrix
|
||||
"nested lists with rectangular shape and at least one matrix element"
|
||||
xss))
|
||||
(define-values (m n) (list*-shape xss fail))
|
||||
(list->array ((inst vector Index) m n) (apply append xss)))
|
||||
|
||||
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
|
||||
(define (matrix->list* a)
|
||||
(cond [(matrix? a) (array->list (array->list-array a 1))]
|
||||
[else (raise-argument-error 'matrix->list* "matrix?" a)]))
|
||||
|
||||
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
|
||||
(define (vector*->matrix xss)
|
||||
(define (fail)
|
||||
(raise-argument-error 'vector*->matrix
|
||||
"nested vectors with rectangular shape and at least one matrix element"
|
||||
xss))
|
||||
(define-values (m n) (vector*-shape xss fail))
|
||||
(vector->matrix m n (apply vector-append (vector->list xss))))
|
||||
|
||||
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
|
||||
(define (matrix->vector* a)
|
||||
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))]
|
||||
[else (raise-argument-error 'matrix->vector* "matrix?" a)]))
|
||||
) ; module
|
||||
|
||||
(require (for-syntax racket/base
|
||||
syntax/parse)
|
||||
(only-in typed/racket/base :)
|
||||
math/array
|
||||
(submod "." typed-defs))
|
||||
|
||||
(define-syntax (matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [[x0 xs0 ...] [x xs ...] ...])
|
||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))]
|
||||
[(_ [[x0 xs0 ...] [x xs ...] ...] : T)
|
||||
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))]
|
||||
[(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T)))
|
||||
(raise-syntax-error 'matrix "given empty row" stx #'r)]
|
||||
[(_ (~and [] c) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'matrix "given empty matrix" stx #'c)]))
|
||||
|
||||
(define-syntax (row-matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))]
|
||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))]
|
||||
[(_ (~and [] r) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'row-matrix "given empty row" stx #'r)]))
|
||||
|
||||
(define-syntax (col-matrix stx)
|
||||
(syntax-parse stx #:literals (:)
|
||||
[(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))]
|
||||
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))]
|
||||
[(_ (~and [] c) (~optional (~seq : T)))
|
||||
(raise-syntax-error 'row-matrix "given empty column" stx #'c)]))
|
||||
|
|
|
@ -1,22 +1,21 @@
|
|||
#lang typed/racket
|
||||
|
||||
(require "../../array.rkt"
|
||||
(require math/array
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
"matrix-multiply.rkt"
|
||||
"matrix-types.rkt")
|
||||
"matrix-arithmetic.rkt")
|
||||
|
||||
(provide matrix-expt)
|
||||
|
||||
(: matrix-expt : (Matrix Number) Integer -> (Matrix Number))
|
||||
(define (matrix-expt a n)
|
||||
(unless (array-matrix? a)
|
||||
(raise-type-error 'matrix-expt "(Matrix Number)" a))
|
||||
(unless (square-matrix? a)
|
||||
(error 'matrix-expt "Square matrix expected, got ~a" a))
|
||||
(cond
|
||||
[(= n 0) (identity-matrix (square-matrix-size a))]
|
||||
[(= n 1) a]
|
||||
(cond [(not (square-matrix? a)) (raise-argument-error 'matrix-expt "square-matrix?" 0 a n)]
|
||||
[(negative? n) (raise-argument-error 'matrix-expt "Natural" 1 a n)]
|
||||
[(zero? n) (identity-matrix (square-matrix-size a))]
|
||||
[else
|
||||
(let: loop : (Matrix Number) ([n : Positive-Integer n])
|
||||
(cond [(= n 1) a]
|
||||
[(= n 2) (matrix* a a)]
|
||||
[(even? n) (let ([a^n/2 (matrix-expt a (quotient n 2))])
|
||||
(matrix* a^n/2 a^n/2))]
|
||||
[else (matrix* a (matrix-expt a (sub1 n)))]))
|
||||
[else (matrix* a (matrix-expt a (sub1 n)))]))]))
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
#lang typed/racket
|
||||
|
||||
(require "../unsafe.rkt"
|
||||
"../../array.rkt"
|
||||
"matrix-types.rkt")
|
||||
|
||||
(provide matrix*)
|
||||
|
||||
;; The `make-matrix-*' operators have to be macros; see ../array/array-pointwise.rkt for an
|
||||
;; explanation.
|
||||
|
||||
#;(: make-matrix-multiply (All (A) (Symbol
|
||||
((Array A) Integer -> (Array A))
|
||||
((Array A) (Array A) -> (Array A))
|
||||
-> ((Array A) (Array A) -> (Array A)))))
|
||||
(define-syntax-rule (make-matrix-multiply name array-axis-sum array*)
|
||||
(λ (arr brr)
|
||||
(unless (array-matrix? arr) (raise-type-error name "matrix" 0 arr brr))
|
||||
(unless (array-matrix? brr) (raise-type-error name "matrix" 1 arr brr))
|
||||
(match-define (vector ad0 ad1) (array-shape arr))
|
||||
(match-define (vector bd0 bd1) (array-shape brr))
|
||||
(unless (= ad1 bd0)
|
||||
(error name
|
||||
"1st argument column size and 2nd argument row size are not equal; given ~e and ~e"
|
||||
arr brr))
|
||||
;; Get strict versions of both because each element in both is evaluated multiple times
|
||||
(let ([arr (array->mutable-array arr)]
|
||||
[brr (array->mutable-array brr)])
|
||||
;; This next part could be done with array-permute, but it's much slower that way
|
||||
(define avs (mutable-array-data arr))
|
||||
(define bvs (mutable-array-data brr))
|
||||
;; Extend arr in the center dimension
|
||||
(define: ds-ext : (Vectorof Index) (vector ad0 bd1 ad1))
|
||||
(define arr-ext
|
||||
(unsafe-build-array
|
||||
ds-ext (λ: ([js : (Vectorof Index)])
|
||||
(define j0 (unsafe-vector-ref js 0))
|
||||
(define j1 (unsafe-vector-ref js 2))
|
||||
;(unsafe-array-ref* arr j0 j1) [twice as slow]
|
||||
(unsafe-vector-ref avs (unsafe-fx+ j1 (unsafe-fx* j0 ad1))))))
|
||||
;; Transpose brr and extend in the leftmost dimension
|
||||
;; Note that ds-ext = (vector ad0 bd1 bd0) because bd0 = ad1
|
||||
(define brr-ext
|
||||
(unsafe-build-array
|
||||
ds-ext (λ: ([js : (Vectorof Index)])
|
||||
(define j0 (unsafe-vector-ref js 2))
|
||||
(define j1 (unsafe-vector-ref js 1))
|
||||
;(unsafe-array-ref* brr j0 j1) [twice as slow]
|
||||
(unsafe-vector-ref bvs (unsafe-fx+ j1 (unsafe-fx* j0 bd1))))))
|
||||
(array-axis-sum (array* arr-ext brr-ext) 2))))
|
||||
|
||||
;; ---------------------------------------------------------------------------------------------------
|
||||
|
||||
(: matrix* (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real))
|
||||
((Matrix Number) (Matrix Number) -> (Matrix Number))))
|
||||
(define matrix* (make-matrix-multiply 'matrix* array-axis-sum array*))
|
||||
|
||||
;(: matrix-fl* ((Array Float) (Array Float) -> (Array Float)))
|
||||
;(define matrix-fl* (make-matrix-multiply 'matrix-fl* array-axis-flsum array-fl*))
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require math/array
|
||||
(require racket/list
|
||||
math/array
|
||||
(only-in typed/racket conjugate)
|
||||
"../unsafe.rkt"
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
"matrix-pointwise.rkt"
|
||||
"matrix-arithmetic.rkt"
|
||||
"matrix-basic.rkt"
|
||||
"matrix-column.rkt"
|
||||
(for-syntax racket))
|
||||
|
||||
|
||||
; TODO:
|
||||
; 1. compute null space from QR factorization
|
||||
; (better numerical stability than from Gauss elimnation)
|
||||
|
@ -25,21 +27,6 @@
|
|||
; But TR has problems with #:when so what is the proper expansion ?
|
||||
|
||||
(provide
|
||||
; basic
|
||||
matrix-ref
|
||||
matrix-scale
|
||||
matrix-row-vector?
|
||||
matrix-column-vector?
|
||||
matrix/dim ; construct
|
||||
matrix-augment ; horizontally
|
||||
matrix-stack ; vertically
|
||||
matrix-block-diagonal
|
||||
; norms
|
||||
matrix-norm
|
||||
; operators
|
||||
matrix-transpose
|
||||
matrix-conjugate
|
||||
matrix-hermitian
|
||||
matrix-inverse
|
||||
; row and column
|
||||
matrix-scale-row
|
||||
|
@ -56,7 +43,6 @@
|
|||
matrix-rank
|
||||
matrix-nullity
|
||||
matrix-determinant
|
||||
matrix-trace
|
||||
; spaces
|
||||
;matrix-column+null-space
|
||||
; solvers
|
||||
|
@ -64,17 +50,6 @@
|
|||
matrix-solve-many
|
||||
; spaces
|
||||
matrix-column-space
|
||||
; column vectors
|
||||
column ; construct
|
||||
unit-column
|
||||
result-column ; convert to lazy
|
||||
column-dimension
|
||||
column-dot
|
||||
column-norm
|
||||
column-projection
|
||||
column-normalize
|
||||
scale-column
|
||||
column+
|
||||
; projection
|
||||
projection-on-orthogonal-basis
|
||||
projection-on-orthonormal-basis
|
||||
|
@ -84,70 +59,8 @@
|
|||
; factorization
|
||||
matrix-lu
|
||||
matrix-qr
|
||||
; comprehensions
|
||||
for/matrix:
|
||||
for*/matrix:
|
||||
for/matrix-sum:
|
||||
; sequences
|
||||
in-row
|
||||
in-column
|
||||
; special matrices
|
||||
vandermonde-matrix
|
||||
)
|
||||
|
||||
;;;
|
||||
;;; Basic
|
||||
;;;
|
||||
|
||||
(: matrix-ref : (Matrix Number) Integer Integer -> Number)
|
||||
(define (matrix-ref M i j)
|
||||
((inst array-ref Number) M (vector i j)))
|
||||
|
||||
(: matrix-scale : Number (Matrix Number) -> (Matrix Number))
|
||||
(define (matrix-scale s a)
|
||||
(array-scale a s))
|
||||
|
||||
(: matrix-row-vector? : (Matrix Number) -> Boolean)
|
||||
(define (matrix-row-vector? a)
|
||||
(= (matrix-row-dimension a) 1))
|
||||
|
||||
(: matrix-column-vector? : (Matrix Number) -> Boolean)
|
||||
(define (matrix-column-vector? a)
|
||||
(= (matrix-column-dimension a) 1))
|
||||
|
||||
|
||||
;;;
|
||||
;;; Norms
|
||||
;;;
|
||||
|
||||
(: matrix-norm : (Matrix Number) -> Real)
|
||||
(define (matrix-norm a)
|
||||
(define n
|
||||
(sqrt
|
||||
(array-ref
|
||||
(array-axis-sum
|
||||
(array-axis-sum
|
||||
(matrix.sqr (matrix.magnitude a)) 0) 0)
|
||||
'#())))
|
||||
(assert n real?))
|
||||
|
||||
;;;
|
||||
;;; Operators
|
||||
;;;
|
||||
|
||||
(: matrix-transpose : (Matrix Number) -> (Matrix Number))
|
||||
(define (matrix-transpose a)
|
||||
(array-axis-swap a 0 1))
|
||||
|
||||
(: matrix-conjugate : (Matrix Number) -> (Matrix Number))
|
||||
(define (matrix-conjugate a)
|
||||
(array-conjugate a))
|
||||
|
||||
(: matrix-hermitian : (Matrix Number) -> (Matrix Number))
|
||||
(define (matrix-hermitian a)
|
||||
(matrix-transpose
|
||||
(array-conjugate a)))
|
||||
|
||||
;;;
|
||||
;;; Row and column
|
||||
;;;
|
||||
|
@ -296,7 +209,7 @@
|
|||
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer)))
|
||||
((Matrix Number) -> (Values (Matrix Number) (Listof Integer)))))
|
||||
(define (matrix-gauss-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t])
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(: loop : (Integer Integer (Matrix Number) Integer (Listof Integer)
|
||||
-> (Values (Matrix Number) (Listof Integer))))
|
||||
(define (loop i j ; i from 0 to m
|
||||
|
@ -362,20 +275,20 @@
|
|||
; TODO: Use QR or SVD instead for inexact matrices
|
||||
; See answer: http://scicomp.stackexchange.com/questions/1861/understanding-how-numpy-does-svd
|
||||
; rank = dimension of column space = dimension of row space
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(define-values (_ cols-without-pivot) (matrix-gauss-eliminate M))
|
||||
(- n (length cols-without-pivot)))
|
||||
|
||||
(: matrix-nullity : (Matrix Number) -> Integer)
|
||||
(define (matrix-nullity M)
|
||||
; nullity = dimension of null space
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(define-values (_ cols-without-pivot) (matrix-gauss-eliminate M))
|
||||
(length cols-without-pivot))
|
||||
|
||||
(: matrix-determinant : (Matrix Number) -> Number)
|
||||
(define (matrix-determinant M)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(cond
|
||||
[(= m 1) (matrix-ref M 0 0)]
|
||||
[(= m 2) (let ([a (matrix-ref M 0 0)]
|
||||
|
@ -406,18 +319,12 @@
|
|||
(set! product (* product (matrix-ref M i i))))
|
||||
product))]))
|
||||
|
||||
(: matrix-trace : (Matrix Number) -> Number)
|
||||
(define (matrix-trace M)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(for/sum: : Number ([i (in-range 0 m)])
|
||||
(matrix-ref M i i)))
|
||||
|
||||
(: matrix-column-space : (Matrix Number) -> (Listof (Matrix Number)))
|
||||
; Returns
|
||||
; 1) a list of column vectors spanning the column space
|
||||
; 2) a list of column vectors spanning the null space
|
||||
(define (matrix-column-space M)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(: M1 (Matrix Number))
|
||||
(: cols-without-pivot (Listof Integer))
|
||||
(define-values (M1 cols-without-pivot) (matrix-gauss-eliminate M #t))
|
||||
|
@ -426,7 +333,7 @@
|
|||
(for/list:
|
||||
([i : Index n]
|
||||
#:when (not (member i cols-without-pivot)))
|
||||
(matrix-column M1 i)))
|
||||
(matrix-col M1 i)))
|
||||
column-space)
|
||||
|
||||
(: matrix-row-echelon-form :
|
||||
|
@ -442,7 +349,7 @@
|
|||
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer)))
|
||||
((Matrix Number) -> (Values (Matrix Number) (Listof Integer)))))
|
||||
(define (matrix-gauss-jordan-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t])
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(: loop : (Integer Integer (Matrix Number) Integer (Listof Integer)
|
||||
-> (Values (Matrix Number) (Listof Integer))))
|
||||
(define (loop i j ; i from 0 to m
|
||||
|
@ -512,23 +419,11 @@
|
|||
(let-values ([(M wp) (matrix-gauss-jordan-eliminate M unitize-pivot-row?)])
|
||||
M))
|
||||
|
||||
(: matrix-augment : (Matrix Number) (Matrix Number) * -> (Matrix Number))
|
||||
(define (matrix-augment a . as)
|
||||
(array-append* (cons a as) 1))
|
||||
|
||||
(: matrix-stack : (Matrix Number) * -> (Matrix Number))
|
||||
(define (matrix-stack . as)
|
||||
(if (null? as)
|
||||
(error 'matrix-stack
|
||||
"expected non-empty list of matrices")
|
||||
(array-append* as 0)))
|
||||
|
||||
|
||||
(: matrix-inverse : (Matrix Number) -> (Matrix Number))
|
||||
(define (matrix-inverse M)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(unless (= m n) (error 'matrix-inverse "matrix not square"))
|
||||
(let ([MI (matrix-augment M (identity-matrix m))])
|
||||
(let ([MI (matrix-augment (list M (identity-matrix m)))])
|
||||
(define 2m (* 2 m))
|
||||
(if (index? 2m)
|
||||
(submatrix (matrix-reduced-row-echelon-form MI #t)
|
||||
|
@ -539,8 +434,8 @@
|
|||
; Return a column-vector x such that Mx = b.
|
||||
; If no such vector exists return #f.
|
||||
(define (matrix-solve M b)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (s t) (matrix-dimensions b))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(define-values (s t) (matrix-shape b))
|
||||
(define m+1 (+ m 1))
|
||||
(cond
|
||||
[(not (= t 1)) (error 'matrix-solve "expected column vector (i.e. r x 1 - matrix), got: ~a " b)]
|
||||
|
@ -548,18 +443,14 @@
|
|||
[(index? m+1)
|
||||
(submatrix
|
||||
(matrix-reduced-row-echelon-form
|
||||
(matrix-augment M b) #t)
|
||||
(matrix-augment (list M b)) #t)
|
||||
(in-range 0 m) (in-range m m+1))]
|
||||
[else (error 'matrix-solve "internatl error")]))
|
||||
|
||||
(: matrix-solve-many : (Matrix Number) (Listof (Matrix Number)) -> (Matrix Number))
|
||||
(define (matrix-solve-many M bs)
|
||||
; TODO: Rewrite matrix-augment* to use array-append when it is ready
|
||||
(: matrix-augment* : (Listof (Matrix Number)) -> (Matrix Number))
|
||||
(define (matrix-augment* vs)
|
||||
(foldl matrix-augment (car vs) (cdr vs)))
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (s t) (matrix-dimensions (car bs)))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(define-values (s t) (matrix-shape (car bs)))
|
||||
(define k (length bs))
|
||||
(define m+1 (+ m 1))
|
||||
(define m+k (+ m k))
|
||||
|
@ -567,8 +458,8 @@
|
|||
[(not (= t 1)) (error 'matrix-solve-many "expected column vector (i.e. r x 1 - matrix), got: ~a " (car bs))]
|
||||
[(not (= m s)) (error 'matrix-solve-many "expected column vectors with same number of rows as the matrix")]
|
||||
[(and (index? m+1) (index? m+k))
|
||||
(define bs-as-matrix (matrix-augment* bs))
|
||||
(define MB (matrix-augment M bs-as-matrix))
|
||||
(define bs-as-matrix (matrix-augment bs))
|
||||
(define MB (matrix-augment (list M bs-as-matrix)))
|
||||
(define reduced-MB (matrix-reduced-row-echelon-form MB #t))
|
||||
(submatrix reduced-MB
|
||||
(in-range 0 m+k)
|
||||
|
@ -584,7 +475,7 @@
|
|||
(: matrix-lu :
|
||||
(Matrix Number) -> (U False (List (Matrix Number) (Matrix Number))))
|
||||
(define (matrix-lu M)
|
||||
(define-values (m _) (matrix-dimensions M))
|
||||
(define-values (m _) (matrix-shape M))
|
||||
(define: ms : (Listof Number) '())
|
||||
(define V
|
||||
(let/ec: return : (U False (Matrix Number))
|
||||
|
@ -639,111 +530,6 @@
|
|||
(list L V))))
|
||||
|
||||
|
||||
(: column-dimension : (Column Number) -> Index)
|
||||
(define (column-dimension v)
|
||||
(if (vector? v)
|
||||
(unsafe-vector-length v)
|
||||
(matrix-row-dimension v)))
|
||||
|
||||
|
||||
(: unsafe-column->vector : (Column Number) -> (Vectorof Number))
|
||||
(define (unsafe-column->vector v)
|
||||
(if (vector? v) v
|
||||
(let ()
|
||||
(define-values (m n) (matrix-dimensions v))
|
||||
(if (= n 1)
|
||||
(mutable-array-data (array->mutable-array v))
|
||||
(error 'unsafe-column->vector
|
||||
"expected a column (vector or mx1 matrix), got ~a" v)))))
|
||||
|
||||
(: vector->column : (Vectorof Number) -> (Result-Column Number))
|
||||
(define (vector->column v)
|
||||
(define m (vector-length v))
|
||||
(flat-vector->matrix m 1 v))
|
||||
|
||||
(: column : Number * -> (Result-Column Number))
|
||||
(define (column . xs)
|
||||
(vector->column
|
||||
(list->vector xs)))
|
||||
|
||||
(: result-column : (Column Number) -> (Result-Column Number))
|
||||
(define (result-column c)
|
||||
(if (vector? c)
|
||||
(vector->column c)
|
||||
c))
|
||||
|
||||
(: scale-column : Number (Column Number) -> (Result-Column Number))
|
||||
(define (scale-column s a)
|
||||
(if (vector? a)
|
||||
(let*: ([n (vector-length a)]
|
||||
[v : (Vectorof Number) (make-vector n 0)])
|
||||
(for: ([i (in-range 0 n)]
|
||||
[x : Number (in-vector a)])
|
||||
(vector-set! v i (* s x)))
|
||||
(vector->column v))
|
||||
(matrix-scale s a)))
|
||||
|
||||
(: column+ : (Column Number) (Column Number) -> (Result-Column Number))
|
||||
(define (column+ v w)
|
||||
(cond [(and (vector? v) (vector? w))
|
||||
(let ([n (vector-length v)]
|
||||
[m (vector-length w)])
|
||||
(unless (= m n)
|
||||
(error 'column+
|
||||
"expected two column vectors of the same length, got ~a and ~a" v w))
|
||||
(define: v+w : (Vectorof Number) (make-vector n 0))
|
||||
(for: ([i (in-range 0 n)]
|
||||
[x : Number (in-vector v)]
|
||||
[y : Number (in-vector w)])
|
||||
(vector-set! v+w i (+ x y)))
|
||||
(result-column v+w))]
|
||||
[else
|
||||
(unless (= (column-dimension v) (column-dimension w))
|
||||
(error 'column+
|
||||
"expected two column vectors of the same length, got ~a and ~a" v w))
|
||||
(matrix+ (result-column v) (result-column w))]))
|
||||
|
||||
|
||||
(: column-dot : (Column Number) (Column Number) -> Number)
|
||||
(define (column-dot c d)
|
||||
(define v (unsafe-column->vector c))
|
||||
(define w (unsafe-column->vector d))
|
||||
(define m (column-dimension v))
|
||||
(define s (column-dimension w))
|
||||
(cond
|
||||
[(not (= m s)) (error 'column-dot
|
||||
"expected two mx1 matrices with same number of rows, got ~a and ~a"
|
||||
c d)]
|
||||
[else
|
||||
(for/sum: : Number ([i (in-range 0 m)])
|
||||
(assert i index?)
|
||||
; Note: If d is a vector of reals,
|
||||
; then the conjugate is a no-op
|
||||
(* (unsafe-vector-ref v i)
|
||||
(conjugate (unsafe-vector-ref w i))))]))
|
||||
|
||||
(: column-norm : (Column Number) -> Real)
|
||||
(define (column-norm v)
|
||||
(define norm (sqrt (column-dot v v)))
|
||||
(assert norm real?))
|
||||
|
||||
|
||||
(: column-projection : (Column Number) (Column Number) -> (Result-Column Number))
|
||||
; (column-projection v w)
|
||||
; Return the projection og vector v on vector w.
|
||||
(define (column-projection v w)
|
||||
(let ([w.w (column-dot w w)])
|
||||
(if (zero? w.w)
|
||||
(error 'column-projection "projection on the zero vector not defined")
|
||||
(matrix-scale (/ (column-dot v w) w.w) (result-column w)))))
|
||||
|
||||
(: column-projection-on-unit : (Column Number) (Column Number) -> (Result-Column Number))
|
||||
; (column-projection-on-unit v w)
|
||||
; Return the projection og vector v on a unit vector w.
|
||||
(define (column-projection-on-unit v w)
|
||||
(matrix-scale (column-dot v w) (result-column w)))
|
||||
|
||||
|
||||
(: projection-on-orthogonal-basis :
|
||||
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
|
||||
; (projection-on-orthogonal-basis v bs)
|
||||
|
@ -754,20 +540,9 @@
|
|||
(if (null? bs)
|
||||
(error 'projection-on-orthogonal-basis
|
||||
"received empty list of basis vectors")
|
||||
(for/matrix-sum: : Number ([b (in-list bs)])
|
||||
(column-projection v (result-column b)))))
|
||||
|
||||
; #;(for/matrix-sum ([b bs])
|
||||
; (matrix-scale (column-dot v b) b))
|
||||
; (define: sum : (U False (Result-Column Number)) #f)
|
||||
; (for ([b1 (in-list bs)])
|
||||
; (define: b : (Result-Column Number) (result-column b1))
|
||||
; (cond [(not sum) (set! sum (column-projection v b))]
|
||||
; [else (set! sum (matrix+ (assert sum) (column-projection v b)))]))
|
||||
; (cond [sum (assert sum)]
|
||||
; [else (error 'projection-on-orthogonal-basis
|
||||
; "received empty list of basis vectors")])
|
||||
|
||||
(matrix-sum (map (λ: ([b : (Column Number)])
|
||||
(column-project v (->col-matrix b)))
|
||||
bs))))
|
||||
|
||||
; (projection-on-orthonormal-basis v bs)
|
||||
; Project the vector v on the orthonormal basis vectors in bs.
|
||||
|
@ -776,35 +551,29 @@
|
|||
(: projection-on-orthonormal-basis :
|
||||
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
|
||||
(define (projection-on-orthonormal-basis v bs)
|
||||
#;(for/matrix-sum ([b bs]) (matrix-scale (column-dot v b) b))
|
||||
#;(for/matrix-sum ([b bs]) (matrix-scale b (column-dot v b)))
|
||||
(define: sum : (U False (Result-Column Number)) #f)
|
||||
(for ([b1 (in-list bs)])
|
||||
(define: b : (Result-Column Number) (result-column b1))
|
||||
(cond [(not sum) (set! sum (column-projection-on-unit v b))]
|
||||
[else (set! sum (matrix+ (assert sum) (column-projection-on-unit v b)))]))
|
||||
(define: b : (Result-Column Number) (->col-matrix b1))
|
||||
(cond [(not sum) (set! sum (column-project/unit v b))]
|
||||
[else (set! sum (array+ (assert sum) (column-project/unit v b)))]))
|
||||
(cond [sum (assert sum)]
|
||||
[else (error 'projection-on-orthogonal-basis
|
||||
[else (error 'projection-on-orthonormal-basis
|
||||
"received empty list of basis vectors")]))
|
||||
|
||||
|
||||
(: zero-column-vector? : (Matrix Number) -> Boolean)
|
||||
(define (zero-column-vector? v)
|
||||
(define-values (m n) (matrix-dimensions v))
|
||||
(for/and: ([i (in-range 0 m)])
|
||||
(zero? (matrix-ref v i 0))))
|
||||
|
||||
(: gram-schmidt-orthogonal : (Listof (Column Number)) -> (Listof (Result-Column Number)))
|
||||
; (gram-schmidt-orthogonal ws)
|
||||
; Given a list ws of column vectors, produce
|
||||
; an orthogonal basis for the span of the
|
||||
; vectors in ws.
|
||||
(define (gram-schmidt-orthogonal ws1)
|
||||
(define ws (map result-column ws1))
|
||||
(define ws (map (λ: ([w : (Column Number)]) (->col-matrix w)) ws1))
|
||||
(cond
|
||||
[(null? ws) '()]
|
||||
[(null? (cdr ws)) (list (car ws))]
|
||||
[else
|
||||
(: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number)) -> (Listof (Result-Column Number)))
|
||||
(: loop : (Listof (Result-Column Number)) (Listof (Column-Matrix Number))
|
||||
-> (Listof (Result-Column Number)))
|
||||
(define (loop vs ws)
|
||||
(cond [(null? ws) vs]
|
||||
[else
|
||||
|
@ -812,84 +581,32 @@
|
|||
(let ([w-proj (projection-on-orthogonal-basis w vs)])
|
||||
; Note: We project onto vs (not on the original ws)
|
||||
; in order to get numerical stability.
|
||||
(let ([w-minus-proj (matrix- w w-proj)])
|
||||
(if (zero-column-vector? w-minus-proj)
|
||||
(let ([w-minus-proj (array-strict (array- w w-proj))])
|
||||
(if (zero-matrix? w-minus-proj)
|
||||
(loop vs (cdr ws)) ; w in span{vs} => omit it
|
||||
(loop (cons (matrix- w w-proj) vs) (cdr ws)))))]))
|
||||
(loop (cons w-minus-proj vs) (cdr ws)))))]))
|
||||
(reverse (loop (list (car ws)) (cdr ws)))]))
|
||||
|
||||
|
||||
|
||||
(: column-normalize : (Column Number) -> (Result-Column Number))
|
||||
; (column-vector-normalize v)
|
||||
; Return unit vector with same direction as v.
|
||||
; If v is the zero vector, the zero vector is returned.
|
||||
(define (column-normalize w)
|
||||
(let ([norm (column-norm w)]
|
||||
[w (result-column w)])
|
||||
(cond [(zero? norm) w]
|
||||
[else (matrix-scale (/ norm) w)])))
|
||||
|
||||
(: gram-schmidt-orthonormal : (Listof (Column Number)) -> (Listof (Result-Column Number)))
|
||||
; (gram-schmidt-orthonormal ws)
|
||||
; Given a list ws of column vectors, produce
|
||||
; an orthonormal basis for the span of the
|
||||
; vectors in ws.
|
||||
(define (gram-schmidt-orthonormal ws)
|
||||
(map column-normalize
|
||||
(gram-schmidt-orthogonal ws)))
|
||||
(map column-normalize (gram-schmidt-orthogonal ws)))
|
||||
|
||||
(: projection-on-subspace :
|
||||
(Column Number) (Listof (Column Number)) -> (Result-Column Number))
|
||||
; (projection-on-subspace v ws)
|
||||
; Returns the projection of v on span{w_i}, w_i in ws.
|
||||
(define (projection-on-subspace v ws)
|
||||
(projection-on-orthogonal-basis
|
||||
v (gram-schmidt-orthogonal ws)))
|
||||
|
||||
(: unit-column : Integer Integer -> (Result-Column Number))
|
||||
(define (unit-column m i)
|
||||
(cond
|
||||
[(and (index? m) (index? i))
|
||||
(define v (make-vector m 0))
|
||||
(if (< i m)
|
||||
(vector-set! v i 1)
|
||||
(error 'unit-vector "dimension must be largest"))
|
||||
(flat-vector->matrix m 1 v)]
|
||||
[else
|
||||
(error 'unit-vector "expected two indices")]))
|
||||
|
||||
|
||||
(: take : (All (A) ((Listof A) Index -> (Listof A))))
|
||||
(define (take xs n)
|
||||
(if (= n 0)
|
||||
'()
|
||||
(let ([n-1 (- n 1)])
|
||||
(if (index? n-1)
|
||||
(cons (car xs) (take (cdr xs) n-1))
|
||||
(error 'take "can not take more elements than the length of the list")))))
|
||||
|
||||
; (list 'take (equal? (take (list 0 1 2 3 4) 2) '(0 1)))
|
||||
|
||||
(: matrix->columns : (Matrix Number) -> (Listof (Matrix Number)))
|
||||
(define (matrix->columns M)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(for/list: : (Listof (Matrix Number))
|
||||
([j (in-range 0 n)])
|
||||
(matrix-column M (assert j index?))))
|
||||
|
||||
(: matrix-augment* : (Listof (Matrix Number)) -> (Matrix Number))
|
||||
(define (matrix-augment* Ms)
|
||||
(define MM (car Ms))
|
||||
(for: ([M (in-list (cdr Ms))])
|
||||
(set! MM (matrix-augment MM M)))
|
||||
MM)
|
||||
(projection-on-orthogonal-basis v (gram-schmidt-orthogonal ws)))
|
||||
|
||||
(: extend-span-to-basis :
|
||||
(Listof (Matrix Number)) Integer -> (Listof (Matrix Number)))
|
||||
; Extend the basis in vs to with rdimensional basis
|
||||
(define (extend-span-to-basis vs r)
|
||||
(define-values (m n) (matrix-dimensions (car vs)))
|
||||
(define-values (m n) (matrix-shape (car vs)))
|
||||
(: loop : (Listof (Matrix Number)) (Listof (Matrix Number)) Integer -> (Listof (Matrix Number)))
|
||||
(define (loop vs ws i)
|
||||
(if (>= i m)
|
||||
|
@ -897,15 +614,15 @@
|
|||
(let ()
|
||||
(define ei (unit-column m i))
|
||||
(define pi (projection-on-subspace ei vs))
|
||||
(if (matrix-all= ei pi)
|
||||
(if (matrix= ei pi)
|
||||
(loop vs ws (+ i 1))
|
||||
(let ([w (matrix- ei pi)])
|
||||
(let ([w (array- ei pi)])
|
||||
(loop (cons w vs) (cons w ws) (+ i 1)))))))
|
||||
(: norm> : (Matrix Number) (Matrix Number) -> Boolean)
|
||||
(define (norm> v w)
|
||||
(> (column-norm v) (column-norm w)))
|
||||
(if (index? r)
|
||||
((inst take (Matrix Number)) (sort (loop vs '() 0) norm>) r)
|
||||
(take (sort (loop vs '() 0) norm>) r)
|
||||
(error 'extend-span-to-basis "expected index as second argument, got ~a" r)))
|
||||
|
||||
(: matrix-qr : (Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
|
||||
|
@ -915,13 +632,13 @@
|
|||
; 2) columns of Q is are orthonormal
|
||||
; 3) R is upper-triangular
|
||||
; Note: columnspace(A)=columnspace(Q) !
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(let* ([basis-for-column-space
|
||||
(gram-schmidt-orthonormal (matrix->columns M))]
|
||||
(gram-schmidt-orthonormal (matrix-cols M))]
|
||||
[extension
|
||||
(extend-span-to-basis
|
||||
basis-for-column-space (- n (length basis-for-column-space)))]
|
||||
[Q (matrix-augment*
|
||||
[Q (matrix-augment
|
||||
(append basis-for-column-space
|
||||
(map column-normalize
|
||||
extension)))]
|
||||
|
@ -938,305 +655,5 @@
|
|||
(set! sum (+ sum (* (matrix-ref Q k i)
|
||||
(matrix-ref M k j)))))
|
||||
(vector-set! v (+ (* i n) j) sum))))
|
||||
(flat-vector->matrix n n v))])
|
||||
(vector->matrix n n v))])
|
||||
(values Q R)))
|
||||
|
||||
(: matrix/dim : Integer Integer Number * -> (Matrix Number))
|
||||
; construct a mxn matrix with elements from the values xs
|
||||
; the length of xs must be m*n
|
||||
(define (matrix/dim m n . xs)
|
||||
(cond [(and (index? m) (index? n))
|
||||
(flat-vector->matrix m n (list->vector xs))]
|
||||
[else (error 'matrix/dim "expected two indices as dimensions, got ~a and ~a" m n)]))
|
||||
|
||||
(: matrix-block-diagonal : (Listof (Matrix Number)) -> (Matrix Number))
|
||||
(define (matrix-block-diagonal as)
|
||||
(define sum-m 0)
|
||||
(define sum-n 0)
|
||||
(define: ms : (Listof Index) '())
|
||||
(define: ns : (Listof Index) '())
|
||||
(for: ([a (in-list as)])
|
||||
(define-values (m n) (matrix-dimensions a))
|
||||
(set! sum-m (+ sum-m m))
|
||||
(set! sum-n (+ sum-n n))
|
||||
(set! ms (cons m ms))
|
||||
(set! ns (cons n ns)))
|
||||
(set! ms (reverse ms))
|
||||
(set! ns (reverse ns))
|
||||
(: loop : (Listof (Matrix Number)) (Listof Index) (Listof Index)
|
||||
(Listof (Matrix Number)) Integer -> (Matrix Number))
|
||||
(define (loop as ms ns rows left)
|
||||
(cond [(null? as) (apply matrix-stack (reverse rows))]
|
||||
[else
|
||||
(define m (car ms))
|
||||
(define n (car ns))
|
||||
(define a (car as))
|
||||
(define row
|
||||
(matrix-augment
|
||||
((inst make-matrix Number) m left 0) a (make-matrix m (- sum-n n left) 0)))
|
||||
(loop (cdr as) (cdr ms) (cdr ns) (cons row rows) (+ left n))]))
|
||||
(loop as ms ns '() 0))
|
||||
|
||||
(define-syntax (for/column: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: flat-vector : (Vectorof Number) (make-vector m 0))
|
||||
(for: ([i (in-range m)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! flat-vector i x))
|
||||
(vector->column flat-vector)))]))
|
||||
|
||||
(define-syntax (for/matrix: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: k : Index 0)
|
||||
(for: ([i (in-range m*n)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
||||
(set! k (assert (+ k 1) index?)))
|
||||
(flat-vector->matrix m n v)))]
|
||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(for: ([i (in-range m*n)] for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v i x))
|
||||
(flat-vector->matrix m n v)))]))
|
||||
|
||||
(define-syntax (for*/matrix: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type m-expr n-expr #:column (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: k : Index 0)
|
||||
(for*: (for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v (+ (* n (remainder k m)) (quotient k m)) x)
|
||||
(set! k (assert (+ k 1) index?)))
|
||||
(flat-vector->matrix m n v)))]
|
||||
[(_ : type m-expr n-expr (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: m : Index m-expr)
|
||||
(define: n : Index n-expr)
|
||||
(define: m*n : Index (assert (* m n) index?))
|
||||
(define: v : (Vectorof Number) (make-vector m*n 0))
|
||||
(define: i : Index 0)
|
||||
(for*: (for:-clause ...)
|
||||
(define x (let () . defs+exprs))
|
||||
(vector-set! v i x)
|
||||
(set! i (assert (+ i 1) index?)))
|
||||
(flat-vector->matrix m n v)))]))
|
||||
|
||||
|
||||
|
||||
(define-syntax (for/matrix-sum: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: sum : (U False (Matrix Number)) #f)
|
||||
(for: (for:-clause ...)
|
||||
(define a (let () . defs+exprs))
|
||||
(set! sum (if sum (matrix+ (assert sum) a) a)))
|
||||
(assert sum)))]))
|
||||
|
||||
;;;
|
||||
;;; SEQUENCES
|
||||
;;;
|
||||
|
||||
(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
||||
(define (in-row/proc M r)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ: ([j : Index]) (matrix-ref M r j))
|
||||
; next-pos
|
||||
(λ: ([j : Index]) (assert (+ j 1) index?))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ: ([j : Index ]) (< j n))
|
||||
#f #f))))
|
||||
|
||||
(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
||||
(define (in-column/proc M s)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ: ([i : Index]) (matrix-ref M i s))
|
||||
; next-pos
|
||||
(λ: ([i : Index]) (assert (+ i 1) index?))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ: ([i : Index]) (< i m))
|
||||
#f #f))))
|
||||
|
||||
; (in-row M i]
|
||||
; Returns a sequence of all elements of row i,
|
||||
; that is xi0, xi1, xi2, ...
|
||||
(define-sequence-syntax in-row
|
||||
(λ () #'in-row/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
[[(x) (_ M-expr r-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r))
|
||||
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr r-expr)]
|
||||
#'((i x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
|
||||
|
||||
; (in-column M j]
|
||||
; Returns a sequence of all elements of column j,
|
||||
; that is x0j, x1j, x2j, ...
|
||||
|
||||
(define-sequence-syntax in-column
|
||||
(λ () #'in-column/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
; M-expr evaluates to column
|
||||
[[(x) (_ M-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(values M1 rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(unless (array-matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d j)])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
; M-expr evaluats to matrix, s-expr to the column index
|
||||
[[(x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-row "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-col "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
||||
|
||||
(: vandermonde-matrix : (Listof Number) Integer -> (Matrix Number))
|
||||
(define (vandermonde-matrix xs n)
|
||||
; construct matrix M with M(i,j)=α_i^j ; where i and j begin from 0 ...
|
||||
; Inefficient version:
|
||||
(cond
|
||||
[(not (index? n))
|
||||
(error 'vandermonde-matrix "expected Index as second argument, got ~a" n)]
|
||||
[else (define: m : Index (length xs))
|
||||
(define: αs : (Vectorof Number) (list->vector xs))
|
||||
(define: α^j : (Vectorof Number) (make-vector n 1))
|
||||
(for*/matrix: : Number m n #:column
|
||||
([j (in-range 0 n)]
|
||||
[i (in-range 0 m)])
|
||||
(define αi^j (vector-ref α^j i))
|
||||
(define αi (vector-ref αs i ))
|
||||
(vector-set! α^j i (* αi^j αi))
|
||||
αi^j)]))
|
||||
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
#lang typed/racket
|
||||
|
||||
(require "../unsafe.rkt"
|
||||
"../../array.rkt"
|
||||
"matrix-types.rkt")
|
||||
|
||||
(provide matrix+ matrix-
|
||||
matrix.sqr matrix.magnitude)
|
||||
|
||||
;; The `make-matrix-*' operators have to be macros; see ../array/array-pointwise.rkt for an
|
||||
;; explanation.
|
||||
|
||||
#;(: make-matrix-pointwise1 (All (A) (Symbol
|
||||
((Array A) -> (Array A))
|
||||
-> ((Array A) -> (Array A)))))
|
||||
(define-syntax-rule (make-matrix-pointwise1 name array-op1)
|
||||
(λ (arr)
|
||||
(unless (array-matrix? arr) (raise-type-error name "matrix" arr))
|
||||
(array-op1 arr)))
|
||||
|
||||
#;(: make-matrix-pointwise2 (All (A) (Symbol
|
||||
((Array A) (Array A) -> (Array A))
|
||||
-> ((Array A) (Array A) -> (Array A)))))
|
||||
(define-syntax-rule (make-matrix-pointwise2 name array-op2)
|
||||
(λ (arr brr)
|
||||
(unless (array-matrix? arr) (raise-type-error name "matrix" 0 arr brr))
|
||||
(unless (array-matrix? brr) (raise-type-error name "matrix" 1 arr brr))
|
||||
(array-op2 arr brr)))
|
||||
|
||||
#;(: make-matrix-pointwise1/2
|
||||
(All (A) (Symbol
|
||||
(case-> ((Array A) -> (Array A))
|
||||
((Array A) (Array A) -> (Array A)))
|
||||
->
|
||||
(case-> ((Array A) -> (Array A))
|
||||
((Array A) (Array A) -> (Array A))))))
|
||||
(define-syntax-rule (make-matrix-pointwise1/2 name array-op1/2)
|
||||
(case-lambda
|
||||
[(arr) ((make-matrix-pointwise1 name array-op1/2) arr)]
|
||||
[(arr brr) ((make-matrix-pointwise2 name array-op1/2) arr brr)]))
|
||||
|
||||
;; ---------------------------------------------------------------------------------------------------
|
||||
|
||||
(: matrix+ (case-> ((Matrix Real) (Matrix Real) -> (Matrix Real))
|
||||
((Matrix Number) (Matrix Number) -> (Matrix Number))))
|
||||
(: matrix- (case-> ((Matrix Real) -> (Matrix Real))
|
||||
((Matrix Number) -> (Matrix Number))
|
||||
((Matrix Real) (Matrix Real) -> (Matrix Real))
|
||||
((Matrix Number) (Matrix Number) -> (Matrix Number))))
|
||||
(: matrix.sqr (case-> ((Matrix Real) -> (Matrix Real))
|
||||
((Matrix Number) -> (Matrix Number))))
|
||||
(: matrix.magnitude ((Matrix Number) -> (Matrix Real)))
|
||||
|
||||
(define matrix+ (make-matrix-pointwise2 'matrix+ array+))
|
||||
(define matrix- (make-matrix-pointwise1/2 'matrix- array-))
|
||||
(define matrix.sqr (make-matrix-pointwise1 'matrix.sqr array-sqr))
|
||||
(define matrix.magnitude (make-matrix-pointwise1 'matrix.magnitude array-magnitude))
|
|
@ -1,100 +1,16 @@
|
|||
#lang racket
|
||||
(provide for/matrix
|
||||
for*/matrix
|
||||
in-row
|
||||
|
||||
(provide in-row
|
||||
in-column)
|
||||
|
||||
(require math/array
|
||||
(except-in math/matrix in-row in-column))
|
||||
|
||||
;;; COMPREHENSIONS
|
||||
|
||||
; (for/matrix m n (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first n values produced becomes the first row.
|
||||
; The next n values becomes the second row and so on.
|
||||
; The bindings in clauses run in parallel.
|
||||
(define-syntax (for/matrix stx)
|
||||
(syntax-case stx ()
|
||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ([m m-expr] [n n-expr])
|
||||
(define flat-vector
|
||||
(for/vector #:length (* m n)
|
||||
(clause ...) . defs+exprs))
|
||||
; TODO (efficiency): Use a flat-vector->array instead
|
||||
(flat-vector->matrix m n flat-vector)))]))
|
||||
|
||||
; (for*/matrix m n (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first n values produced becomes the first row.
|
||||
; The next n values becomes the second row and so on.
|
||||
; The bindings in clauses run nested.
|
||||
; (for*/matrix m n #:column (clause ...) . defs+exprs)
|
||||
; Return an m x n matrix with elements from the last expr.
|
||||
; The first m values produced becomes the first column.
|
||||
; The next m values becomes the second column and so on.
|
||||
; The bindings in clauses run nested.
|
||||
(define-syntax (for*/matrix stx)
|
||||
(syntax-case stx ()
|
||||
[(_ m-expr n-expr #:column (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let* ([m m-expr]
|
||||
[n n-expr]
|
||||
[v (make-vector (* m n) 0)]
|
||||
[w (for*/vector #:length (* m n) (clause ...) . defs+exprs)])
|
||||
(for* ([i (in-range m)] [j (in-range n)])
|
||||
(vector-set! v (+ (* i n) j)
|
||||
(vector-ref w (+ (* j m) i))))
|
||||
(flat-vector->matrix m n v)))]
|
||||
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ([m m-expr] [n n-expr])
|
||||
(flat-vector->matrix
|
||||
m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))]))
|
||||
|
||||
; TODO: The following is uncommented until matrix+ can be imported.
|
||||
|
||||
; (for/matrix-sum (clause ...) . defs+exprs)
|
||||
; Return the matrix sum of all matrices produced by the last expr.
|
||||
; The bindings in clauses are parallel.
|
||||
|
||||
;(define-syntax (for/matrix-sum stx)
|
||||
; (syntax-case stx ()
|
||||
; [(_ (clause ...) . defs+exprs)
|
||||
; (syntax/loc stx
|
||||
; (let ([ms (for/list (clause ...) . defs+exprs)])
|
||||
; (foldl matrix+ (first ms) (rest ms))))]))
|
||||
;
|
||||
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
|
||||
; (for/matrix-sum ([i 3]) M))
|
||||
; (flat-vector->matrix 2 2 #(3 6 9 12)))
|
||||
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
|
||||
; (for/matrix-sum ([i 2] [j 2]) M))
|
||||
; (flat-vector->matrix 2 2 #(2 4 6 8)))
|
||||
|
||||
; (for*/matrix-sum (clause ...) . defs+exprs)
|
||||
; Return the matrix sum of all matrices produced by the last expr.
|
||||
; The bindings in clauses are in nested.
|
||||
|
||||
;(define-syntax (for*/matrix-sum stx)
|
||||
; (syntax-case stx ()
|
||||
; [(_ (clause ...) . defs+exprs)
|
||||
; (syntax/loc stx
|
||||
; (let ([ms (for*/list (clause ...) . defs+exprs)])
|
||||
; (foldl matrix+ (first ms) (rest ms))))]))
|
||||
;
|
||||
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
|
||||
; (for*/matrix-sum ([i 2] [j 2]) M))
|
||||
; (flat-vector->matrix 2 2 #(4 8 12 16)))
|
||||
|
||||
|
||||
;;;
|
||||
;;; SEQUENCES
|
||||
;;;
|
||||
"matrix-types.rkt"
|
||||
"matrix-basic.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
)
|
||||
|
||||
(define (in-row/proc M r)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
|
@ -120,12 +36,12 @@
|
|||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r))
|
||||
|
@ -142,12 +58,12 @@
|
|||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
|
@ -167,7 +83,7 @@
|
|||
|
||||
|
||||
(define (in-column/proc M s)
|
||||
(define-values (m n) (matrix-dimensions M))
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
|
@ -190,12 +106,12 @@
|
|||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-row "expected col number, got ~a" s))
|
||||
|
@ -212,12 +128,12 @@
|
|||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-dimensions M1))
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (array-matrix? M)
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s))
|
||||
|
@ -233,28 +149,192 @@
|
|||
[[_ clause] (raise-syntax-error
|
||||
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
||||
|
||||
(define-syntax (for/matrix-sum: stx)
|
||||
(syntax-case stx ()
|
||||
[(_ : type (for:-clause ...) . defs+exprs)
|
||||
(syntax/loc stx
|
||||
(let ()
|
||||
(define: sum : (U False (Matrix Number)) #f)
|
||||
(for: (for:-clause ...)
|
||||
(define a (let () . defs+exprs))
|
||||
(set! sum (if sum (array+ (assert sum) a) a)))
|
||||
(assert sum)))]))
|
||||
#|
|
||||
;;;
|
||||
;;; SEQUENCES
|
||||
;;;
|
||||
|
||||
(: in-row/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
||||
(define (in-row/proc M r)
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ: ([j : Index]) (matrix-ref M r j))
|
||||
; next-pos
|
||||
(λ: ([j : Index]) (assert (+ j 1) index?))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ: ([j : Index ]) (< j n))
|
||||
#f #f))))
|
||||
|
||||
(: in-column/proc : (Matrix Number) Integer -> (Sequenceof Number))
|
||||
(define (in-column/proc M s)
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(make-do-sequence
|
||||
(λ ()
|
||||
(values
|
||||
; pos->element
|
||||
(λ: ([i : Index]) (matrix-ref M i s))
|
||||
; next-pos
|
||||
(λ: ([i : Index]) (assert (+ i 1) index?))
|
||||
; initial-pos
|
||||
0
|
||||
; continue-with-pos?
|
||||
(λ: ([i : Index]) (< i m))
|
||||
#f #f))))
|
||||
|
||||
; (in-row M i]
|
||||
; Returns a sequence of all elements of row i,
|
||||
; that is xi0, xi1, xi2, ...
|
||||
(define-sequence-syntax in-row
|
||||
(λ () #'in-row/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
[[(x) (_ M-expr r-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r))
|
||||
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr r-expr)]
|
||||
#'((i x)
|
||||
(:do-in
|
||||
([(M r n d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 r-expr rd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? r)
|
||||
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d (+ (* r n) j))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
|
||||
|
||||
; (in-column M j]
|
||||
; Returns a sequence of all elements of column j,
|
||||
; that is x0j, x1j, x2j, ...
|
||||
|
||||
(define-sequence-syntax in-column
|
||||
(λ () #'in-column/proc)
|
||||
(λ (stx)
|
||||
(syntax-case stx ()
|
||||
; M-expr evaluates to column
|
||||
[[(x) (_ M-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
([j 0])
|
||||
(< j n)
|
||||
([(x) (vector-ref d j)])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
; M-expr evaluats to matrix, s-expr to the column index
|
||||
[[(x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-row "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-col "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[(i x) (_ M-expr s-expr)]
|
||||
#'((x)
|
||||
(:do-in
|
||||
([(M s n m d)
|
||||
(let ([M1 M-expr])
|
||||
(define-values (rd cd) (matrix-shape M1))
|
||||
(values M1 s-expr rd cd
|
||||
(mutable-array-data
|
||||
(array->mutable-array M1))))])
|
||||
(begin
|
||||
(unless (matrix? M)
|
||||
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
||||
(unless (integer? s)
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s))
|
||||
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
||||
(raise-type-error 'in-column "expected col number, got ~a" s)))
|
||||
([j 0])
|
||||
(< j m)
|
||||
([(x) (vector-ref d (+ (* j n) s))]
|
||||
[(i) j])
|
||||
#true
|
||||
#true
|
||||
[(+ j 1)]))]
|
||||
[[_ clause] (raise-syntax-error
|
||||
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
||||
|#
|
||||
#;
|
||||
(module* test #f
|
||||
(require (except-in math/matrix in-row in-column)
|
||||
rackunit)
|
||||
(require rackunit)
|
||||
; "matrix-sequences.rkt"
|
||||
; These work in racket not in typed racket
|
||||
(check-equal? (matrix->list (for*/matrix 2 3 ([i 2] [j 3]) (+ i j)))
|
||||
'[[0 1 2] [1 2 3]])
|
||||
(check-equal? (matrix->list (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j)))
|
||||
'[[0 2 2] [1 1 3]])
|
||||
(check-equal? (matrix->list (for*/matrix 2 2 #:column ([i 4]) i))
|
||||
'[[0 2] [1 3]])
|
||||
(check-equal? (matrix->list (for/matrix 2 2 ([i 4]) i))
|
||||
'[[0 1] [2 3]])
|
||||
(check-equal? (matrix->list (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j)))
|
||||
'[[6 8 10] [12 14 16]])
|
||||
(check-equal? (for/list ([x (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
||||
(check-equal? (for/list ([x (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
||||
'(3 4))
|
||||
(check-equal? (for/list ([(i x) (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)])
|
||||
(check-equal? (for/list ([(i x) (in-row (vector->matrix 2 2 #(1 2 3 4)) 1)])
|
||||
(list i x))
|
||||
'((0 3) (1 4)))
|
||||
(check-equal? (for/list ([x (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
||||
(check-equal? (for/list ([x (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
||||
'(2 4))
|
||||
(check-equal? (for/list ([(i x) (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)])
|
||||
(check-equal? (for/list ([(i x) (in-column (vector->matrix 2 2 #(1 2 3 4)) 1)])
|
||||
(list i x))
|
||||
'((0 2) (1 4))))
|
||||
|
|
|
@ -1,16 +1,21 @@
|
|||
#lang typed/racket
|
||||
#lang typed/racket/base
|
||||
|
||||
(require "../array/array-struct.rkt"
|
||||
"../array/array-fold.rkt"
|
||||
"../array/array-pointwise.rkt"
|
||||
"../unsafe.rkt")
|
||||
|
||||
(provide Matrix
|
||||
Column Result-Column
|
||||
Column-Matrix
|
||||
array-matrix?
|
||||
matrix-all=
|
||||
matrix?
|
||||
square-matrix?
|
||||
row-matrix?
|
||||
col-matrix?
|
||||
matrix-shape
|
||||
square-matrix-size
|
||||
matrix-dimensions
|
||||
matrix-row-dimension
|
||||
matrix-column-dimension)
|
||||
|
||||
(require math/array)
|
||||
matrix-num-rows
|
||||
matrix-num-cols)
|
||||
|
||||
(define-type (Matrix A) (Array A))
|
||||
; matrices are represented as arrays
|
||||
|
@ -22,39 +27,44 @@
|
|||
(define-type (Column-Matrix A) (Matrix A))
|
||||
; a column vector represented as a matrix
|
||||
|
||||
(: array-matrix? (All (A) ((Array A) -> Boolean)))
|
||||
(define (array-matrix? x)
|
||||
(= 2 (array-dims x)))
|
||||
(define matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (> (array-size arr) 0)
|
||||
(= (array-dims arr) 2))))
|
||||
|
||||
(: square-matrix? : (All (A) (Matrix A) -> Boolean))
|
||||
(define (square-matrix? a)
|
||||
(and (array-matrix? a)
|
||||
(let ([sh (array-shape a)])
|
||||
(= (vector-ref sh 0) (vector-ref sh 1)))))
|
||||
(define square-matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (matrix? arr)
|
||||
(let ([sh (array-shape arr)])
|
||||
(= (vector-ref sh 0) (vector-ref sh 1))))))
|
||||
|
||||
(: square-matrix-size : (All (A) (Matrix A) -> Index))
|
||||
(define (square-matrix-size a)
|
||||
(vector-ref (array-shape a) 0))
|
||||
(define row-matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (matrix? arr)
|
||||
(= (vector-ref (array-shape arr) 0) 1))))
|
||||
|
||||
(: matrix-all= : (Matrix Number) (Matrix Number) -> Boolean)
|
||||
(define (matrix-all= arr0 arr1)
|
||||
(array-all-and (array= arr0 arr1)))
|
||||
(define col-matrix?
|
||||
(plambda: (A) ([arr : (Array A)])
|
||||
(and (matrix? arr)
|
||||
(= (vector-ref (array-shape arr) 1) 1))))
|
||||
|
||||
(: matrix-dimensions : (Matrix Number) -> (Values Index Index))
|
||||
(define (matrix-dimensions a)
|
||||
(define sh (array-shape a))
|
||||
; TODO: Remove list conversion when trbug1 is fixed
|
||||
(define sh-tmp (vector->list sh))
|
||||
(values (car sh-tmp) (cadr sh-tmp))
|
||||
; (values (vector-ref sh 0) (vector-ref sh 1))
|
||||
)
|
||||
(: matrix-shape : (All (A) (Matrix A) -> (Values Index Index)))
|
||||
(define (matrix-shape a)
|
||||
(cond [(matrix? a) (define sh (array-shape a))
|
||||
(values (unsafe-vector-ref sh 0) (unsafe-vector-ref sh 1))]
|
||||
[else (raise-argument-error 'matrix-shape "matrix?" a)]))
|
||||
|
||||
(: matrix-row-dimension : (Matrix Number) -> Index)
|
||||
(define (matrix-row-dimension a)
|
||||
(define sh (array-shape a))
|
||||
(vector-ref sh 0))
|
||||
(: square-matrix-size (All (A) ((Matrix A) -> Index)))
|
||||
(define (square-matrix-size arr)
|
||||
(cond [(square-matrix? arr) (unsafe-vector-ref (array-shape arr) 0)]
|
||||
[else (raise-argument-error 'square-matrix-size "square-matrix?" arr)]))
|
||||
|
||||
(: matrix-column-dimension : (Matrix Number) -> Index)
|
||||
(define (matrix-column-dimension a)
|
||||
(define sh (array-shape a))
|
||||
(vector-ref sh 1))
|
||||
(: matrix-num-rows (All (A) ((Matrix A) -> Index)))
|
||||
(define (matrix-num-rows a)
|
||||
(cond [(matrix? a) (vector-ref (array-shape a) 0)]
|
||||
[else (raise-argument-error 'matrix-col-length "matrix?" a)]))
|
||||
|
||||
(: matrix-num-cols (All (A) ((Matrix A) -> Index)))
|
||||
(define (matrix-num-cols a)
|
||||
(cond [(matrix? a) (vector-ref (array-shape a) 1)]
|
||||
[else (raise-argument-error 'matrix-row-length "matrix?" a)]))
|
||||
|
|
93
collects/math/private/matrix/typed-matrix-arithmetic.rkt
Normal file
93
collects/math/private/matrix/typed-matrix-arithmetic.rkt
Normal file
|
@ -0,0 +1,93 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require racket/list
|
||||
math/array
|
||||
"matrix-types.rkt"
|
||||
"utils.rkt"
|
||||
(except-in "untyped-matrix-arithmetic.rkt" matrix-map))
|
||||
|
||||
(provide matrix-map
|
||||
matrix=
|
||||
matrix*
|
||||
matrix+
|
||||
matrix-
|
||||
matrix-scale
|
||||
matrix-sum)
|
||||
|
||||
(: matrix-map (All (R A B T ...)
|
||||
(case-> ((A -> R) (Array A) -> (Array R))
|
||||
((A B T ... T -> R) (Array A) (Array B) (Array T) ... T -> (Array R)))))
|
||||
(define matrix-map
|
||||
(case-lambda:
|
||||
[([f : (A -> R)] [arr : (Array A)])
|
||||
(inline-matrix-map f arr)]
|
||||
[([f : (A B -> R)] [arr0 : (Array A)] [arr1 : (Array B)])
|
||||
(inline-matrix-map f arr0 arr1)]
|
||||
[([f : (A B T ... T -> R)] [arr0 : (Array A)] [arr1 : (Array B)] . [arrs : (Array T) ... T])
|
||||
(define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs))
|
||||
(define g0 (unsafe-array-proc arr0))
|
||||
(define g1 (unsafe-array-proc arr1))
|
||||
(define gs (map unsafe-array-proc arrs))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n)
|
||||
(λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
|
||||
(map (λ: ([g : (Indexes -> T)]) (g js)) gs))))]))
|
||||
|
||||
(: matrix=? ((Array Number) (Array Number) -> Boolean))
|
||||
(define (matrix=? arr0 arr1)
|
||||
(define-values (m0 n0) (matrix-shape arr0))
|
||||
(define-values (m1 n1) (matrix-shape arr1))
|
||||
(and (= m0 m1)
|
||||
(= n0 n1)
|
||||
(let ([proc0 (unsafe-array-proc arr0)]
|
||||
[proc1 (unsafe-array-proc arr1)])
|
||||
(array-all-and (unsafe-build-array
|
||||
((inst vector Index) m0 n0)
|
||||
(λ: ([js : Indexes])
|
||||
(= (proc0 js) (proc1 js))))))))
|
||||
|
||||
(: matrix= (case-> ((Array Number) (Array Number) -> Boolean)
|
||||
((Array Number) (Array Number) (Array Number) (Array Number) * -> Boolean)))
|
||||
(define matrix=
|
||||
(case-lambda:
|
||||
[([arr0 : (Array Number)] [arr1 : (Array Number)]) (matrix=? arr0 arr1)]
|
||||
[([arr0 : (Array Number)] [arr1 : (Array Number)] . [arrs : (Array Number) *])
|
||||
(and (matrix=? arr0 arr1)
|
||||
(let: loop : Boolean ([arr1 : (Array Number) arr1]
|
||||
[arrs : (Listof (Array Number)) arrs])
|
||||
(cond [(empty? arrs) #t]
|
||||
[else (and (matrix=? arr1 (first arrs))
|
||||
(loop (first arrs) (rest arrs)))])))]))
|
||||
|
||||
(: matrix* (case-> ((Array Real) (Array Real) * -> (Array Real))
|
||||
((Array Number) (Array Number) * -> (Array Number))))
|
||||
(define (matrix* a . as)
|
||||
(let loop ([a a] [as as])
|
||||
(cond [(empty? as) a]
|
||||
[else (loop (inline-matrix* a (first as)) (rest as))])))
|
||||
|
||||
(: matrix+ (case-> ((Array Real) (Array Real) * -> (Array Real))
|
||||
((Array Number) (Array Number) * -> (Array Number))))
|
||||
(define (matrix+ a . as)
|
||||
(let loop ([a a] [as as])
|
||||
(cond [(empty? as) a]
|
||||
[else (loop (inline-matrix+ a (first as)) (rest as))])))
|
||||
|
||||
(: matrix- (case-> ((Array Real) (Array Real) * -> (Array Real))
|
||||
((Array Number) (Array Number) * -> (Array Number))))
|
||||
(define (matrix- a . as)
|
||||
(cond [(empty? as) (inline-matrix- a)]
|
||||
[else
|
||||
(let loop ([a a] [as as])
|
||||
(cond [(empty? as) a]
|
||||
[else (loop (inline-matrix- a (first as)) (rest as))]))]))
|
||||
|
||||
(: matrix-scale (case-> ((Array Real) Real -> (Array Real))
|
||||
((Array Number) Number -> (Array Number))))
|
||||
(define (matrix-scale a x) (inline-matrix-scale a x))
|
||||
|
||||
(: matrix-sum (case-> ((Listof (Array Real)) -> (Array Real))
|
||||
((Listof (Array Number)) -> (Array Number))))
|
||||
(define (matrix-sum lst)
|
||||
(cond [(empty? lst) (raise-argument-error 'matrix-sum "nonempty List" lst)]
|
||||
[else (apply matrix+ lst)]))
|
105
collects/math/private/matrix/untyped-matrix-arithmetic.rkt
Normal file
105
collects/math/private/matrix/untyped-matrix-arithmetic.rkt
Normal file
|
@ -0,0 +1,105 @@
|
|||
#lang racket/base
|
||||
|
||||
(provide inline-matrix*
|
||||
inline-matrix+
|
||||
inline-matrix-
|
||||
inline-matrix-scale
|
||||
inline-matrix-map
|
||||
matrix-map)
|
||||
|
||||
(module syntax-defs racket/base
|
||||
(require (for-syntax racket/base)
|
||||
(only-in typed/racket/base λ: : inst Index)
|
||||
math/array
|
||||
"matrix-types.rkt"
|
||||
"utils.rkt")
|
||||
|
||||
(provide (all-defined-out))
|
||||
|
||||
;(: matrix-multiply ((Array Number) (Array Number) -> (Array Number)))
|
||||
;; This is a macro so the result can have as precise a type as possible
|
||||
(define-syntax-rule (inline-matrix-multiply arr-expr brr-expr)
|
||||
(let ([arr arr-expr]
|
||||
[brr brr-expr])
|
||||
(let-values ([(m p n) (matrix-multiply-shape arr brr)]
|
||||
;; Make arr strict because its elements are reffed multiple times
|
||||
[(_) (array-strict! arr)])
|
||||
(let (;; Extend arr in the center dimension
|
||||
[arr-proc (unsafe-array-proc (array-axis-insert arr 1 n))]
|
||||
;; Transpose brr and extend in the leftmost dimension
|
||||
[brr-proc (unsafe-array-proc
|
||||
(array-axis-insert (array-strict (array-axis-swap brr 0 1)) 0 m))])
|
||||
;; The *transpose* of brr is traversed in row-major order when this result is traversed
|
||||
;; in row-major order (which is why the transpose is strict, not brr)
|
||||
(array-axis-sum
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n p)
|
||||
(λ: ([js : Indexes])
|
||||
(* (arr-proc js) (brr-proc js))))
|
||||
2)))))
|
||||
|
||||
(define-syntax (inline-matrix* stx)
|
||||
(syntax-case stx ()
|
||||
[(_ arr)
|
||||
(syntax/loc stx arr)]
|
||||
[(_ arr brr crrs ...)
|
||||
(syntax/loc stx (inline-matrix* (inline-matrix-multiply arr brr) crrs ...))]))
|
||||
|
||||
(define-syntax (inline-matrix-map stx)
|
||||
(syntax-case stx ()
|
||||
[(_ f arr-expr)
|
||||
(syntax/loc stx
|
||||
(let*-values ([(arr) arr-expr]
|
||||
[(m n) (matrix-shape arr)]
|
||||
[(proc) (unsafe-array-proc arr)])
|
||||
(unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (f (proc js))))))]
|
||||
[(_ f arr-expr brr-exprs ...)
|
||||
(with-syntax ([(brrs ...) (generate-temporaries #'(brr-exprs ...))]
|
||||
[(procs ...) (generate-temporaries #'(brr-exprs ...))])
|
||||
(syntax/loc stx
|
||||
(let ([arr arr-expr]
|
||||
[brrs brr-exprs] ...)
|
||||
(let-values ([(m n) (matrix-shapes 'matrix-map arr brrs ...)]
|
||||
[(proc) (unsafe-array-proc arr)]
|
||||
[(procs) (unsafe-array-proc brrs)] ...)
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n)
|
||||
(λ: ([js : Indexes])
|
||||
(f (proc js) (procs js) ...)))))))]))
|
||||
|
||||
(define-syntax-rule (inline-matrix+ arr0 arrs ...) (inline-matrix-map + arr0 arrs ...))
|
||||
(define-syntax-rule (inline-matrix- arr0 arrs ...) (inline-matrix-map - arr0 arrs ...))
|
||||
(define-syntax-rule (inline-matrix-scale arr x) (inline-matrix-map (λ (y) (* x y)) arr))
|
||||
|
||||
) ; module
|
||||
|
||||
(require 'syntax-defs)
|
||||
|
||||
(module untyped-defs typed/racket/base
|
||||
(require math/array
|
||||
(submod ".." syntax-defs)
|
||||
"utils.rkt")
|
||||
|
||||
(provide matrix-map)
|
||||
|
||||
(: matrix-map (All (R A) (case-> ((A -> R) (Array A) -> (Array R))
|
||||
((A A A * -> R) (Array A) (Array A) (Array A) * -> (Array R)))))
|
||||
(define matrix-map
|
||||
(case-lambda:
|
||||
[([f : (A -> R)] [arr : (Array A)])
|
||||
(inline-matrix-map f arr)]
|
||||
[([f : (A A -> R)] [arr0 : (Array A)] [arr1 : (Array A)])
|
||||
(inline-matrix-map f arr0 arr1)]
|
||||
[([f : (A A A * -> R)] [arr0 : (Array A)] [arr1 : (Array A)] . [arrs : (Array A) *])
|
||||
(define-values (m n) (apply matrix-shapes 'matrix-map arr0 arr1 arrs))
|
||||
(define g0 (unsafe-array-proc arr0))
|
||||
(define g1 (unsafe-array-proc arr1))
|
||||
(define gs (map (inst unsafe-array-proc A) arrs))
|
||||
(unsafe-build-array
|
||||
((inst vector Index) m n)
|
||||
(λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
|
||||
(map (λ: ([g : (Indexes -> A)]) (g js)) gs))))]))
|
||||
|
||||
) ; module
|
||||
|
||||
(require 'untyped-defs)
|
|
@ -1,6 +1,38 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require math/array)
|
||||
(require racket/match
|
||||
racket/string
|
||||
math/array
|
||||
"matrix-types.rkt")
|
||||
|
||||
(provide (all-defined-out))
|
||||
|
||||
(: format-matrices/error ((Listof (Array Any)) -> String))
|
||||
(define (format-matrices/error as)
|
||||
(string-join (map (λ: ([a : (Array Any)]) (format "~e" a)) as)))
|
||||
|
||||
(: matrix-shapes (Symbol (Matrix Any) (Matrix Any) * -> (Values Index Index)))
|
||||
(define (matrix-shapes name arr . brrs)
|
||||
(define-values (m n) (matrix-shape arr))
|
||||
(unless (andmap (λ: ([brr : (Matrix Any)])
|
||||
(match-define (vector bm bn) (array-shape brr))
|
||||
(and (= bm m) (= bn n)))
|
||||
brrs)
|
||||
(error name
|
||||
"matrices must have the same shape; given ~a"
|
||||
(format-matrices/error (cons arr brrs))))
|
||||
(values m n))
|
||||
|
||||
(: matrix-multiply-shape ((Matrix Any) (Matrix Any) -> (Values Index Index Index)))
|
||||
(define (matrix-multiply-shape arr brr)
|
||||
(define-values (ad0 ad1) (matrix-shape arr))
|
||||
(define-values (bd0 bd1) (matrix-shape brr))
|
||||
(unless (= ad1 bd0)
|
||||
(error 'matrix-multiply
|
||||
"1st argument column size and 2nd argument row size are not equal; given ~e and ~e"
|
||||
arr brr))
|
||||
(values ad0 ad1 bd1))
|
||||
|
||||
(: ensure-matrix (All (A) Symbol (Array A) -> (Array A)))
|
||||
(define (ensure-matrix name a)
|
||||
(if (matrix? a) a (raise-argument-error name "matrix?" a)))
|
||||
|
|
|
@ -1385,7 +1385,7 @@ Returns an array with shape @racket[(vector (array-size arr))], with the element
|
|||
@;{==================================================================================================}
|
||||
|
||||
|
||||
@section{Folds and Other Axis Reductions}
|
||||
@section{Folds, Reductions and Expansions}
|
||||
|
||||
@defproc*[([(array-axis-fold [arr (Array A)] [k Integer] [f (A A -> A)]) (Array A)]
|
||||
[(array-axis-fold [arr (Array A)] [k Integer] [f (A B -> B)] [init B]) (Array B)])]{
|
||||
|
@ -1547,18 +1547,6 @@ and @racket[array-all-or] is defined similarly.
|
|||
(array-all-or (array= arr (array 0)))]
|
||||
}
|
||||
|
||||
@defproc[(array->list-array [arr (Array A)] [k Integer]) (Array (Listof A))]{
|
||||
Returns an array of lists, computed as if by applying @racket[list] to the elements in each row of
|
||||
axis @racket[k].
|
||||
@examples[#:eval typed-eval
|
||||
(define arr (index-array #(3 3)))
|
||||
arr
|
||||
(array->list-array arr 1)
|
||||
(array-ref (array->list-array (array->list-array arr 1) 0) #())]
|
||||
See @racket[mean] for more useful examples, and @racket[array-axis-reduce] for an example that shows
|
||||
how @racket[array->list-array] is implemented.
|
||||
}
|
||||
|
||||
@defproc[(array-axis-reduce [arr (Array A)] [k Integer] [h (Index (Integer -> A) -> B)]) (Array B)]{
|
||||
Like @racket[array-axis-fold], but allows evaluation control (such as short-cutting @racket[and] and
|
||||
@racket[or]) by moving the loop into @racket[h]. The result has the shape of @racket[arr], but with
|
||||
|
@ -1591,6 +1579,60 @@ Every fold, including @racket[array-axis-fold], is ultimately defined using
|
|||
@racket[array-axis-reduce] or its unsafe counterpart.
|
||||
}
|
||||
|
||||
@defproc[(array-axis-expand [arr (Array A)] [k Integer] [dk Integer] [g (A Index -> B)]) (Array B)]{
|
||||
Inserts a new axis number @racket[k] of length @racket[dk], using @racket[g] to generate values;
|
||||
@racket[k] must be @italic{no greater than} the dimension of @racket[arr], and @racket[dk] must be
|
||||
nonnegative.
|
||||
|
||||
Conceptually, @racket[g] is applied @racket[dk] times to each element in each row of axis @racket[k],
|
||||
once for each nonnegative index @racket[jk < dk]. (In reality, @racket[g] is applied only when the
|
||||
resulting array is indexed.)
|
||||
|
||||
Turning vector elements into rows of a new last axis using @racket[array-axis-expand] and
|
||||
@racket[vector-ref]:
|
||||
@interaction[#:eval typed-eval
|
||||
(define arr (array #['#(a b c) '#(d e f) '#(g h i)]))
|
||||
(array-axis-expand arr 1 3 vector-ref)]
|
||||
Creating a @racket[vandermonde-matrix]:
|
||||
@interaction[#:eval typed-eval
|
||||
(array-axis-expand (list->array '(1 2 3 4)) 1 5 expt)]
|
||||
|
||||
This function is a dual of @racket[array-axis-reduce] in that it can be used to invert applications
|
||||
of @racket[array-axis-reduce].
|
||||
To do so, @racket[g] should be a destructuring function that is dual to the constructor passed to
|
||||
@racket[array-axis-reduce].
|
||||
Example dual pairs are @racket[vector-ref] and @racket[build-vector], and @racket[list-ref] and
|
||||
@racket[build-list].
|
||||
|
||||
(Do not pass @racket[list-ref] to @racket[array-axis-expand] if you care about performance, though.
|
||||
See @racket[list-array->array] for a more efficient solution.)
|
||||
}
|
||||
|
||||
@defproc[(array->list-array [arr (Array A)] [k Integer 0]) (Array (Listof A))]{
|
||||
Returns an array of lists, computed as if by applying @racket[list] to the elements in each row of
|
||||
axis @racket[k].
|
||||
@examples[#:eval typed-eval
|
||||
(define arr (index-array #(3 3)))
|
||||
arr
|
||||
(array->list-array arr 1)
|
||||
(array-ref (array->list-array (array->list-array arr 1) 0) #())]
|
||||
See @racket[mean] for more useful examples, and @racket[array-axis-reduce] for an example that shows
|
||||
how @racket[array->list-array] is implemented.
|
||||
}
|
||||
|
||||
@defproc[(list-array->array [arr (Array (Listof A))] [k Integer 0]) (Array A)]{
|
||||
Returns an array in which the list elements of @racket[arr] comprise a new axis @racket[k].
|
||||
Equivalent to @racket[(array-axis-expand arr k n list-ref)] where @racket[n] is the
|
||||
length of the lists in @racket[arr], but with O(1) indexing.
|
||||
|
||||
@examples[#:eval typed-eval
|
||||
(define arr (array->list-array (index-array #(3 3)) 1))
|
||||
arr
|
||||
(list-array->array arr 1)]
|
||||
|
||||
For fixed @racket[k], this function and @racket[array->list-array] are mutual inverses with respect
|
||||
to their array arguments.
|
||||
}
|
||||
|
||||
@;{==================================================================================================}
|
||||
|
||||
|
|
|
@ -3,28 +3,86 @@
|
|||
(require math/array
|
||||
math/base
|
||||
math/flonum
|
||||
math/matrix)
|
||||
math/matrix
|
||||
"../private/matrix/matrix-column.rkt"
|
||||
"test-utils.rkt")
|
||||
|
||||
(: random-matrix (Integer Integer Integer -> (Matrix Integer)))
|
||||
;; Generates a random matrix with integer elements < k. Useful to test properties.
|
||||
(define (random-matrix m n k)
|
||||
(array-strict (build-array (vector m n) (λ (_) (random k)))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Types
|
||||
|
||||
(check-true (matrix? (array #[#[1]])))
|
||||
(check-false (matrix? (array #[1])))
|
||||
(check-false (matrix? (array 1)))
|
||||
(check-false (matrix? (array #[])))
|
||||
|
||||
(check-true (row-matrix? (matrix [[1 2 3 4]])))
|
||||
(check-false (row-matrix? (matrix [[1] [2] [3] [4]])))
|
||||
|
||||
(check-true (col-matrix? (matrix [[1] [2] [3] [4]])))
|
||||
(check-false (col-matrix? (matrix [[1 2 3 4]])))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Matrix multiplication
|
||||
|
||||
(check-equal? (matrix* (identity-matrix 2)
|
||||
(matrix [[1 20] [300 4000]]))
|
||||
(matrix [[1 20] [300 4000]]))
|
||||
|
||||
(check-equal? (matrix* (matrix [[1 2 3] [4 5 6] [7 8 9]])
|
||||
(matrix [[1 2 3] [4 5 6] [7 8 9]]))
|
||||
(matrix [[30 36 42] [66 81 96] [102 126 150]]))
|
||||
|
||||
(let ([m0 (random-matrix 4 5 100)]
|
||||
[m1 (random-matrix 5 2 100)]
|
||||
[m2 (random-matrix 2 10 100)])
|
||||
(check-equal? (matrix* (matrix* m0 m1) m2)
|
||||
(matrix* m0 (matrix* m1 m2))))
|
||||
|
||||
;; ===================================================================================================
|
||||
;; Construction
|
||||
|
||||
(check-equal?
|
||||
(block-diagonal-matrix
|
||||
(list
|
||||
(matrix [[1 2] [3 4]])
|
||||
(matrix [[1 2 3] [4 5 6]])
|
||||
(matrix [[1] [3] [5]])))
|
||||
(matrix
|
||||
[[1 2 0 0 0 0]
|
||||
[3 4 0 0 0 0]
|
||||
[0 0 1 2 3 0]
|
||||
[0 0 4 5 6 0]
|
||||
[0 0 0 0 0 1]
|
||||
[0 0 0 0 0 3]
|
||||
[0 0 0 0 0 5]]))
|
||||
|
||||
;; ===================================================================================================
|
||||
|
||||
(begin
|
||||
(begin "matrix-types.rkt"
|
||||
(list
|
||||
'array-matrix?
|
||||
(array-matrix? (list*->array '[[1 2] [3 4]] real? ))
|
||||
(not (array-matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? ))))
|
||||
'matrix?
|
||||
(matrix? (list*->array '[[1 2] [3 4]] real? ))
|
||||
(not (matrix? (list*->array '[[[1 2] [3 4]] [[1 2] [3 4]]] real? ))))
|
||||
(list
|
||||
'square-matrix?
|
||||
(square-matrix? (list*->array '[[1 2] [3 4]] real? ))
|
||||
(not (square-matrix? (list*->array '[[1 2 3] [4 5 6]] real? ))))
|
||||
(list
|
||||
'square-matrix-size
|
||||
(= 2 (square-matrix-size (list*->array '[[1 2 3] [4 5 6]] real? ))))
|
||||
(= 3 (square-matrix-size (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real?))))
|
||||
(list
|
||||
'matrix-all=-
|
||||
(matrix-all= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? ))
|
||||
(not (matrix-all= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? ))))
|
||||
'matrix=
|
||||
(matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2] [3 4]] real? ))
|
||||
#;(not (matrix= (list*->array '[[1 2] [3 4]] real?) (list*->array '[[1 2]] real? ))))
|
||||
(list
|
||||
'matrix-dimensions
|
||||
(let-values ([(m n) (matrix-dimensions (list->matrix '[[1 2 3] [4 5 6]]))])
|
||||
'matrix-shape
|
||||
(let-values ([(m n) (matrix-shape (list*->matrix '[[1 2 3] [4 5 6]]))])
|
||||
(equal? (list m n) '(2 3)))))
|
||||
|
||||
(begin "matrix-constructors.rkt"
|
||||
|
@ -32,47 +90,46 @@
|
|||
'identity-matrix
|
||||
(equal? (array->list* (identity-matrix 1)) '[[1]])
|
||||
(equal? (array->list* (identity-matrix 2)) '[[1 0] [0 1]])
|
||||
(equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]])
|
||||
(equal? (array->list* (flidentity-matrix 1)) '[[1.]])
|
||||
(equal? (array->list* (flidentity-matrix 2)) '[[1. 0.] [0. 1.]])
|
||||
(equal? (array->list* (flidentity-matrix 3)) '[[1. 0. 0.] [0. 1. 0.] [0. 0. 1.]]))
|
||||
(equal? (array->list* (identity-matrix 3)) '[[1 0 0] [0 1 0] [0 0 1]]))
|
||||
(list
|
||||
'const-matrix
|
||||
(equal? (array->list* (make-matrix 2 3 0)) '((0 0 0) (0 0 0)))
|
||||
(equal? (array->list* (make-matrix 2 3 0.)) '((0. 0. 0.) (0. 0. 0.))))
|
||||
(list
|
||||
'matrix->list
|
||||
(equal? (matrix->list (list->matrix '((1 2) (3 4)))) '((1 2) (3 4)))
|
||||
(equal? (matrix->list (fllist->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.))))
|
||||
(equal? (matrix->list* (list*->matrix '((1 2) (3 4)))) '((1 2) (3 4)))
|
||||
(equal? (matrix->list* (list*->matrix '((1. 2.) (3. 4.)))) '((1. 2.) (3. 4.))))
|
||||
(list
|
||||
'matrix->vector
|
||||
(equal? (matrix->vector (vector->matrix '#(#(1 2) #(3 4)))) '#(#(1 2) #(3 4)))
|
||||
(equal? (matrix->vector (flvector->matrix '#(#(1. 2.) #(3. 4.)))) '#(#(1. 2.) #(3. 4.))))
|
||||
(equal? (matrix->vector* ((inst vector*->matrix Integer) '#(#(1 2) #(3 4))))
|
||||
'#(#(1 2) #(3 4)))
|
||||
(equal? (matrix->vector* ((inst vector*->matrix Flonum) '#(#(1. 2.) #(3. 4.))))
|
||||
'#(#(1. 2.) #(3. 4.))))
|
||||
(list
|
||||
'matrix-row
|
||||
(equal? (matrix-row (identity-matrix 3) 0) (list->matrix '[[1 0 0]]))
|
||||
(equal? (matrix-row (identity-matrix 3) 1) (list->matrix '[[0 1 0]]))
|
||||
(equal? (matrix-row (identity-matrix 3) 2) (list->matrix '[[0 0 1]])))
|
||||
(equal? (matrix-row (identity-matrix 3) 0) (list*->matrix '[[1 0 0]]))
|
||||
(equal? (matrix-row (identity-matrix 3) 1) (list*->matrix '[[0 1 0]]))
|
||||
(equal? (matrix-row (identity-matrix 3) 2) (list*->matrix '[[0 0 1]])))
|
||||
(list
|
||||
'matrix-col
|
||||
(equal? (matrix-column (identity-matrix 3) 0) (list->matrix '[[1] [0] [0]]))
|
||||
(equal? (matrix-column (identity-matrix 3) 1) (list->matrix '[[0] [1] [0]]))
|
||||
(equal? (matrix-column (identity-matrix 3) 2) (list->matrix '[[0] [0] [1]])))
|
||||
(equal? (matrix-col (identity-matrix 3) 0) (list*->matrix '[[1] [0] [0]]))
|
||||
(equal? (matrix-col (identity-matrix 3) 1) (list*->matrix '[[0] [1] [0]]))
|
||||
(equal? (matrix-col (identity-matrix 3) 2) (list*->matrix '[[0] [0] [1]])))
|
||||
(list
|
||||
'submatrix
|
||||
(equal? (submatrix (identity-matrix 3)
|
||||
(in-range 0 1) (in-range 0 2)) (list->matrix '[[1 0]]))
|
||||
(in-range 0 1) (in-range 0 2)) (list*->matrix '[[1 0]]))
|
||||
(equal? (submatrix (identity-matrix 3)
|
||||
(in-range 0 2) (in-range 0 3)) (list->matrix '[[1 0 0] [0 1 0]]))))
|
||||
(in-range 0 2) (in-range 0 3)) (list*->matrix '[[1 0 0] [0 1 0]]))))
|
||||
|
||||
(begin
|
||||
"matrix-pointwise.rkt"
|
||||
(let ()
|
||||
(define A (list->matrix '[[1 2] [3 4]]))
|
||||
(define ~A (list->matrix '[[-1 -2] [-3 -4]]))
|
||||
(define B (list->matrix '[[5 6] [7 8]]))
|
||||
(define A+B (list->matrix '[[6 8] [10 12]]))
|
||||
(define A-B (list->matrix '[[-4 -4] [-4 -4]]))
|
||||
(define A (list*->matrix '[[1 2] [3 4]]))
|
||||
(define ~A (list*->matrix '[[-1 -2] [-3 -4]]))
|
||||
(define B (list*->matrix '[[5 6] [7 8]]))
|
||||
(define A+B (list*->matrix '[[6 8] [10 12]]))
|
||||
(define A-B (list*->matrix '[[-4 -4] [-4 -4]]))
|
||||
(list 'matrix+ (equal? (matrix+ A B) A+B))
|
||||
(list 'matrix-
|
||||
(equal? (matrix- A B) A-B)
|
||||
|
@ -81,102 +138,102 @@
|
|||
(begin
|
||||
"matrix-expt.rkt"
|
||||
(let ()
|
||||
(define A (list->matrix '[[1 2] [3 4]]))
|
||||
(define A (list*->matrix '[[1 2] [3 4]]))
|
||||
(list
|
||||
'matrix-expt
|
||||
(equal? (matrix-expt A 0) (identity-matrix 2))
|
||||
(equal? (matrix-expt A 1) A)
|
||||
(equal? (matrix-expt A 2) (list->matrix '[[7 10] [15 22]]))
|
||||
(equal? (matrix-expt A 3) (list->matrix '[[37 54] [81 118]]))
|
||||
(equal? (matrix-expt A 8) (list->matrix '[[165751 241570] [362355 528106]]))))
|
||||
#;(list
|
||||
(define A (fllist->matrix '[[1. 2.] [3. 4.]]))
|
||||
(check-equal? (matrix->list (flmatrix-expt A 0)) (matrix->list (flidentity-matrix 2)))
|
||||
(check-equal? (matrix->list (flmatrix-expt A 1)) (matrix->list A))
|
||||
(check-equal? (matrix->list (flmatrix-expt A 2)) '[[7. 10.] [15. 22.]])
|
||||
(check-equal? (matrix->list (flmatrix-expt A 3)) '[[37. 54.] [81. 118.]])
|
||||
(check-equal? (matrix->list (flmatrix-expt A 8)) '[[165751. 241570.] [362355. 528106.]])))
|
||||
(equal? (matrix-expt A 2) (list*->matrix '[[7 10] [15 22]]))
|
||||
(equal? (matrix-expt A 3) (list*->matrix '[[37 54] [81 118]]))
|
||||
(equal? (matrix-expt A 8) (list*->matrix '[[165751 241570] [362355 528106]])))))
|
||||
|
||||
(begin
|
||||
"matrix-operations.rkt"
|
||||
(list 'vandermonde-matrix
|
||||
(equal? (vandermonde-matrix '(1 2 3) 5)
|
||||
(list->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]])))
|
||||
(list*->matrix '[[1 1 1 1 1] [1 2 4 8 16] [1 3 9 27 81]])))
|
||||
#;
|
||||
(list 'in-column
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix/dim 2 2 1 2 3 4) 0)]) x)
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 0)])
|
||||
x)
|
||||
'(1 3))
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix/dim 2 2 1 2 3 4) 1)]) x)
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-column (matrix [[1 2] [3 4]]) 1)])
|
||||
x)
|
||||
'(2 4))
|
||||
(equal? (for/list: : (Listof Number) ([x (in-column (column 5 2 3))]) x)
|
||||
(equal? (for/list: : (Listof Number) ([x (in-column (col-matrix [5 2 3]))]) x)
|
||||
'(5 2 3)))
|
||||
#;
|
||||
(list 'in-row
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix/dim 2 2 1 2 3 4) 0)]) x)
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 0)])
|
||||
x)
|
||||
'(1 2))
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix/dim 2 2 1 2 3 4) 1)]) x)
|
||||
(equal? (for/list: : (Listof Number) ([x : Number (in-row (matrix [[1 2] [3 4]]) 1)])
|
||||
x)
|
||||
'(3 4)))
|
||||
(list 'for/matrix:
|
||||
(equal? (for/matrix: : Number 2 4 ([i (in-naturals)]) i)
|
||||
(matrix/dim 2 4
|
||||
0 1 2 3
|
||||
4 5 6 7))
|
||||
(matrix [[0 1 2 3] [4 5 6 7]]))
|
||||
(equal? (for/matrix: : Number 2 4 #:column ([i (in-naturals)]) i)
|
||||
(matrix/dim 2 4
|
||||
0 2 4 6
|
||||
1 3 5 7))
|
||||
(matrix [[0 2 4 6] [1 3 5 7]]))
|
||||
(equal? (for/matrix: : Number 3 3 ([i (in-range 10 100)]) i)
|
||||
(matrix/dim 3 3 10 11 12 13 14 15 16 17 18)))
|
||||
(matrix [[10 11 12] [13 14 15] [16 17 18]])))
|
||||
(list 'for*/matrix:
|
||||
(equal? (for*/matrix: : Number 3 3 ([i (in-range 3)] [j (in-range 3)]) (+ (* i 10) j))
|
||||
(matrix/dim 3 3 0 1 2 10 11 12 20 21 22)))
|
||||
(matrix [[0 1 2] [10 11 12] [20 21 22]])))
|
||||
(list 'matrix-block-diagonal
|
||||
(equal? (matrix-block-diagonal (list (matrix/dim 2 2 1 2 3 4) (matrix/dim 1 3 5 6 7)))
|
||||
(list->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]])))
|
||||
(equal? (block-diagonal-matrix (list (matrix [[1 2] [3 4]]) (matrix [[5 6 7]])))
|
||||
(list*->matrix '[[1 2 0 0 0] [3 4 0 0 0] [0 0 5 6 7]])))
|
||||
(list 'matrix-augment
|
||||
(equal? (matrix-augment (column 1 2 3) (column 4 5 6) (column 7 8 9))
|
||||
(matrix/dim 3 3 1 4 7 2 5 8 3 6 9)))
|
||||
(equal? (matrix-augment (list (col-matrix [1 2 3])
|
||||
(col-matrix [4 5 6])
|
||||
(col-matrix [7 8 9])))
|
||||
(matrix [[1 4 7] [2 5 8] [3 6 9]])))
|
||||
(list 'matrix-stack
|
||||
(equal? (matrix-stack (column 1 2 3) (column 4 5 6) (column 7 8 9))
|
||||
(column 1 2 3 4 5 6 7 8 9)))
|
||||
(equal? (matrix-stack (list (col-matrix [1 2 3])
|
||||
(col-matrix [4 5 6])
|
||||
(col-matrix [7 8 9])))
|
||||
(col-matrix [1 2 3 4 5 6 7 8 9])))
|
||||
#;
|
||||
(list 'column-dimension
|
||||
(= (column-dimension #(1 2 3)) 3)
|
||||
(= (column-dimension (flat-vector->matrix 1 2 #(1 2))) 1))
|
||||
(let ([matrix: flat-vector->matrix])
|
||||
(= (column-dimension (vector->matrix 1 2 #(1 2))) 1))
|
||||
(let ([matrix: vector->matrix])
|
||||
(list 'column-dot
|
||||
(= (column-dot (column 1 2) (column 1 2)) 5)
|
||||
(= (column-dot (column 1 2) (column 3 4)) 11)
|
||||
(= (column-dot (column 3 4) (column 3 4)) 25)
|
||||
(= (column-dot (column 1 2 3) (column 4 5 6))
|
||||
(= (column-dot (col-matrix [1 2]) (col-matrix [1 2])) 5)
|
||||
(= (column-dot (col-matrix [1 2]) (col-matrix [3 4])) 11)
|
||||
(= (column-dot (col-matrix [3 4]) (col-matrix [3 4])) 25)
|
||||
(= (column-dot (col-matrix [1 2 3]) (col-matrix [4 5 6]))
|
||||
(+ (* 1 4) (* 2 5) (* 3 6)))
|
||||
(= (column-dot (column +3i +4i) (column +3i +4i))
|
||||
(= (column-dot (col-matrix [+3i +4i]) (col-matrix [+3i +4i]))
|
||||
25)))
|
||||
(list 'matrix-trace
|
||||
(equal? (matrix-trace (flat-vector->matrix 2 2 #(1 2 3 4))) 5))
|
||||
(let ([matrix: flat-vector->matrix])
|
||||
(equal? (matrix-trace (vector->matrix 2 2 #(1 2 3 4))) 5))
|
||||
(let ([matrix: vector->matrix])
|
||||
(list 'column-norm
|
||||
(= (column-norm (column 2 4)) (sqrt 20))))
|
||||
(list 'column-projection
|
||||
(equal? (column-projection #(1 2 3) #(4 5 6)) (column 128/77 160/77 192/77))
|
||||
(equal? (column-projection (column 1 2 3) (column 2 4 3))
|
||||
(matrix-scale 19/29 (column 2 4 3))))
|
||||
(= (column-norm (col-matrix [2 4])) (sqrt 20))))
|
||||
(list 'column-project
|
||||
(equal? (column-project #(1 2 3) #(4 5 6)) (col-matrix [128/77 160/77 192/77]))
|
||||
(equal? (column-project (col-matrix [1 2 3]) (col-matrix [2 4 3]))
|
||||
(matrix-scale (col-matrix [2 4 3]) 19/29)))
|
||||
(list 'projection-on-orthogonal-basis
|
||||
(equal? (projection-on-orthogonal-basis #(3 -2 2) (list #(-1 0 2) #( 2 5 1)))
|
||||
(column -1/3 -1/3 1/3))
|
||||
(col-matrix [-1/3 -1/3 1/3]))
|
||||
(equal? (projection-on-orthogonal-basis
|
||||
(column 3 -2 2) (list #(-1 0 2) (column 2 5 1)))
|
||||
(column -1/3 -1/3 1/3)))
|
||||
(col-matrix [3 -2 2]) (list #(-1 0 2) (col-matrix [2 5 1])))
|
||||
(col-matrix [-1/3 -1/3 1/3])))
|
||||
(list 'projection-on-orthonormal-basis
|
||||
(equal? (projection-on-orthonormal-basis
|
||||
#(1 2 3 4)
|
||||
(list (matrix-scale 1/2 (column 1 1 1 1))
|
||||
(matrix-scale 1/2 (column -1 1 -1 1))
|
||||
(matrix-scale 1/2 (column 1 -1 -1 1))))
|
||||
(column 2 3 2 3)))
|
||||
(list (matrix-scale (col-matrix [ 1 1 1 1]) 1/2)
|
||||
(matrix-scale (col-matrix [-1 1 -1 1]) 1/2)
|
||||
(matrix-scale (col-matrix [ 1 -1 -1 1]) 1/2)))
|
||||
(col-matrix [2 3 2 3])))
|
||||
(list 'gram-schmidt-orthogonal
|
||||
(equal? (gram-schmidt-orthogonal (list #(3 1) #(2 2)))
|
||||
(list (column 3 1) (column -2/5 6/5))))
|
||||
(list (col-matrix [3 1]) (col-matrix [-2/5 6/5]))))
|
||||
(list 'vector-normalize
|
||||
(equal? (column-normalize #(3 4))
|
||||
(column 3/5 4/5)))
|
||||
(col-matrix [3/5 4/5])))
|
||||
(list 'gram-schmidt-orthonormal
|
||||
(equal? (gram-schmidt-orthonormal (ann '(#(3 1) #(2 2)) (Listof (Column Number))))
|
||||
(list (column-normalize #(3 1))
|
||||
|
@ -184,74 +241,81 @@
|
|||
|
||||
(list 'projection-on-subspace
|
||||
(equal? (projection-on-subspace #(1 2 3) '(#(2 4 3)))
|
||||
(matrix-scale 19/29 (column 2 4 3))))
|
||||
(matrix-scale (col-matrix [2 4 3]) 19/29)))
|
||||
(list 'unit-vector
|
||||
(equal? (unit-column 4 1) (column 0 1 0 0)))
|
||||
(equal? (unit-column 4 1) (col-matrix [0 1 0 0])))
|
||||
(list 'matrix-qr
|
||||
(let-values ([(Q R) (matrix-qr (matrix/dim 3 2 1 1 0 1 1 1))])
|
||||
(let-values ([(Q R) (matrix-qr (matrix [[1 1] [0 1] [1 1]]))])
|
||||
(equal? (list Q R)
|
||||
(list (matrix/dim 3 2 0.7071067811865475 0 0 1 0.7071067811865475 0)
|
||||
(matrix/dim 2 2 1.414213562373095 1.414213562373095 0 1))))
|
||||
(list (matrix [[0.7071067811865475 0]
|
||||
[0 1]
|
||||
[0.7071067811865475 0]])
|
||||
(matrix [[1.414213562373095 1.414213562373095]
|
||||
[0 1]]))))
|
||||
(let ()
|
||||
(define A (matrix/dim 4 4 1 2 3 4 1 2 4 5 1 2 5 6 1 2 6 7))
|
||||
(define A (matrix [[1 2 3 4] [1 2 4 5] [1 2 5 6] [1 2 6 7]]))
|
||||
(define-values (Q R) (matrix-qr A))
|
||||
(equal? (list Q R)
|
||||
(list
|
||||
(flat-vector->matrix
|
||||
4 4 (ann #(1/2 -0.6708203932499369 0.5477225575051662 -0.0
|
||||
(vector->matrix
|
||||
4 4 ((inst vector Number)
|
||||
1/2 -0.6708203932499369 0.5477225575051662 -0.0
|
||||
1/2 -0.22360679774997896 -0.7302967433402214 0.4082482904638629
|
||||
1/2 0.22360679774997896 -0.18257418583505536 -0.8164965809277259
|
||||
1/2 0.6708203932499369 0.3651483716701107 0.408248290463863)
|
||||
(Vectorof Number)))
|
||||
(flat-vector->matrix
|
||||
4 4 (ann #(2 4 9 11 0 0.0 2.23606797749979 2.23606797749979
|
||||
0 0 0.0 4.440892098500626e-16 0 0 0 0.0)
|
||||
(Vectorof Number)))))))
|
||||
1/2 0.6708203932499369 0.3651483716701107 0.408248290463863))
|
||||
(vector->matrix
|
||||
4 4 ((inst vector Number)
|
||||
2 4 9 11 0 0.0 2.23606797749979 2.23606797749979
|
||||
0 0 0.0 4.440892098500626e-16 0 0 0 0.0))))))
|
||||
(list 'matrix-solve
|
||||
(let* ([M (list->matrix '[[1 5] [2 3]])]
|
||||
[b (list->matrix '[[5] [5]])])
|
||||
(let* ([M (list*->matrix '[[1 5] [2 3]])]
|
||||
[b (list*->matrix '[[5] [5]])])
|
||||
(equal? (matrix* M (matrix-solve M b)) b)))
|
||||
(list 'matrix-inverse
|
||||
(equal? (let ([M (list->matrix '[[1 2] [3 4]])]) (matrix* M (matrix-inverse M)))
|
||||
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* M (matrix-inverse M)))
|
||||
(identity-matrix 2))
|
||||
(equal? (let ([M (list->matrix '[[1 2] [3 4]])]) (matrix* (matrix-inverse M) M))
|
||||
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) (matrix* (matrix-inverse M) M))
|
||||
(identity-matrix 2)))
|
||||
(list 'matrix-determinant
|
||||
(equal? (matrix-determinant (list->matrix '[[3]])) 3)
|
||||
(equal? (matrix-determinant (list->matrix '[[1 2] [3 4]])) (- (* 1 4) (* 2 3)))
|
||||
(equal? (matrix-determinant (list->matrix '[[1 2 3] [4 5 6] [7 8 9]])) 0)
|
||||
(equal? (matrix-determinant (list->matrix '[[1 2 3] [4 -5 6] [7 8 9]])) 120)
|
||||
(equal? (matrix-determinant (list->matrix '[[1 2 3 4] [-5 6 7 8] [9 10 -11 12] [13 14 15 16]])) 5280))
|
||||
(equal? (matrix-determinant (list*->matrix '[[3]])) 3)
|
||||
(equal? (matrix-determinant (list*->matrix '[[1 2] [3 4]])) (- (* 1 4) (* 2 3)))
|
||||
(equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 5 6] [7 8 9]])) 0)
|
||||
(equal? (matrix-determinant (list*->matrix '[[1 2 3] [4 -5 6] [7 8 9]])) 120)
|
||||
(equal? (matrix-determinant (list*->matrix '[[1 2 3 4]
|
||||
[-5 6 7 8]
|
||||
[9 10 -11 12]
|
||||
[13 14 15 16]]))
|
||||
5280))
|
||||
(list 'matrix-scale
|
||||
(equal? (matrix-scale 2 (list->matrix '[[1 2] [3 4]]))
|
||||
(list->matrix '[[2 4] [6 8]])))
|
||||
(equal? (matrix-scale (list*->matrix '[[1 2] [3 4]]) 2)
|
||||
(list*->matrix '[[2 4] [6 8]])))
|
||||
(list 'matrix-transpose
|
||||
(equal? (matrix-transpose (list->matrix '[[1 2] [3 4]]))
|
||||
(list->matrix '[[1 3] [2 4]])))
|
||||
(equal? (matrix-transpose (list*->matrix '[[1 2] [3 4]]))
|
||||
(list*->matrix '[[1 3] [2 4]])))
|
||||
(list 'matrix-hermitian
|
||||
(equal? (matrix-hermitian (list->matrix '[[1+i 2-i] [3+i 4-i]]))
|
||||
(list->matrix '[[1-i 3-i] [2+i 4+i]])))
|
||||
(equal? (matrix-hermitian (list*->matrix '[[1+i 2-i] [3+i 4-i]]))
|
||||
(list*->matrix '[[1-i 3-i] [2+i 4+i]])))
|
||||
(let ()
|
||||
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
|
||||
(define (gauss-eliminate M u? p?)
|
||||
(let-values ([(M wp) (matrix-gauss-eliminate M u? p?)])
|
||||
M))
|
||||
(list 'matrix-gauss-eliminate
|
||||
(equal? (let ([M (list->matrix '[[1 2] [3 4]])])
|
||||
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])])
|
||||
(gauss-eliminate M #f #f))
|
||||
(list->matrix '[[1 2] [0 -2]]))
|
||||
(equal? (let ([M (list->matrix '[[2 4] [3 4]])])
|
||||
(list*->matrix '[[1 2] [0 -2]]))
|
||||
(equal? (let ([M (list*->matrix '[[2 4] [3 4]])])
|
||||
(gauss-eliminate M #t #f))
|
||||
(list->matrix '[[1 2] [0 1]]))
|
||||
(equal? (let ([M (list->matrix '[[2. 4.] [3. 4.]])])
|
||||
(list*->matrix '[[1 2] [0 1]]))
|
||||
(equal? (let ([M (list*->matrix '[[2. 4.] [3. 4.]])])
|
||||
(gauss-eliminate M #t #t))
|
||||
(list->matrix '[[1. 1.3333333333333333] [0. 1.]]))
|
||||
(equal? (let ([M (list->matrix '[[1 4] [2 4]])])
|
||||
(list*->matrix '[[1. 1.3333333333333333] [0. 1.]]))
|
||||
(equal? (let ([M (list*->matrix '[[1 4] [2 4]])])
|
||||
(gauss-eliminate M #t #t))
|
||||
(list->matrix '[[1 2] [0 1]]))
|
||||
(equal? (let ([M (list->matrix '[[1 2] [2 4]])])
|
||||
(list*->matrix '[[1 2] [0 1]]))
|
||||
(equal? (let ([M (list*->matrix '[[1 2] [2 4]])])
|
||||
(gauss-eliminate M #f #t))
|
||||
(list->matrix '[[2 4] [0 0]]))))
|
||||
(list*->matrix '[[2 4] [0 0]]))))
|
||||
(list
|
||||
'matrix-scale-row
|
||||
(equal? (matrix-scale-row (identity-matrix 3) 0 2)
|
||||
|
@ -265,7 +329,7 @@
|
|||
(equal? (matrix-add-scaled-row (list*->array '[[1 2 3] [4 5 6] [7 8 9]] real? ) 0 2 1)
|
||||
(list*->array '[[9 12 15] [4 5 6] [7 8 9]] real? )))
|
||||
(let ()
|
||||
(define M (list->matrix '[[1 1 0 3]
|
||||
(define M (list*->matrix '[[1 1 0 3]
|
||||
[2 1 -1 1]
|
||||
[3 -1 -1 2]
|
||||
[-1 2 3 -1]]))
|
||||
|
@ -277,12 +341,12 @@
|
|||
(define V (if (list? LU) (second LU) #f))
|
||||
(list
|
||||
'matrix-lu
|
||||
(equal? L (list->matrix
|
||||
(equal? L (list*->matrix
|
||||
'[[1 0 0 0]
|
||||
[2 1 0 0]
|
||||
[3 4 1 0]
|
||||
[-1 -3 0 1]]))
|
||||
(equal? V (list->matrix
|
||||
(equal? V (list*->matrix
|
||||
'[[1 1 0 3]
|
||||
[0 -1 -1 -5]
|
||||
[0 0 3 13]
|
||||
|
@ -290,34 +354,34 @@
|
|||
(equal? (matrix* L V) M)))))
|
||||
(list
|
||||
'matrix-rank
|
||||
(equal? (matrix-rank (list->matrix '[[0 0] [0 0]])) 0)
|
||||
(equal? (matrix-rank (list->matrix '[[1 0] [0 0]])) 1)
|
||||
(equal? (matrix-rank (list->matrix '[[1 0] [0 3]])) 2)
|
||||
(equal? (matrix-rank (list->matrix '[[1 2] [2 4]])) 1)
|
||||
(equal? (matrix-rank (list->matrix '[[1 2] [3 4]])) 2))
|
||||
(equal? (matrix-rank (list*->matrix '[[0 0] [0 0]])) 0)
|
||||
(equal? (matrix-rank (list*->matrix '[[1 0] [0 0]])) 1)
|
||||
(equal? (matrix-rank (list*->matrix '[[1 0] [0 3]])) 2)
|
||||
(equal? (matrix-rank (list*->matrix '[[1 2] [2 4]])) 1)
|
||||
(equal? (matrix-rank (list*->matrix '[[1 2] [3 4]])) 2))
|
||||
(list
|
||||
'matrix-nullity
|
||||
(equal? (matrix-nullity (list->matrix '[[0 0] [0 0]])) 2)
|
||||
(equal? (matrix-nullity (list->matrix '[[1 0] [0 0]])) 1)
|
||||
(equal? (matrix-nullity (list->matrix '[[1 0] [0 3]])) 0)
|
||||
(equal? (matrix-nullity (list->matrix '[[1 2] [2 4]])) 1)
|
||||
(equal? (matrix-nullity (list->matrix '[[1 2] [3 4]])) 0))
|
||||
(equal? (matrix-nullity (list*->matrix '[[0 0] [0 0]])) 2)
|
||||
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 0]])) 1)
|
||||
(equal? (matrix-nullity (list*->matrix '[[1 0] [0 3]])) 0)
|
||||
(equal? (matrix-nullity (list*->matrix '[[1 2] [2 4]])) 1)
|
||||
(equal? (matrix-nullity (list*->matrix '[[1 2] [3 4]])) 0))
|
||||
#;(let ()
|
||||
(define-values (c1 n1)
|
||||
(matrix-column+null-space (list->matrix '[[0 0] [0 0]])))
|
||||
(matrix-column+null-space (list*rix '[[0 0] [0 0]])))
|
||||
(define-values (c2 n2)
|
||||
(matrix-column+null-space (list->matrix '[[1 2] [2 4]])))
|
||||
(matrix-column+null-space (list*->matrix '[[1 2] [2 4]])))
|
||||
(define-values (c3 n3)
|
||||
(matrix-column+null-space (list->matrix '[[1 2] [2 5]])))
|
||||
(matrix-column+null-space (list*atrix '[[1 2] [2 5]])))
|
||||
(list
|
||||
'matrix-column+null-space
|
||||
(equal? c1 '())
|
||||
(equal? n1 (list (list->matrix '[[0] [0]])
|
||||
(list->matrix '[[0] [0]])))
|
||||
(equal? c2 (list (list->matrix '[[1] [2]])))
|
||||
(equal? n1 (list (list*->matrix '[[0] [0]])
|
||||
(list*->matrix '[[0] [0]])))
|
||||
(equal? c2 (list (list*->matrix '[[1] [2]])))
|
||||
;(equal? n2 '([0 0]))
|
||||
(equal? c3 (list (list->matrix '[[1] [2]])
|
||||
(list->matrix '[[2] [5]])))
|
||||
(equal? c3 (list (list*->matrix '[[1] [2]])
|
||||
(list*->matrix '[[2] [5]])))
|
||||
(equal? n3 '()))))
|
||||
|
||||
|
||||
|
@ -331,10 +395,12 @@
|
|||
(list 'matrix*
|
||||
(let ()
|
||||
(define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6] [7 8]] '[[19 22] [43 50]]))
|
||||
(equal? (matrix* (list->matrix A) (list->matrix B)) (list->matrix AB)))
|
||||
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))
|
||||
(let ()
|
||||
(define-values (A B AB) (values '[[1 2] [3 4]] '[[5 6 7] [8 9 10]] '[[21 24 27] [47 54 61]]))
|
||||
(equal? (matrix* (list->matrix A) (list->matrix B)) (list->matrix AB)))))
|
||||
(define-values (A B AB) (values '[[1 2] [3 4]]
|
||||
'[[5 6 7] [8 9 10]]
|
||||
'[[21 24 27] [47 54 61]]))
|
||||
(equal? (matrix* (list*->matrix A) (list*->matrix B)) (list*->matrix AB)))))
|
||||
#;(begin
|
||||
"matrix-2d.rkt"
|
||||
(let ()
|
||||
|
@ -343,10 +409,10 @@
|
|||
(define -e1 (matrix-transpose (vector->matrix #(#(-1 0)))))
|
||||
(define -e2 (matrix-transpose (vector->matrix #(#( 0 -1)))))
|
||||
(define O (matrix-transpose (vector->matrix #(#( 0 0)))))
|
||||
(define 2*e1 (matrix-scale 2 e1))
|
||||
(define 4*e1 (matrix-scale 4 e1))
|
||||
(define 3*e2 (matrix-scale 3 e2))
|
||||
(define 4*e2 (matrix-scale 4 e2))
|
||||
(define 2*e1 (matrix-scale e1 2))
|
||||
(define 4*e1 (matrix-scale e1 4))
|
||||
(define 3*e2 (matrix-scale e2 3))
|
||||
(define 4*e2 (matrix-scale e2 4))
|
||||
(begin
|
||||
(list 'matrix-2d-rotation
|
||||
(<= (matrix-norm (matrix- (matrix* (matrix-2d-rotation (/ pi 2)) e1) e2 )) epsilon.0)
|
||||
|
|
11
collects/math/tests/test-utils.rkt
Normal file
11
collects/math/tests/test-utils.rkt
Normal file
|
@ -0,0 +1,11 @@
|
|||
#lang typed/racket
|
||||
|
||||
(require (except-in typed/rackunit check-equal?))
|
||||
|
||||
(provide (all-from-out typed/rackunit)
|
||||
check-equal?)
|
||||
|
||||
;; This gets around the fact that typed/rackunit can no longer test higher-order values for equality,
|
||||
;; since TR has firmed up its rules on passing `Any' types in and out of untyped code
|
||||
(define-syntax-rule (check-equal? a b . message)
|
||||
(check-true (equal? a b) . message))
|
Loading…
Reference in New Issue
Block a user