
The fix consists of three parts: 1. Rewriting `inline-matrix*'. The material change here is that the expansion now contains only direct applications of `+' and `*'. TR's optimizer replaces them with `unsafe-fx+' and `unsafe-fx*', which keeps intermediate flonum values from being boxed. 2. Making the types of all functions that operate on (Matrix Number) values more precise. Now TR can prove that matrix operations preserve inexactness. For example, matrix-conjugate : (Matrix Flonum) -> (Matrix Flonum) and three other cases for Real, Float-Complex, and Number. 3. Changing the return types of some functions that used to return things like (Matrix (U A 0)). Now that we worry about preserving inexactness, we can't have `matrix-upper-triangle' always return a matrix that contains exact zeros. It now accepts an optional `zero' argument of type A.
183 lines
7.2 KiB
Racket
183 lines
7.2 KiB
Racket
#lang typed/racket/base
|
|
|
|
(require racket/performance-hint
|
|
racket/string
|
|
racket/fixnum
|
|
math/base
|
|
"matrix-types.rkt"
|
|
"../unsafe.rkt"
|
|
"../array/array-struct.rkt"
|
|
"../vector/vector-mutate.rkt")
|
|
|
|
(provide (all-defined-out))
|
|
|
|
(: format-matrices/error ((Listof (Array Any)) -> String))
|
|
(define (format-matrices/error as)
|
|
(string-join (map (λ: ([a : (Array Any)]) (format "~e" a)) as)))
|
|
|
|
(: matrix-shapes (Symbol (Matrix Any) (Matrix Any) * -> (Values Index Index)))
|
|
(define (matrix-shapes name arr . brrs)
|
|
(define-values (m n) (matrix-shape arr))
|
|
(unless (andmap (λ: ([brr : (Matrix Any)])
|
|
(define-values (bm bn) (matrix-shape brr))
|
|
(and (= bm m) (= bn n)))
|
|
brrs)
|
|
(error name
|
|
"matrices must have the same shape; given ~a"
|
|
(format-matrices/error (cons arr brrs))))
|
|
(values m n))
|
|
|
|
(: matrix-multiply-shape ((Matrix Any) (Matrix Any) -> (Values Index Index Index)))
|
|
(define (matrix-multiply-shape arr brr)
|
|
(define-values (ad0 ad1) (matrix-shape arr))
|
|
(define-values (bd0 bd1) (matrix-shape brr))
|
|
(unless (= ad1 bd0)
|
|
(error 'matrix-multiply
|
|
"1st argument column size and 2nd argument row size are not equal; given ~e and ~e"
|
|
arr brr))
|
|
(values ad0 ad1 bd1))
|
|
|
|
(: ensure-matrix (All (A) Symbol (Array A) -> (Array A)))
|
|
(define (ensure-matrix name a)
|
|
(if (matrix? a) a (raise-argument-error name "matrix?" a)))
|
|
|
|
(: ensure-row-matrix (All (A) Symbol (Array A) -> (Array A)))
|
|
(define (ensure-row-matrix name a)
|
|
(if (row-matrix? a) a (raise-argument-error name "row-matrix?" a)))
|
|
|
|
(: ensure-col-matrix (All (A) Symbol (Array A) -> (Array A)))
|
|
(define (ensure-col-matrix name a)
|
|
(if (col-matrix? a) a (raise-argument-error name "col-matrix?" a)))
|
|
|
|
(: sort/key (All (A B) (case-> ((Listof A) (B B -> Boolean) (A -> B) -> (Listof A))
|
|
((Listof A) (B B -> Boolean) (A -> B) Boolean -> (Listof A)))))
|
|
;; Sometimes necessary because TR can't do inference with keyword arguments yet
|
|
(define (sort/key lst lt? key [cache-keys? #f])
|
|
((inst sort A B) lst lt? #:key key #:cache-keys? cache-keys?))
|
|
|
|
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
|
|
(define (unsafe-vector2d-ref vss i j)
|
|
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
|
|
|
|
;; Note: this accepts +nan.0
|
|
(define nonnegative?
|
|
(λ: ([x : Real]) (not (x . < . 0))))
|
|
|
|
(define number-rational?
|
|
(λ: ([z : Number])
|
|
(cond [(real? z) (rational? z)]
|
|
[else (and (rational? (real-part z))
|
|
(rational? (imag-part z)))])))
|
|
|
|
(: find-partial-pivot
|
|
(case-> ((Vectorof (Vectorof Flonum)) Index Index Index -> (Values Index Flonum))
|
|
((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real))
|
|
((Vectorof (Vectorof Float-Complex)) Index Index Index -> (Values Index Float-Complex))
|
|
((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number))))
|
|
;; Find the element with maximum magnitude in a column
|
|
(define (find-partial-pivot rows m i j)
|
|
(define l (fx+ i 1))
|
|
(define pivot (unsafe-vector2d-ref rows i j))
|
|
(define mag-pivot (magnitude pivot))
|
|
(let loop ([#{l : Nonnegative-Fixnum} l] [#{p : Index} i] [pivot pivot] [mag-pivot mag-pivot])
|
|
(cond [(l . fx< . m)
|
|
(define new-pivot (unsafe-vector2d-ref rows l j))
|
|
(define mag-new-pivot (magnitude new-pivot))
|
|
(cond [(mag-new-pivot . > . mag-pivot) (loop (fx+ l 1) l new-pivot mag-new-pivot)]
|
|
[else (loop (fx+ l 1) p pivot mag-pivot)])]
|
|
[else (values p pivot)])))
|
|
|
|
(: find-first-pivot
|
|
(case-> ((Vectorof (Vectorof Flonum)) Index Index Index -> (Values Index Flonum))
|
|
((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real))
|
|
((Vectorof (Vectorof Float-Complex)) Index Index Index -> (Values Index Float-Complex))
|
|
((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number))))
|
|
;; Find the first nonzero element in a column
|
|
(define (find-first-pivot rows m i j)
|
|
(define pivot (unsafe-vector2d-ref rows i j))
|
|
(if ((magnitude pivot) . > . 0)
|
|
(values i pivot)
|
|
(let loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
|
|
(cond [(l . fx< . m)
|
|
(define pivot (unsafe-vector2d-ref rows l j))
|
|
(if ((magnitude pivot) . > . 0) (values l pivot) (loop (fx+ l 1)))]
|
|
[else
|
|
(values i pivot)]))))
|
|
|
|
(: elim-rows!
|
|
(case-> ((Vectorof (Vectorof Flonum)) Index Index Index Flonum Nonnegative-Fixnum -> Void)
|
|
((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void)
|
|
((Vectorof (Vectorof Float-Complex)) Index Index Index Float-Complex Nonnegative-Fixnum
|
|
-> Void)
|
|
((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void)))
|
|
(define (elim-rows! rows m i j pivot start)
|
|
(define row_i (unsafe-vector-ref rows i))
|
|
(let loop ([#{l : Nonnegative-Fixnum} start])
|
|
(when (l . fx< . m)
|
|
(unless (l . fx= . i)
|
|
(define row_l (unsafe-vector-ref rows l))
|
|
(define x_lj (unsafe-vector-ref row_l j))
|
|
(unless (= x_lj 0)
|
|
(vector-scaled-add! row_l row_i (* -1 (/ x_lj pivot)) j)
|
|
;; Make sure the element below the pivot is zero
|
|
(unsafe-vector-set! row_l j (- x_lj x_lj))))
|
|
(loop (fx+ l 1)))))
|
|
|
|
(begin-encourage-inline
|
|
|
|
(: call/ns (All (A) ((-> (Matrix A)) -> (Matrix A))))
|
|
(define (call/ns thnk)
|
|
(array-default-strict
|
|
(parameterize ([array-strictness #f])
|
|
(thnk))))
|
|
|
|
) ; begin-encourage-inline
|
|
|
|
(: make-thread-local-box (All (A) (A -> (-> (Boxof A)))))
|
|
(define (make-thread-local-box contents)
|
|
(let: ([val : (Thread-Cellof (U #f (Boxof A))) (make-thread-cell #f)])
|
|
(λ () (or (thread-cell-ref val)
|
|
(let: ([v : (Boxof A) (box contents)])
|
|
(thread-cell-set! val v)
|
|
v)))))
|
|
|
|
(: one (case-> (Flonum -> Nonnegative-Flonum)
|
|
(Real -> (U 1 Nonnegative-Flonum))
|
|
(Float-Complex -> Nonnegative-Flonum)
|
|
(Number -> (U 1 Nonnegative-Flonum))))
|
|
(define (one x)
|
|
(cond [(flonum? x) 1.0]
|
|
[(real? x) 1]
|
|
[(float-complex? x) 1.0]
|
|
[else 1]))
|
|
|
|
(: zero (case-> (Flonum -> Flonum-Positive-Zero)
|
|
(Real -> (U 0 Flonum-Positive-Zero))
|
|
(Float-Complex -> Flonum-Positive-Zero)
|
|
(Number -> (U 0 Flonum-Positive-Zero))))
|
|
(define (zero x)
|
|
(cond [(flonum? x) 0.0]
|
|
[(real? x) 0]
|
|
[(float-complex? x) 0.0]
|
|
[else 0]))
|
|
|
|
(: one* (case-> (Flonum -> Nonnegative-Flonum)
|
|
(Real -> (U 1 Nonnegative-Flonum))
|
|
(Float-Complex -> Float-Complex)
|
|
(Number -> (U 1 Nonnegative-Flonum Float-Complex))))
|
|
(define (one* x)
|
|
(cond [(flonum? x) 1.0]
|
|
[(real? x) 1]
|
|
[(float-complex? x) 1.0+0.0i]
|
|
[else 1]))
|
|
|
|
(: zero* (case-> (Flonum -> Flonum-Positive-Zero)
|
|
(Real -> (U 0 Flonum-Positive-Zero))
|
|
(Float-Complex -> Float-Complex)
|
|
(Number -> (U 0 Flonum-Positive-Zero Float-Complex))))
|
|
(define (zero* x)
|
|
(cond [(flonum? x) 0.0]
|
|
[(real? x) 0]
|
|
[(float-complex? x) 0.0+0.0i]
|
|
[else 0]))
|