diff --git a/collects/math/private/matrix/matrix-operations.rkt b/collects/math/private/matrix/matrix-operations.rkt index 1b757e52da..b3ead31c02 100644 --- a/collects/math/private/matrix/matrix-operations.rkt +++ b/collects/math/private/matrix/matrix-operations.rkt @@ -1,9 +1,11 @@ #lang typed/racket/base -(require racket/list +(require racket/fixnum + racket/list math/array (only-in typed/racket conjugate) "../unsafe.rkt" + "../vector/vector-mutate.rkt" "matrix-types.rkt" "matrix-constructors.rkt" "matrix-conversion.rkt" @@ -198,71 +200,75 @@ ;;; GAUSS ELIMINATION / ROW ECHELON FORM +(: unsafe-vector2d-ref (All (A) ((Vectorof (Vectorof A)) Index Index -> A))) +(define (unsafe-vector2d-ref vss i j) + (unsafe-vector-ref (unsafe-vector-ref vss i) j)) + +(: find-partial-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index)) + ((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index)))) +;; Find the element with maximum magnitude in a column +(define (find-partial-pivot rows m i j) + (let loop ([#{l : Nonnegative-Fixnum} i] [#{max-current : Real} -inf.0] [#{max-index : Index} i]) + (cond [(l . fx< . m) + (define v (magnitude (unsafe-vector2d-ref rows l j))) + (cond [(> v max-current) (loop (fx+ l 1) v l)] + [else (loop (fx+ l 1) max-current max-index)])] + [else max-index]))) + +(: find-pivot (case-> ((Vectorof (Vectorof Real)) Index Index Index -> (U #f Index)) + ((Vectorof (Vectorof Number)) Index Index Index -> (U #f Index)))) +;; Find a non-zero element in a column +(define (find-pivot rows m i j) + (let loop ([#{l : Nonnegative-Fixnum} i]) + (cond [(l . fx>= . m) #f] + [(not (zero? (unsafe-vector2d-ref rows l j))) l] + [else (loop (fx+ l 1))]))) + (: matrix-gauss-eliminate : - (case-> ((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Integer))) - ((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Integer))) - ((Matrix Number) -> (Values (Matrix Number) (Listof Integer))))) -(define (matrix-gauss-eliminate M [unitize-pivot-row? #f] [partial-pivoting? #t]) + (case-> ((Matrix Real) -> (Values (Matrix Real) (Listof Index))) + ((Matrix Real) Boolean -> (Values (Matrix Real) (Listof Index))) + ((Matrix Real) Boolean Boolean -> (Values (Matrix Real) (Listof Index))) + ((Matrix Number) -> (Values (Matrix Number) (Listof Index))) + ((Matrix Number) Boolean -> (Values (Matrix Number) (Listof Index))) + ((Matrix Number) Boolean Boolean -> (Values (Matrix Number) (Listof Index))))) +;; Returns the result of Gaussian elimination and a list of column indexes that had no pivot value +;; If `reduced?' is #t, the result is in *reduced* row-echelon form, and is unique (up to +;; floating-point error) +;; If `partial-pivoting?' is #t, the largest value in each column is used as the pivot +(define (matrix-gauss-eliminate M [reduced? #f] [partial-pivoting? #t]) (define-values (m n) (matrix-shape M)) - (: loop : (Integer Integer (Matrix Number) Integer (Listof Integer) - -> (Values (Matrix Number) (Listof Integer)))) - (define (loop i j ; i from 0 to m - M - k ; count rows without pivot - without-pivot) + (define rows (matrix->vector* M)) + (let loop ([#{i : Nonnegative-Fixnum} 0] + [#{j : Nonnegative-Fixnum} 0] + [#{without-pivot : (Listof Index)} '()]) (cond - [(or (= i m) (= j n)) (values M without-pivot)] - [else - ; find row to become pivot - (define p - (if partial-pivoting? - ; find element with maximal absolute value - (let: max-loop : (U False Integer) - ([l : Integer i] ; i<=l (magnitude (matrix-ref M l j)) max-current) - (max-loop (+ l 1) v l) - (max-loop (+ l 1) max-current max-index)))])) - ; find non-zero element in column - (let: first-loop : (U False Integer) - ([l : Integer i]) ; i<=lmatrix rows) + (reverse without-pivot))]))) (: matrix-rank : (Matrix Number) -> Integer) (define (matrix-rank M) diff --git a/collects/math/private/vector/vector-mutate.rkt b/collects/math/private/vector/vector-mutate.rkt new file mode 100644 index 0000000000..4ca10c35e2 --- /dev/null +++ b/collects/math/private/vector/vector-mutate.rkt @@ -0,0 +1,47 @@ +#lang typed/racket/base + +(require racket/fixnum + math/private/unsafe) + +(provide vector-swap! + vector-scale! + vector-scaled-add!) + +(: vector-swap! (All (A) ((Vectorof A) Integer Integer -> Void))) +(define (vector-swap! vs i0 i1) + (unless (= i0 i1) + (define tmp (unsafe-vector-ref vs i0)) + (unsafe-vector-set! vs i0 (unsafe-vector-ref vs i1)) + (unsafe-vector-set! vs i1 tmp))) + +(define-syntax-rule (vector-generic-scale! vs-expr v-expr *) + (let* ([vs vs-expr] + [v v-expr] + [n (vector-length vs)]) + (let loop ([#{i : Nonnegative-Fixnum} 0]) + (if (i . fx< . n) + (begin (unsafe-vector-set! vs i (* v (unsafe-vector-ref vs i))) + (loop (fx+ i 1))) + (void))))) + +(: vector-scale! (case-> ((Vectorof Real) Real -> Void) + ((Vectorof Number) Number -> Void))) +(define (vector-scale! vs v) + (vector-generic-scale! vs v *)) + +(define-syntax-rule (vector-generic-scaled-add! vs0-expr vs1-expr v-expr + *) + (let* ([vs0 vs0-expr] + [vs1 vs1-expr] + [v v-expr] + [n (min (vector-length vs0) (vector-length vs1))]) + (let loop ([#{i : Nonnegative-Fixnum} 0]) + (if (i . fx< . n) + (begin (unsafe-vector-set! vs0 i (+ (unsafe-vector-ref vs0 i) + (* (unsafe-vector-ref vs1 i) v))) + (loop (fx+ i 1))) + (void))))) + +(: vector-scaled-add! (case-> ((Vectorof Real) (Vectorof Real) Real -> Void) + ((Vectorof Number) (Vectorof Number) Number -> Void))) +(define (vector-scaled-add! v0 v1 s) + (vector-generic-scaled-add! v0 v1 s + *)) diff --git a/collects/math/tests/matrix-tests.rkt b/collects/math/tests/matrix-tests.rkt index 96dc77d514..f8d7716574 100644 --- a/collects/math/tests/matrix-tests.rkt +++ b/collects/math/tests/matrix-tests.rkt @@ -20,7 +20,7 @@ (make-array #(0 1) 0) (make-array #(0 0) 0) (make-array #(1 1 1) 0))) - +#| ;; =================================================================================================== ;; Literal syntax @@ -666,7 +666,34 @@ (check-exn exn:fail:contract? (λ () (matrix-trace (col-matrix [1 2 3])))) (for: ([a (in-list nonmatrices)]) (check-exn exn:fail:contract? (λ () (matrix-trace a)))) +|# +;; =================================================================================================== +;; Gaussian elimination +(: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number)) +(define (gauss-eliminate M reduce? partial-pivot?) + (let-values ([(M wp) (matrix-gauss-eliminate M reduce? partial-pivot?)]) + M)) + +(check-equal? (gauss-eliminate (matrix [[1 2] [3 4]]) #f #f) + (matrix [[1 2] [0 -2]])) + +(check-equal? (gauss-eliminate (matrix [[2 4] [3 4]]) #t #f) + (matrix [[1 2] [0 1]])) + +(check-equal? (gauss-eliminate (matrix [[2. 4.] [3. 4.]]) #t #t) + (matrix [[1. 1.3333333333333333] [0. 1.]])) + +(check-equal? (gauss-eliminate (matrix [[1 4] [2 4]]) #t #t) + (matrix [[1 2] [0 1]])) + +(check-equal? (gauss-eliminate (matrix [[1 2] [2 4]]) #f #t) + (matrix [[2 4] [0 0]])) + +(for: ([a (in-list nonmatrices)]) + (check-exn exn:fail:contract? (λ () (gauss-eliminate a #f #f)))) + +#| ;; =================================================================================================== ;; Tests not yet converted to rackunit @@ -765,27 +792,6 @@ [9 10 -11 12] [13 14 15 16]])) 5280)) - (let () - (: gauss-eliminate : (Matrix Number) Boolean Boolean -> (Matrix Number)) - (define (gauss-eliminate M u? p?) - (let-values ([(M wp) (matrix-gauss-eliminate M u? p?)]) - M)) - (list 'matrix-gauss-eliminate - (equal? (let ([M (list*->matrix '[[1 2] [3 4]])]) - (gauss-eliminate M #f #f)) - (list*->matrix '[[1 2] [0 -2]])) - (equal? (let ([M (list*->matrix '[[2 4] [3 4]])]) - (gauss-eliminate M #t #f)) - (list*->matrix '[[1 2] [0 1]])) - (equal? (let ([M (list*->matrix '[[2. 4.] [3. 4.]])]) - (gauss-eliminate M #t #t)) - (list*->matrix '[[1. 1.3333333333333333] [0. 1.]])) - (equal? (let ([M (list*->matrix '[[1 4] [2 4]])]) - (gauss-eliminate M #t #t)) - (list*->matrix '[[1 2] [0 1]])) - (equal? (let ([M (list*->matrix '[[1 2] [2 4]])]) - (gauss-eliminate M #f #t)) - (list*->matrix '[[2 4] [0 0]])))) (list 'matrix-scale-row (equal? (matrix-scale-row (identity-matrix 3) 0 2) @@ -893,3 +899,4 @@ (equal? (matrix* (matrix-2d-orthogonal-projection 1 0) e2) O) (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e1) O) (equal? (matrix* (matrix-2d-orthogonal-projection 0 1) e2) e2)))))) +|#