
The fix consists of three parts: 1. Rewriting `inline-matrix*'. The material change here is that the expansion now contains only direct applications of `+' and `*'. TR's optimizer replaces them with `unsafe-fx+' and `unsafe-fx*', which keeps intermediate flonum values from being boxed. 2. Making the types of all functions that operate on (Matrix Number) values more precise. Now TR can prove that matrix operations preserve inexactness. For example, matrix-conjugate : (Matrix Flonum) -> (Matrix Flonum) and three other cases for Real, Float-Complex, and Number. 3. Changing the return types of some functions that used to return things like (Matrix (U A 0)). Now that we worry about preserving inexactness, we can't have `matrix-upper-triangle' always return a matrix that contains exact zeros. It now accepts an optional `zero' argument of type A.
174 lines
7.0 KiB
Racket
174 lines
7.0 KiB
Racket
#lang typed/racket/base
|
|
|
|
(require racket/fixnum
|
|
racket/flonum
|
|
racket/list
|
|
racket/vector
|
|
math/base
|
|
"matrix-types.rkt"
|
|
"../unsafe.rkt"
|
|
"../array/array-struct.rkt"
|
|
"../array/array-constructors.rkt"
|
|
"../array/array-unfold.rkt"
|
|
"../array/utils.rkt")
|
|
|
|
(provide identity-matrix
|
|
make-matrix
|
|
build-matrix
|
|
diagonal-matrix/zero
|
|
diagonal-matrix
|
|
block-diagonal-matrix/zero
|
|
block-diagonal-matrix
|
|
vandermonde-matrix)
|
|
|
|
;; ===================================================================================================
|
|
;; Basic constructors
|
|
|
|
(: identity-matrix (All (A) (case-> (Integer -> (Matrix (U 1 0)))
|
|
(Integer A -> (Matrix (U A 0)))
|
|
(Integer A A -> (Matrix A)))))
|
|
(define identity-matrix
|
|
(case-lambda
|
|
[(m) (diagonal-array 2 m 1 0)]
|
|
[(m one) (diagonal-array 2 m one 0)]
|
|
[(m one zero) (diagonal-array 2 m one zero)]))
|
|
|
|
(: make-matrix (All (A) (Integer Integer A -> (Matrix A))))
|
|
(define (make-matrix m n x)
|
|
(make-array (vector m n) x))
|
|
|
|
(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A))))
|
|
(define (build-matrix m n proc)
|
|
(cond [(or (not (index? m)) (= m 0))
|
|
(raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)]
|
|
[(or (not (index? n)) (= n 0))
|
|
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
|
|
[else
|
|
(array-default-strict
|
|
(unsafe-build-array
|
|
((inst vector Index) m n)
|
|
(λ: ([js : Indexes])
|
|
(proc (unsafe-vector-ref js 0)
|
|
(unsafe-vector-ref js 1)))))]))
|
|
|
|
;; ===================================================================================================
|
|
;; Diagonal matrices
|
|
|
|
(: diagonal-matrix/zero (All (A) ((Listof A) A -> (Matrix A))))
|
|
(define (diagonal-matrix/zero xs zero)
|
|
(cond [(empty? xs)
|
|
(raise-argument-error 'diagonal-matrix "nonempty List" xs)]
|
|
[else
|
|
(define vs (list->vector xs))
|
|
(define m (vector-length vs))
|
|
(unsafe-build-simple-array
|
|
((inst vector Index) m m)
|
|
(λ: ([js : Indexes])
|
|
(define i (unsafe-vector-ref js 0))
|
|
(cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)]
|
|
[else zero])))]))
|
|
|
|
(: diagonal-matrix (All (A) (case-> ((Listof A) -> (Matrix (U A 0)))
|
|
((Listof A) A -> (Matrix A)))))
|
|
(define diagonal-matrix
|
|
(case-lambda
|
|
[(xs) (diagonal-matrix/zero xs 0)]
|
|
[(xs zero) (diagonal-matrix/zero xs zero)]))
|
|
|
|
;; ===================================================================================================
|
|
;; Block diagonal matrices
|
|
|
|
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Matrix A)) A -> (Matrix A)))
|
|
(define (block-diagonal-matrix/zero* as zero)
|
|
(define num (vector-length as))
|
|
(define-values (ms ns)
|
|
(let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty]
|
|
[ns : (Listof Index) empty]
|
|
) ([a (in-vector as)])
|
|
(define-values (m n) (matrix-shape a))
|
|
(values (cons m ms) (cons n ns)))])
|
|
(values (reverse ms) (reverse ns))))
|
|
(define res-m (assert (apply + ms) index?))
|
|
(define res-n (assert (apply + ns) index?))
|
|
(define vs ((inst make-vector Index) res-m 0))
|
|
(define hs ((inst make-vector Index) res-n 0))
|
|
(define is ((inst make-vector Index) res-m 0))
|
|
(define js ((inst make-vector Index) res-n 0))
|
|
(define-values (_res-i _res-j)
|
|
(for/fold: ([res-i : Nonnegative-Fixnum 0]
|
|
[res-j : Nonnegative-Fixnum 0]
|
|
) ([m (in-list ms)]
|
|
[n (in-list ns)]
|
|
[k : Nonnegative-Fixnum (in-range num)])
|
|
(let ([k (assert k index?)])
|
|
(for: ([i : Nonnegative-Fixnum (in-range m)])
|
|
(vector-set! vs (unsafe-fx+ res-i i) k)
|
|
(vector-set! is (unsafe-fx+ res-i i) (assert i index?)))
|
|
(for: ([j : Nonnegative-Fixnum (in-range n)])
|
|
(vector-set! hs (unsafe-fx+ res-j j) k)
|
|
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
|
|
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
|
|
(define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as))
|
|
(array-default-strict
|
|
(unsafe-build-array
|
|
((inst vector Index) res-m res-n)
|
|
(λ: ([ij : Indexes])
|
|
(define i (unsafe-vector-ref ij 0))
|
|
(define j (unsafe-vector-ref ij 1))
|
|
(define v (unsafe-vector-ref vs i))
|
|
(cond [(fx= v (vector-ref hs j))
|
|
(define proc (unsafe-vector-ref procs v))
|
|
(define iv (unsafe-vector-ref is i))
|
|
(define jv (unsafe-vector-ref js j))
|
|
(unsafe-vector-set! ij 0 iv)
|
|
(unsafe-vector-set! ij 1 jv)
|
|
(define res (proc ij))
|
|
(unsafe-vector-set! ij 0 i)
|
|
(unsafe-vector-set! ij 1 j)
|
|
res]
|
|
[else
|
|
zero])))))
|
|
|
|
(: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A))))
|
|
(define (block-diagonal-matrix/zero as zero)
|
|
(let ([as (list->vector as)])
|
|
(define num (vector-length as))
|
|
(cond [(= num 0)
|
|
(raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)]
|
|
[(= num 1)
|
|
(unsafe-vector-ref as 0)]
|
|
[else
|
|
(block-diagonal-matrix/zero* as zero)])))
|
|
|
|
(: block-diagonal-matrix (All (A) (case-> ((Listof (Matrix A)) -> (Matrix (U A 0)))
|
|
((Listof (Matrix A)) A -> (Matrix A)))))
|
|
(define block-diagonal-matrix
|
|
(case-lambda
|
|
[(as) (block-diagonal-matrix/zero as 0)]
|
|
[(as zero) (block-diagonal-matrix/zero as zero)]))
|
|
|
|
;; ===================================================================================================
|
|
;; Special matrices
|
|
|
|
(: sane-expt (case-> (Flonum Index -> Flonum)
|
|
(Real Index -> Real)
|
|
(Float-Complex Index -> Float-Complex)
|
|
(Number Index -> Number)))
|
|
(define (sane-expt x n)
|
|
(cond [(flonum? x) (flexpt x (real->double-flonum n))]
|
|
[(real? x) (real-part (expt x n))] ; remove `real-part' when expt : Real Index -> Real
|
|
[(float-complex? x) (number->float-complex (expt x n))]
|
|
[else (expt x n)]))
|
|
|
|
(: vandermonde-matrix (case-> ((Listof Flonum) Integer -> (Matrix Flonum))
|
|
((Listof Real) Integer -> (Matrix Real))
|
|
((Listof Float-Complex) Integer -> (Matrix Float-Complex))
|
|
((Listof Number) Integer -> (Matrix Number))))
|
|
(define (vandermonde-matrix xs n)
|
|
(cond [(empty? xs)
|
|
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]
|
|
[(or (not (index? n)) (zero? n))
|
|
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
|
[else
|
|
(array-axis-expand (list->array xs) 1 n sane-expt)]))
|