racket/collects/math/private/matrix/matrix-lu.rkt
Neil Toronto f42cc6f14a Fixed major performance issue with matrix arithmetic; please merge to 5.3.2
The fix consists of three parts:

1. Rewriting `inline-matrix*'. The material change here is that the
   expansion now contains only direct applications of `+' and `*'.
   TR's optimizer replaces them with `unsafe-fx+' and `unsafe-fx*',
   which keeps intermediate flonum values from being boxed.

2. Making the types of all functions that operate on (Matrix Number)
   values more precise. Now TR can prove that matrix operations preserve
   inexactness. For example, matrix-conjugate : (Matrix Flonum) ->
   (Matrix Flonum) and three other cases for Real, Float-Complex, and
   Number.

3. Changing the return types of some functions that used to return
   things like (Matrix (U A 0)). Now that we worry about preserving
   inexactness, we can't have `matrix-upper-triangle' always return a
   matrix that contains exact zeros. It now accepts an optional `zero'
   argument of type A.
2013-01-21 22:04:04 -07:00

71 lines
3.1 KiB
Racket

#lang typed/racket/base
(require racket/fixnum
"matrix-types.rkt"
"matrix-conversion.rkt"
"matrix-arithmetic.rkt"
"utils.rkt"
"../unsafe.rkt"
"../vector/vector-mutate.rkt"
"../array/mutable-array.rkt"
"../array/array-struct.rkt"
"../array/array-pointwise.rkt")
(provide matrix-lu)
;; An LU factorization exists iff Gaussian elimination can be done without row swaps.
(: matrix-lu
(All (A) (case-> ((Matrix Flonum) -> (Values (Matrix Flonum) (Matrix Flonum)))
((Matrix Flonum) (-> A) -> (Values (U A (Matrix Flonum)) (Matrix Flonum)))
((Matrix Real) -> (Values (Matrix Real) (Matrix Real)))
((Matrix Real) (-> A) -> (Values (U A (Matrix Real)) (Matrix Real)))
((Matrix Float-Complex) -> (Values (Matrix Float-Complex)
(Matrix Float-Complex)))
((Matrix Float-Complex) (-> A) -> (Values (U A (Matrix Float-Complex))
(Matrix Float-Complex)))
((Matrix Number) -> (Values (Matrix Number) (Matrix Number)))
((Matrix Number) (-> A) -> (Values (U A (Matrix Number)) (Matrix Number))))))
(define matrix-lu
(case-lambda
[(M) (matrix-lu M (λ () (raise-argument-error 'matrix-lu "LU-decomposable matrix" M)))]
[(M fail)
(define m (square-matrix-size M))
(define rows (matrix->vector* M))
(define L
(parameterize ([array-strictness #f])
;; Construct L in a weird way to prove to TR that it has the right type
(array->mutable-array (inline-array-map zero* M))))
;; Going to fill in the lower triangle by banging values into `ys'
(define ys (mutable-array-data L))
(let loop ([#{i : Nonnegative-Fixnum} 0])
(cond
[(i . fx< . m)
;; Pivot must be on the diagonal
(define pivot (unsafe-vector2d-ref rows i i))
(cond
[(zero? pivot) (values (fail) M)]
[else
;; Zero out everything below the pivot
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond
[(l . fx< . m)
(define x_li (unsafe-vector2d-ref rows l i))
(define y_li (/ x_li pivot))
(unless (zero? x_li)
;; Fill in lower triangle of L
(unsafe-vector-set! ys (+ (* l m) i) y_li)
;; Add row i, scaled
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(* -1 y_li)))
(l-loop (fx+ l 1))]
[else
(loop (fx+ i 1))]))])]
[else
;; L's lower triangle has been filled; now fill the diagonal with 1s
(for: ([i : Integer (in-range 0 m)])
(define j (+ (* i m) i))
(vector-set! ys j (one* (vector-ref ys j))))
(values L (vector*->matrix rows))]))]))