
* Gram-Schmidt using vector type * QR decomposition * Operator 1-norm and maximum norm; stub for 2-norm and angle between subspaces (`matrix-basis-angle') * `matrix-absolute-error' and `matrix-relative-error'; also predicates based on them, such as `matrix-identity?' * Lots of shuffling code about * Types that can have contracts, and an exhaustive test to make sure every value exported by `math/matrix' has a contract when used in untyped code * Some more tests (still needs some)
153 lines
5.9 KiB
Racket
153 lines
5.9 KiB
Racket
#lang typed/racket/base
|
|
|
|
(require racket/fixnum
|
|
racket/list
|
|
racket/vector
|
|
"matrix-types.rkt"
|
|
"../unsafe.rkt"
|
|
"../array/array-struct.rkt"
|
|
"../array/array-constructors.rkt"
|
|
"../array/array-unfold.rkt"
|
|
"../array/utils.rkt")
|
|
|
|
(provide identity-matrix
|
|
make-matrix
|
|
build-matrix
|
|
diagonal-matrix/zero
|
|
diagonal-matrix
|
|
block-diagonal-matrix/zero
|
|
block-diagonal-matrix
|
|
vandermonde-matrix)
|
|
|
|
;; ===================================================================================================
|
|
;; Basic constructors
|
|
|
|
(: identity-matrix (Integer -> (Matrix (U 0 1))))
|
|
(define (identity-matrix m) (diagonal-array 2 m 1 0))
|
|
|
|
(: make-matrix (All (A) (Integer Integer A -> (Matrix A))))
|
|
(define (make-matrix m n x)
|
|
(make-array (vector m n) x))
|
|
|
|
(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A))))
|
|
(define (build-matrix m n proc)
|
|
(cond [(or (not (index? m)) (= m 0))
|
|
(raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)]
|
|
[(or (not (index? n)) (= n 0))
|
|
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
|
|
[else
|
|
(unsafe-build-array
|
|
((inst vector Index) m n)
|
|
(λ: ([js : Indexes])
|
|
(proc (unsafe-vector-ref js 0)
|
|
(unsafe-vector-ref js 1))))]))
|
|
|
|
;; ===================================================================================================
|
|
;; Diagonal matrices
|
|
|
|
(: diagonal-matrix/zero (All (A) ((Listof A) A -> (Matrix A))))
|
|
(define (diagonal-matrix/zero xs zero)
|
|
(cond [(empty? xs)
|
|
(raise-argument-error 'diagonal-matrix "nonempty List" xs)]
|
|
[else
|
|
(define vs (list->vector xs))
|
|
(define m (vector-length vs))
|
|
(unsafe-build-array
|
|
((inst vector Index) m m)
|
|
(λ: ([js : Indexes])
|
|
(define i (unsafe-vector-ref js 0))
|
|
(cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)]
|
|
[else zero])))]))
|
|
|
|
(: diagonal-matrix (All (A) ((Listof A) -> (Matrix (U A 0)))))
|
|
(define (diagonal-matrix xs)
|
|
(diagonal-matrix/zero xs 0))
|
|
|
|
;; ===================================================================================================
|
|
;; Block diagonal matrices
|
|
|
|
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Matrix A)) A -> (Matrix A)))
|
|
(define (block-diagonal-matrix/zero* as zero)
|
|
(define num (vector-length as))
|
|
(define-values (ms ns)
|
|
(let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty]
|
|
[ns : (Listof Index) empty]
|
|
) ([a (in-vector as)])
|
|
(define-values (m n) (matrix-shape a))
|
|
(values (cons m ms) (cons n ns)))])
|
|
(values (reverse ms) (reverse ns))))
|
|
(define res-m (assert (apply + ms) index?))
|
|
(define res-n (assert (apply + ns) index?))
|
|
(define vs ((inst make-vector Index) res-m 0))
|
|
(define hs ((inst make-vector Index) res-n 0))
|
|
(define is ((inst make-vector Index) res-m 0))
|
|
(define js ((inst make-vector Index) res-n 0))
|
|
(define-values (_res-i _res-j)
|
|
(for/fold: ([res-i : Nonnegative-Fixnum 0]
|
|
[res-j : Nonnegative-Fixnum 0]
|
|
) ([m (in-list ms)]
|
|
[n (in-list ns)]
|
|
[k : Nonnegative-Fixnum (in-range num)])
|
|
(let ([k (assert k index?)])
|
|
(for: ([i : Nonnegative-Fixnum (in-range m)])
|
|
(vector-set! vs (unsafe-fx+ res-i i) k)
|
|
(vector-set! is (unsafe-fx+ res-i i) (assert i index?)))
|
|
(for: ([j : Nonnegative-Fixnum (in-range n)])
|
|
(vector-set! hs (unsafe-fx+ res-j j) k)
|
|
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
|
|
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
|
|
(define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as))
|
|
(unsafe-build-array
|
|
((inst vector Index) res-m res-n)
|
|
(λ: ([ij : Indexes])
|
|
(define i (unsafe-vector-ref ij 0))
|
|
(define j (unsafe-vector-ref ij 1))
|
|
(define v (unsafe-vector-ref vs i))
|
|
(cond [(fx= v (vector-ref hs j))
|
|
(define proc (unsafe-vector-ref procs v))
|
|
(define iv (unsafe-vector-ref is i))
|
|
(define jv (unsafe-vector-ref js j))
|
|
(unsafe-vector-set! ij 0 iv)
|
|
(unsafe-vector-set! ij 1 jv)
|
|
(define res (proc ij))
|
|
(unsafe-vector-set! ij 0 i)
|
|
(unsafe-vector-set! ij 1 j)
|
|
res]
|
|
[else
|
|
zero]))))
|
|
|
|
(: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A))))
|
|
(define (block-diagonal-matrix/zero as zero)
|
|
(let ([as (list->vector as)])
|
|
(define num (vector-length as))
|
|
(cond [(= num 0)
|
|
(raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)]
|
|
[(= num 1)
|
|
(unsafe-vector-ref as 0)]
|
|
[else
|
|
(block-diagonal-matrix/zero* as zero)])))
|
|
|
|
(: block-diagonal-matrix (All (A) ((Listof (Matrix A)) -> (Matrix (U A 0)))))
|
|
(define (block-diagonal-matrix as)
|
|
(block-diagonal-matrix/zero as 0))
|
|
|
|
;; ===================================================================================================
|
|
;; Special matrices
|
|
|
|
(: expt-hack (case-> (Real Integer -> Real)
|
|
(Number Integer -> Number)))
|
|
;; Stop using this when TR correctly derives expt : Real Integer -> Real
|
|
(define (expt-hack x n)
|
|
(cond [(real? x) (assert (expt x n) real?)]
|
|
[else (expt x n)]))
|
|
|
|
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Matrix Real))
|
|
((Listof Number) Integer -> (Matrix Number))))
|
|
(define (vandermonde-matrix xs n)
|
|
(cond [(empty? xs)
|
|
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]
|
|
[(or (not (index? n)) (zero? n))
|
|
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
|
[else
|
|
(array-axis-expand (list->array xs) 1 n expt-hack)]))
|