racket/collects/math/private/matrix/matrix-lu.rkt
Neil Toronto a1aa97c1fd `math/matrix' fixes; please merge to 5.3.2.
* Fixed type of `matrix-expt'

* Made matrix functions respect `array-strictness' parameter (mostly
  wrapping functions with `parameterize' and return values with
  `array-default-strictness'; reindentation makes changes look larger)

* Added strictness tests
(cherry picked from commit f40ad2ca9d)
2013-01-18 15:55:34 -05:00

63 lines
2.5 KiB
Racket

#lang typed/racket/base
(require racket/fixnum
"matrix-types.rkt"
"matrix-conversion.rkt"
"matrix-arithmetic.rkt"
"utils.rkt"
"../unsafe.rkt"
"../vector/vector-mutate.rkt"
"../array/mutable-array.rkt"
"../array/array-struct.rkt")
(provide matrix-lu)
;; An LU factorization exists iff Gaussian elimination can be done without row swaps.
(: matrix-lu
(All (A) (case-> ((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number))))))
(define matrix-lu
(case-lambda
[(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))]
[(M fail)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
(define L
(parameterize ([array-strictness #f])
;; Construct L in a weird way to prove to TR that it has the right type
(array->mutable-array (matrix-scale M (ann 0 Real)))))
;; Going to fill in the lower triangle by banging values into `ys'
(define ys (mutable-array-data L))
(let loop ([#{i : Nonnegative-Fixnum} 0])
(cond
[(i . fx< . m)
;; Pivot must be on the diagonal
(define pivot (unsafe-vector2d-ref rows i i))
(cond
[(zero? pivot) (values (fail) M)]
[else
;; Zero out everything below the pivot
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond
[(l . fx< . m)
(define x_li (unsafe-vector2d-ref rows l i))
(define y_li (/ x_li pivot))
(unless (zero? x_li)
;; Fill in lower triangle of L
(unsafe-vector-set! ys (+ (* l m) i) y_li)
;; Add row i, scaled
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- y_li)))
(l-loop (fx+ l 1))]
[else
(loop (fx+ i 1))]))])]
[else
;; L's lower triangle has been filled; now fill the diagonal with 1s
(for: ([i : Integer (in-range 0 m)])
(vector-set! ys (+ (* i m) i) 1))
(values L (vector*->matrix rows))]))]))