More efficient Gaussian elimination using vectors of vectors (non-strict
arrays can't help an inherently sequential algorithm)
This commit is contained in:
parent
099a35881e
commit
3bc4c1ffdc
|
@ -1,9 +1,11 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require racket/list
|
||||
(require racket/fixnum
|
||||
racket/list
|
||||
math/array
|
||||
(only-in typed/racket conjugate)
|
||||
"../unsafe.rkt"
|
||||
"../vector/vector-mutate.rkt"
|
||||
"matrix-types.rkt"
|
||||
"matrix-constructors.rkt"
|
||||
"matrix-conversion.rkt"
|
||||
|
@ -198,71 +200,75 @@
|
|||
|
||||
;;; GAUSS ELIMINATION / ROW ECHELON FORM
|
||||
|
||||
(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A)))
|
||||
(define (unsafe-vector2d-ref vss i j)
|
||||
(unsafe-vector-ref (unsafe-vector-ref vss i) j))
|
||||
|
||||
(: find-partial-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index))
|
||||
((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index))))
|
||||
;; Find the element with maximum magnitude in a column
|
||||
(define (find-partial-pivot rows m i j)
|
||||
(let loop ([#{l : Nonnegative-Fixnum} i] [#{max-current : Real} -inf.0] [#{max-index : Index} i])
|
||||
(cond [(l . fx< . m)
|
||||
(define v (magnitude (unsafe-vector2d-ref rows l j)))
|
||||
(cond [(> v max-current) (loop (fx+ l 1) v l)]
|
||||
[else (loop (fx+ l 1) max-current max-index)])]
|
||||
[else max-index])))
|
||||
|
||||
(: find-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index))
|
||||
((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index))))
|
||||
;; Find a non-zero element in a column
|
||||
(define (find-pivot rows m i j)
|
||||
(let loop ([#{l : Nonnegative-Fixnum} i])
|
||||
(cond [(l . fx>= . m) #f]
|
||||
[(not (zero? (unsafe-vector2d-ref rows l j))) l]
|
||||
[else (loop (fx+ l 1))])))
|
||||
|
||||
(: matrix-gauss-eliminate :
|
||||
(case-> ((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Integer)))
|
||||
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer)))
|
||||
((Matrix Number) -> (Values (Matrix Number) (Listof Integer)))))
|
||||
(define (matrix-gauss-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t])
|
||||
(case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index)))
|
||||
((Matrix Real) Boolean -> (Values (Matrix Real) (Listof Index)))
|
||||
((Matrix Real) Boolean Boolean -> (Values (Matrix Real) (Listof Index)))
|
||||
((Matrix Number) -> (Values (Matrix Number) (Listof Index)))
|
||||
((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Index)))
|
||||
((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Index)))))
|
||||
;; Returns the result of Gaussian elimination and a list of column indexes that had no pivot value
|
||||
;; If `reduced?' is #t, the result is in *reduced* row-echelon form, and is unique (up to
|
||||
;; floating-point error)
|
||||
;; If `partial-pivoting?' is #t, the largest value in each column is used as the pivot
|
||||
(define (matrix-gauss-eliminate M [reduced? #f] [partial-pivoting? #t])
|
||||
(define-values (m n) (matrix-shape M))
|
||||
(: loop : (Integer Integer (Matrix Number) Integer (Listof Integer)
|
||||
-> (Values (Matrix Number) (Listof Integer))))
|
||||
(define (loop i j ; i from 0 to m
|
||||
M
|
||||
k ; count rows without pivot
|
||||
without-pivot)
|
||||
(define rows (matrix->vector* M))
|
||||
(let loop ([#{i : Nonnegative-Fixnum} 0]
|
||||
[#{j : Nonnegative-Fixnum} 0]
|
||||
[#{without-pivot : (Listof Index)} '()])
|
||||
(cond
|
||||
[(or (= i m) (= j n)) (values M without-pivot)]
|
||||
[else
|
||||
; find row to become pivot
|
||||
(define p
|
||||
(if partial-pivoting?
|
||||
; find element with maximal absolute value
|
||||
(let: max-loop : (U False Integer)
|
||||
([l : Integer i] ; i<=l<m
|
||||
[max-current : Real -inf.0]
|
||||
[max-index : Integer i])
|
||||
(cond
|
||||
[(= l m) max-index]
|
||||
[else
|
||||
(let ([v (magnitude (matrix-ref M l j))])
|
||||
(if (> (magnitude (matrix-ref M l j)) max-current)
|
||||
(max-loop (+ l 1) v l)
|
||||
(max-loop (+ l 1) max-current max-index)))]))
|
||||
; find non-zero element in column
|
||||
(let: first-loop : (U False Integer)
|
||||
([l : Integer i]) ; i<=l<m
|
||||
(cond
|
||||
[(= l m) #f]
|
||||
[(not (zero? (matrix-ref M l j))) l]
|
||||
[else (first-loop (+ l 1))]))))
|
||||
[(and (i . fx< . m) (j . fx< . n))
|
||||
;; Find the row with the pivot value
|
||||
(define p (cond [partial-pivoting? (find-partial-pivot rows m i j)]
|
||||
[else (find-pivot rows m i j)]))
|
||||
(define pivot (if p (unsafe-vector2d-ref rows p j) 0))
|
||||
(cond
|
||||
[(or (eq? p #f)
|
||||
(zero? (matrix-ref M p j)))
|
||||
; no pivot found
|
||||
(loop i (+ j 1) M (+ k 1) (cons j without-pivot))]
|
||||
[(or (not p) (zero? pivot)) ; didn't find pivot?
|
||||
(loop i (fx+ j 1) (cons j without-pivot))]
|
||||
[else
|
||||
; swap if neccessary
|
||||
(let* ([M (if (= i p) M (matrix-swap-rows M i p))]
|
||||
; now we now (i,j) is a pivot
|
||||
[M ; maybe scale row
|
||||
(if unitize-pivot-row?
|
||||
(let ([pivot (matrix-ref M i j)])
|
||||
(if (zero? pivot)
|
||||
M
|
||||
(matrix-scale-row M i (/ pivot))))
|
||||
M)])
|
||||
(let ([pivot (matrix-ref M i j)])
|
||||
; remove elements below pivot
|
||||
(let l-loop ([l (+ i 1)] [M M])
|
||||
(if (= l m)
|
||||
(loop (+ i 1) (+ j 1) M k without-pivot)
|
||||
(let ([x_lj (matrix-ref M l j)])
|
||||
(l-loop (+ l 1)
|
||||
(if (zero? x_lj)
|
||||
M
|
||||
(matrix-add-scaled-row M l (- (/ x_lj pivot)) i))))))))])]))
|
||||
(let-values ([(M without) (loop 0 0 M 0 '())])
|
||||
(values M without)))
|
||||
(vector-swap! rows i p) ; swap pivot row with current
|
||||
(let ([pivot (cond [reduced? (vector-scale! (unsafe-vector-ref rows i) (/ pivot))
|
||||
(/ pivot pivot)]
|
||||
[else pivot])])
|
||||
;; Remove elements below pivot by scaling and adding the pivot's row to each row below
|
||||
(let l-loop ([#{l : Nonnegative-Fixnum} (fx+ i 1)])
|
||||
(cond [(l . fx< . m)
|
||||
(define x_lj (unsafe-vector2d-ref rows l j))
|
||||
(unless (zero? x_lj)
|
||||
(vector-scaled-add! (unsafe-vector-ref rows l)
|
||||
(unsafe-vector-ref rows i)
|
||||
(- (/ x_lj pivot))))
|
||||
(l-loop (fx+ l 1))]
|
||||
[else
|
||||
(loop (fx+ i 1) (fx+ j 1) without-pivot)])))])]
|
||||
[else
|
||||
(values (vector*->matrix rows)
|
||||
(reverse without-pivot))])))
|
||||
|
||||
(: matrix-rank : (Matrix Number) -> Integer)
|
||||
(define (matrix-rank M)
|
||||
|
|
47
collects/math/private/vector/vector-mutate.rkt
Normal file
47
collects/math/private/vector/vector-mutate.rkt
Normal file
|
@ -0,0 +1,47 @@
|
|||
#lang typed/racket/base
|
||||
|
||||
(require racket/fixnum
|
||||
math/private/unsafe)
|
||||
|
||||
(provide vector-swap!
|
||||
vector-scale!
|
||||
vector-scaled-add!)
|
||||
|
||||
(: vector-swap! (All (A) ((Vectorof A) Integer Integer -> Void)))
|
||||
(define (vector-swap! vs i0 i1)
|
||||
(unless (= i0 i1)
|
||||
(define tmp (unsafe-vector-ref vs i0))
|
||||
(unsafe-vector-set! vs i0 (unsafe-vector-ref vs i1))
|
||||
(unsafe-vector-set! vs i1 tmp)))
|
||||
|
||||
(define-syntax-rule (vector-generic-scale! vs-expr v-expr *)
|
||||
(let* ([vs vs-expr]
|
||||
[v v-expr]
|
||||
[n (vector-length vs)])
|
||||
(let loop ([#{i : Nonnegative-Fixnum} 0])
|
||||
(if (i . fx< . n)
|
||||
(begin (unsafe-vector-set! vs i (* v (unsafe-vector-ref vs i)))
|
||||
(loop (fx+ i 1)))
|
||||
(void)))))
|
||||
|
||||
(: vector-scale! (case-> ((Vectorof Real) Real -> Void)
|
||||
((Vectorof Number) Number -> Void)))
|
||||
(define (vector-scale! vs v)
|
||||
(vector-generic-scale! vs v *))
|
||||
|
||||
(define-syntax-rule (vector-generic-scaled-add! vs0-expr vs1-expr v-expr + *)
|
||||
(let* ([vs0 vs0-expr]
|
||||
[vs1 vs1-expr]
|
||||
[v v-expr]
|
||||
[n (min (vector-length vs0) (vector-length vs1))])
|
||||
(let loop ([#{i : Nonnegative-Fixnum} 0])
|
||||
(if (i . fx< . n)
|
||||
(begin (unsafe-vector-set! vs0 i (+ (unsafe-vector-ref vs0 i)
|
||||
(* (unsafe-vector-ref vs1 i) v)))
|
||||
(loop (fx+ i 1)))
|
||||
(void)))))
|
||||
|
||||
(: vector-scaled-add! (case-> ((Vectorof Real) (Vectorof Real) Real -> Void)
|
||||
((Vectorof Number) (Vectorof Number) Number -> Void)))
|
||||
(define (vector-scaled-add! v0 v1 s)
|
||||
(vector-generic-scaled-add! v0 v1 s + *))
|
|
@ -20,7 +20,7 @@
|
|||
(make-array #(0 1) 0)
|
||||
(make-array #(0 0) 0)
|
||||
(make-array #(1 1 1) 0)))
|
||||
|
||||
#|
|
||||
;; ===================================================================================================
|
||||
;; Literal syntax
|
||||
|
||||
|
@ -666,7 +666,34 @@
|
|||
(check-exn exn:fail:contract? (λ () (matrix-trace (col-matrix [1 2 3]))))
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (matrix-trace a))))
|
||||
|#
|
||||
;; ===================================================================================================
|
||||
;; Gaussian elimination
|
||||
|
||||
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
|
||||
(define (gauss-eliminate M reduce? partial-pivot?)
|
||||
(let-values ([(M wp) (matrix-gauss-eliminate M reduce? partial-pivot?)])
|
||||
M))
|
||||
|
||||
(check-equal? (gauss-eliminate (matrix [[1 2] [3 4]]) #f #f)
|
||||
(matrix [[1 2] [0 -2]]))
|
||||
|
||||
(check-equal? (gauss-eliminate (matrix [[2 4] [3 4]]) #t #f)
|
||||
(matrix [[1 2] [0 1]]))
|
||||
|
||||
(check-equal? (gauss-eliminate (matrix [[2. 4.] [3. 4.]]) #t #t)
|
||||
(matrix [[1. 1.3333333333333333] [0. 1.]]))
|
||||
|
||||
(check-equal? (gauss-eliminate (matrix [[1 4] [2 4]]) #t #t)
|
||||
(matrix [[1 2] [0 1]]))
|
||||
|
||||
(check-equal? (gauss-eliminate (matrix [[1 2] [2 4]]) #f #t)
|
||||
(matrix [[2 4] [0 0]]))
|
||||
|
||||
(for: ([a (in-list nonmatrices)])
|
||||
(check-exn exn:fail:contract? (λ () (gauss-eliminate a #f #f))))
|
||||
|
||||
#|
|
||||
;; ===================================================================================================
|
||||
;; Tests not yet converted to rackunit
|
||||
|
||||
|
@ -765,27 +792,6 @@
|
|||
[9 10 -11 12]
|
||||
[13 14 15 16]]))
|
||||
5280))
|
||||
(let ()
|
||||
(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number))
|
||||
(define (gauss-eliminate M u? p?)
|
||||
(let-values ([(M wp) (matrix-gauss-eliminate M u? p?)])
|
||||
M))
|
||||
(list 'matrix-gauss-eliminate
|
||||
(equal? (let ([M (list*->matrix '[[1 2] [3 4]])])
|
||||
(gauss-eliminate M #f #f))
|
||||
(list*->matrix '[[1 2] [0 -2]]))
|
||||
(equal? (let ([M (list*->matrix '[[2 4] [3 4]])])
|
||||
(gauss-eliminate M #t #f))
|
||||
(list*->matrix '[[1 2] [0 1]]))
|
||||
(equal? (let ([M (list*->matrix '[[2. 4.] [3. 4.]])])
|
||||
(gauss-eliminate M #t #t))
|
||||
(list*->matrix '[[1. 1.3333333333333333] [0. 1.]]))
|
||||
(equal? (let ([M (list*->matrix '[[1 4] [2 4]])])
|
||||
(gauss-eliminate M #t #t))
|
||||
(list*->matrix '[[1 2] [0 1]]))
|
||||
(equal? (let ([M (list*->matrix '[[1 2] [2 4]])])
|
||||
(gauss-eliminate M #f #t))
|
||||
(list*->matrix '[[2 4] [0 0]]))))
|
||||
(list
|
||||
'matrix-scale-row
|
||||
(equal? (matrix-scale-row (identity-matrix 3) 0 2)
|
||||
|
@ -893,3 +899,4 @@
|
|||
(equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O)
|
||||
(equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2))))))
|
||||
|#
|
||||
|
|
Loading…
Reference in New Issue
Block a user