racket/collects/math/private/matrix/utils.rkt
Neil Toronto 5981535686 Made Gaussian elimination faster (half the scaled additions) and more
accurate (always produces zeros in the lower half)
2013-01-01 18:19:43 -07:00

116 lines
4.7 KiB
Racket

#lang typed/racket/base
(require racket/string
racket/fixnum
"matrix-types.rkt"
"../unsafe.rkt"
"../array/array-struct.rkt"
"../vector/vector-mutate.rkt")
(provide (all-defined-out))
(: format-matrices/error ((Listof (Array Any)) -> String))
(define (format-matrices/error as)
(string-join (map (λ: ([a : (Array Any)]) (format "~e" a)) as)))
(: matrix-shapes (Symbol (Matrix Any) (Matrix Any) * -> (Values Index Index)))
(define (matrix-shapes name arr . brrs)
(define-values (m n) (matrix-shape arr))
(unless (andmap (λ: ([brr : (Matrix Any)])
(define-values (bm bn) (matrix-shape brr))
(and (= bm m) (= bn n)))
brrs)
(error name
"matrices must have the same shape; given ~a"
(format-matrices/error (cons arr brrs))))
(values m n))
(: matrix-multiply-shape ((Matrix Any) (Matrix Any) -> (Values Index Index Index)))
(define (matrix-multiply-shape arr brr)
(define-values (ad0 ad1) (matrix-shape arr))
(define-values (bd0 bd1) (matrix-shape brr))
(unless (= ad1 bd0)
(error 'matrix-multiply
"1st argument column size and 2nd argument row size are not equal; given ~e and ~e"
arr brr))
(values ad0 ad1 bd1))
(: ensure-matrix (All (A) Symbol (Array A) -> (Array A)))
(define (ensure-matrix name a)
(if (matrix? a) a (raise-argument-error name "matrix?" a)))
(: ensure-row-matrix (All (A) Symbol (Array A) -> (Array A)))
(define (ensure-row-matrix name a)
(if (row-matrix? a) a (raise-argument-error name "row-matrix?" a)))
(: ensure-col-matrix (All (A) Symbol (Array A) -> (Array A)))
(define (ensure-col-matrix name a)
(if (col-matrix? a) a (raise-argument-error name "col-matrix?" a)))
(: sort/key (All (A B) (case-> ((Listof A) (B B -> Boolean) (A -> B) -> (Listof A))
((Listof A) (B B -> Boolean) (A -> B) Boolean -> (Listof A)))))
;; Sometimes necessary because TR can't do inference with keyword arguments yet
(define (sort/key lst lt? key [cache-keys? #f])
((inst sort A B) lst lt? #:key key #:cache-keys? cache-keys?))
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
(define (unsafe-vector2d-ref vss i j)
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
;; Note: this accepts +nan.0
(define nonnegative?
(λ: ([x : Real]) (not (x . < . 0))))
(define number-rational?
(λ: ([z : Number])
(cond [(real? z) (rational? z)]
[else (and (rational? (real-part z))
(rational? (imag-part z)))])))
(: find-partial-pivot
(case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real))
((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number))))
;; Find the element with maximum magnitude in a column
(define (find-partial-pivot rows m i j)
(define l (fx+ i 1))
(define pivot (unsafe-vector2d-ref rows i j))
(define mag-pivot (magnitude pivot))
(let loop ([#{l : Nonnegative-Fixnum} l] [#{p : Index} i] [pivot pivot] [mag-pivot mag-pivot])
(cond [(l . fx< . m)
(define new-pivot (unsafe-vector2d-ref rows l j))
(define mag-new-pivot (magnitude new-pivot))
(cond [(mag-new-pivot . > . mag-pivot) (loop (fx+ l 1) l new-pivot mag-new-pivot)]
[else (loop (fx+ l 1) p pivot mag-pivot)])]
[else (values p pivot)])))
(: find-first-pivot
(case-> ((Vectorof (Vectorof Real)) Index Index Index -> (Values Index Real))
((Vectorof (Vectorof Number)) Index Index Index -> (Values Index Number))))
;; Find the first nonzero element in a column
(define (find-first-pivot rows m i j)
(define pivot (unsafe-vector2d-ref rows i j))
(if ((magnitude pivot) . > . 0)
(values i pivot)
(let loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond [(l . fx< . m)
(define pivot (unsafe-vector2d-ref rows l j))
(if ((magnitude pivot) . > . 0) (values l pivot) (loop (fx+ l 1)))]
[else
(values i pivot)]))))
(: elim-rows!
(case-> ((Vectorof (Vectorof Real)) Index Index Index Real Nonnegative-Fixnum -> Void)
((Vectorof (Vectorof Number)) Index Index Index Number Nonnegative-Fixnum -> Void)))
(define (elim-rows! rows m i j pivot start)
(define row_i (unsafe-vector-ref rows i))
(let loop ([#{l : Nonnegative-Fixnum} start])
(when (l . fx< . m)
(unless (l . fx= . i)
(define row_l (unsafe-vector-ref rows l))
(define x_lj (unsafe-vector-ref row_l j))
(unless (zero? x_lj)
(vector-scaled-add! row_l row_i (- (/ x_lj pivot)) j)
;; Make sure the element below the pivot is zero
(unsafe-vector-set! row_l j (- x_lj x_lj))))
(loop (fx+ l 1)))))