More efficient Gaussian elimination using vectors of vectors (non-strict

arrays can't help an inherently sequential algorithm)
This commit is contained in:
Neil Toronto 2012-12-21 13:47:08 -07:00
parent 099a35881e
commit 3bc4c1ffdc
3 changed files with 143 additions and 83 deletions

View File

@ -1,9 +1,11 @@
#lang typed/racket/base
(require racket/list
(require racket/fixnum
racket/list
math/array
(only-in typed/racket conjugate)
"../unsafe.rkt"
"../vector/vector-mutate.rkt"
"matrix-types.rkt"
"matrix-constructors.rkt"
"matrix-conversion.rkt"
@ -198,71 +200,75 @@
;;; GAUSS ELIMINATION / ROW ECHELON FORM
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
(define (unsafe-vector2d-ref vss i j)
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
(: find-partial-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index))
((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index))))
;; Find the element with maximum magnitude in a column
(define (find-partial-pivot rows m i j)
(let loop ([#{l : Nonnegative-Fixnum} i] [#{max-current : Real} -inf.0] [#{max-index : Index} i])
(cond [(l . fx< . m)
(define v (magnitude (unsafe-vector2d-ref rows l j)))
(cond [(> v max-current) (loop (fx+ l 1) v l)]
[else (loop (fx+ l 1) max-current max-index)])]
[else max-index])))
(: find-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index))
((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index))))
;; Find a non-zero element in a column
(define (find-pivot rows m i j)
(let loop ([#{l : Nonnegative-Fixnum} i])
(cond [(l . fx>= . m) #f]
[(not (zero? (unsafe-vector2d-ref rows l j))) l]
[else (loop (fx+ l 1))])))
(: matrix-gauss-eliminate :
(case-> ((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Integer)))
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer)))
((Matrix Number) -> (Values (Matrix Number) (Listof Integer)))))
(define (matrix-gauss-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t])
(case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Boolean -> (Values (Matrix Real) (Listof Index)))
((Matrix Real) Boolean Boolean -> (Values (Matrix Real) (Listof Index)))
((Matrix Number) -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Index)))
((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Index)))))
;; Returns the result of Gaussian elimination and a list of column indexes that had no pivot value
;; If `reduced?' is #t, the result is in *reduced* row-echelon form, and is unique (up to
;; floating-point error)
;; If `partial-pivoting?' is #t, the largest value in each column is used as the pivot
(define (matrix-gauss-eliminate M [reduced? #f] [partial-pivoting? #t])
(define-values (m n) (matrix-shape M))
(: loop : (Integer Integer (Matrix Number) Integer (Listof Integer)
-> (Values (Matrix Number) (Listof Integer))))
(define (loop i j ; i from 0 to m
M
k ; count rows without pivot
without-pivot)
(define rows (matrix->vector* M))
(let loop ([#{i : Nonnegative-Fixnum} 0]
[#{j : Nonnegative-Fixnum} 0]
[#{without-pivot : (Listof Index)} '()])
(cond
[(or (= i m) (= j n)) (values M without-pivot)]
[else
; find row to become pivot
(define p
(if partial-pivoting?
; find element with maximal absolute value
(let: max-loop : (U False Integer)
([l : Integer i] ; i<=l<m
[max-current : Real -inf.0]
[max-index : Integer i])
(cond
[(= l m) max-index]
[else
(let ([v (magnitude (matrix-ref M l j))])
(if (> (magnitude (matrix-ref M l j)) max-current)
(max-loop (+ l 1) v l)
(max-loop (+ l 1) max-current max-index)))]))
; find non-zero element in column
(let: first-loop : (U False Integer)
([l : Integer i]) ; i<=l<m
(cond
[(= l m) #f]
[(not (zero? (matrix-ref M l j))) l]
[else (first-loop (+ l 1))]))))
[(and (i . fx< . m) (j . fx< . n))
;; Find the row with the pivot value
(define p (cond [partial-pivoting? (find-partial-pivot rows m i j)]
[else (find-pivot rows m i j)]))
(define pivot (if p (unsafe-vector2d-ref rows p j) 0))
(cond
[(or (eq? p #f)
(zero? (matrix-ref M p j)))
; no pivot found
(loop i (+ j 1) M (+ k 1) (cons j without-pivot))]
[(or (not p) (zero? pivot)) ; didn't find pivot?
(loop i (fx+ j 1) (cons j without-pivot))]
[else
; swap if neccessary
(let* ([M (if (= i p) M (matrix-swap-rows M i p))]
; now we now (i,j) is a pivot
[M ; maybe scale row
(if unitize-pivot-row?
(let ([pivot (matrix-ref M i j)])
(if (zero? pivot)
M
(matrix-scale-row M i (/ pivot))))
M)])
(let ([pivot (matrix-ref M i j)])
; remove elements below pivot
(let l-loop ([l (+ i 1)] [M M])
(if (= l m)
(loop (+ i 1) (+ j 1) M k without-pivot)
(let ([x_lj (matrix-ref M l j)])
(l-loop (+ l 1)
(if (zero? x_lj)
M
(matrix-add-scaled-row M l (- (/ x_lj pivot)) i))))))))])]))
(let-values ([(M without) (loop 0 0 M 0 '())])
(values M without)))
(vector-swap! rows i p) ; swap pivot row with current
(let ([pivot (cond [reduced? (vector-scale! (unsafe-vector-ref rows i) (/ pivot))
(/ pivot pivot)]
[else pivot])])
;; Remove elements below pivot by scaling and adding the pivot's row to each row below
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
(cond [(l . fx< . m)
(define x_lj (unsafe-vector2d-ref rows l j))
(unless (zero? x_lj)
(vector-scaled-add! (unsafe-vector-ref rows l)
(unsafe-vector-ref rows i)
(- (/ x_lj pivot))))
(l-loop (fx+ l 1))]
[else
(loop (fx+ i 1) (fx+ j 1) without-pivot)])))])]
[else
(values (vector*->matrix rows)
(reverse without-pivot))])))
(: matrix-rank : (Matrix Number) -> Integer)
(define (matrix-rank M)

View File

@ -0,0 +1,47 @@
#lang typed/racket/base
(require racket/fixnum
math/private/unsafe)
(provide vector-swap!
vector-scale!
vector-scaled-add!)
(: vector-swap! (All (A) ((Vectorof A) Integer Integer -> Void)))
(define (vector-swap! vs i0 i1)
(unless (= i0 i1)
(define tmp (unsafe-vector-ref vs i0))
(unsafe-vector-set! vs i0 (unsafe-vector-ref vs i1))
(unsafe-vector-set! vs i1 tmp)))
(define-syntax-rule (vector-generic-scale! vs-expr v-expr *)
(let* ([vs vs-expr]
[v v-expr]
[n (vector-length vs)])
(let loop ([#{i : Nonnegative-Fixnum} 0])
(if (i . fx< . n)
(begin (unsafe-vector-set! vs i (* v (unsafe-vector-ref vs i)))
(loop (fx+ i 1)))
(void)))))
(: vector-scale! (case-> ((Vectorof Real) Real -> Void)
((Vectorof Number) Number -> Void)))
(define (vector-scale! vs v)
(vector-generic-scale! vs v *))
(define-syntax-rule (vector-generic-scaled-add! vs0-expr vs1-expr v-expr + *)
(let* ([vs0 vs0-expr]
[vs1 vs1-expr]
[v v-expr]
[n (min (vector-length vs0) (vector-length vs1))])
(let loop ([#{i : Nonnegative-Fixnum} 0])
(if (i . fx< . n)
(begin (unsafe-vector-set! vs0 i (+ (unsafe-vector-ref vs0 i)
(* (unsafe-vector-ref vs1 i) v)))
(loop (fx+ i 1)))
(void)))))
(: vector-scaled-add! (case-> ((Vectorof Real) (Vectorof Real) Real -> Void)
((Vectorof Number) (Vectorof Number) Number -> Void)))
(define (vector-scaled-add! v0 v1 s)
(vector-generic-scaled-add! v0 v1 s + *))

View File

@ -20,7 +20,7 @@
(make-array #(0 1) 0)
(make-array #(0 0) 0)
(make-array #(1 1 1) 0)))
#|
;; ===================================================================================================
;; Literal syntax
@ -666,7 +666,34 @@
(check-exn exn:fail:contract? (λ () (matrix-trace (col-matrix [1 2 3]))))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (matrix-trace a))))
|#
;; ===================================================================================================
;; Gaussian elimination
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
(define (gauss-eliminate M reduce? partial-pivot?)
(let-values ([(M wp) (matrix-gauss-eliminate M reduce? partial-pivot?)])
M))
(check-equal? (gauss-eliminate (matrix [[1 2] [3 4]]) #f #f)
(matrix [[1 2] [0 -2]]))
(check-equal? (gauss-eliminate (matrix [[2 4] [3 4]]) #t #f)
(matrix [[1 2] [0 1]]))
(check-equal? (gauss-eliminate (matrix [[2. 4.] [3. 4.]]) #t #t)
(matrix [[1. 1.3333333333333333] [0. 1.]]))
(check-equal? (gauss-eliminate (matrix [[1 4] [2 4]]) #t #t)
(matrix [[1 2] [0 1]]))
(check-equal? (gauss-eliminate (matrix [[1 2] [2 4]]) #f #t)
(matrix [[2 4] [0 0]]))
(for: ([a (in-list nonmatrices)])
(check-exn exn:fail:contract? (λ () (gauss-eliminate a #f #f))))
#|
;; ===================================================================================================
;; Tests not yet converted to rackunit
@ -765,27 +792,6 @@
[9 10 -11 12]
[13 14 15 16]]))
5280))
(let ()
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
(define (gauss-eliminate M u? p?)
(let-values ([(M wp) (matrix-gauss-eliminate M u? p?)])
M))
(list 'matrix-gauss-eliminate
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])])
(gauss-eliminate M #f #f))
(list*->matrix '[[1 2] [0 -2]]))
(equal? (let ([M (list*->matrix '[[2 4] [3 4]])])
(gauss-eliminate M #t #f))
(list*->matrix '[[1 2] [0 1]]))
(equal? (let ([M (list*->matrix '[[2. 4.] [3. 4.]])])
(gauss-eliminate M #t #t))
(list*->matrix '[[1. 1.3333333333333333] [0. 1.]]))
(equal? (let ([M (list*->matrix '[[1 4] [2 4]])])
(gauss-eliminate M #t #t))
(list*->matrix '[[1 2] [0 1]]))
(equal? (let ([M (list*->matrix '[[1 2] [2 4]])])
(gauss-eliminate M #f #t))
(list*->matrix '[[2 4] [0 0]]))))
(list
'matrix-scale-row
(equal? (matrix-scale-row (identity-matrix 3) 0 2)
@ -893,3 +899,4 @@
(equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O)
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O)
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2))))))
|#