racket/collects/math/private/matrix/matrix-basic.rkt
Neil Toronto f42cc6f14a Fixed major performance issue with matrix arithmetic; please merge to 5.3.2
The fix consists of three parts:

1. Rewriting `inline-matrix*'. The material change here is that the
   expansion now contains only direct applications of `+' and `*'.
   TR's optimizer replaces them with `unsafe-fx+' and `unsafe-fx*',
   which keeps intermediate flonum values from being boxed.

2. Making the types of all functions that operate on (Matrix Number)
   values more precise. Now TR can prove that matrix operations preserve
   inexactness. For example, matrix-conjugate : (Matrix Flonum) ->
   (Matrix Flonum) and three other cases for Real, Float-Complex, and
   Number.

3. Changing the return types of some functions that used to return
   things like (Matrix (U A 0)). Now that we worry about preserving
   inexactness, we can't have `matrix-upper-triangle' always return a
   matrix that contains exact zeros. It now accepts an optional `zero'
   argument of type A.
2013-01-21 22:04:04 -07:00

480 lines
20 KiB
Racket

#lang typed/racket/base
(require racket/list
racket/fixnum
math/flonum
math/base
"matrix-types.rkt"
"matrix-arithmetic.rkt"
"matrix-constructors.rkt"
"matrix-conversion.rkt"
"utils.rkt"
"../unsafe.rkt"
"../array/array-struct.rkt"
"../array/array-indexing.rkt"
"../array/array-sequence.rkt"
"../array/array-transform.rkt"
"../array/array-fold.rkt"
"../array/array-pointwise.rkt"
"../array/array-convert.rkt"
"../array/utils.rkt"
"../vector/vector-mutate.rkt")
(provide
;; Extraction
matrix-ref
submatrix
matrix-row
matrix-col
matrix-rows
matrix-cols
matrix-diagonal
matrix-upper-triangle
matrix-lower-triangle
;; Embiggenment
matrix-augment
matrix-stack
;; Inner product space
matrix-1norm
matrix-2norm
matrix-inf-norm
matrix-norm
matrix-dot
matrix-cos-angle
matrix-angle
matrix-normalize
;; Simple operators
matrix-transpose
matrix-conjugate
matrix-hermitian
matrix-trace
;; Row/column operators
matrix-map-rows
matrix-map-cols
matrix-normalize-rows
matrix-normalize-cols
;; Predicates
matrix-rows-orthogonal?
matrix-cols-orthogonal?)
;; ===================================================================================================
;; Extraction
(: matrix-ref (All (A) (Matrix A) Integer Integer -> A))
(define (matrix-ref a i j)
(define-values (m n) (matrix-shape a))
(cond [(or (i . < . 0) (i . >= . m))
(raise-argument-error 'matrix-ref (format "Index < ~a" m) 1 a i j)]
[(or (j . < . 0) (j . >= . n))
(raise-argument-error 'matrix-ref (format "Index < ~a" n) 2 a i j)]
[else
(unsafe-array-ref a ((inst vector Index) i j))]))
(: submatrix (All (A) (Matrix A) (U Slice (Sequenceof Integer)) (U Slice (Sequenceof Integer))
-> (Array A)))
(define (submatrix a row-range col-range)
(array-slice-ref (ensure-matrix 'submatrix a) (list row-range col-range)))
(: matrix-row (All (A) (Matrix A) Integer -> (Matrix A)))
(define (matrix-row a i)
(define-values (m n) (matrix-shape a))
(cond [(or (i . < . 0) (i . >= . m))
(raise-argument-error 'matrix-row (format "Index < ~a" m) 1 a i)]
[else
(define proc (unsafe-array-proc a))
(array-default-strict
(unsafe-build-array
((inst vector Index) 1 n)
(λ: ([ij : Indexes])
(unsafe-vector-set! ij 0 i)
(define res (proc ij))
(unsafe-vector-set! ij 0 0)
res)))]))
(: matrix-col (All (A) (Matrix A) Integer -> (Matrix A)))
(define (matrix-col a j)
(define-values (m n) (matrix-shape a))
(cond [(or (j . < . 0) (j . >= . n))
(raise-argument-error 'matrix-row (format "Index < ~a" n) 1 a j)]
[else
(define proc (unsafe-array-proc a))
(array-default-strict
(unsafe-build-array
((inst vector Index) m 1)
(λ: ([ij : Indexes])
(unsafe-vector-set! ij 1 j)
(define res (proc ij))
(unsafe-vector-set! ij 1 0)
res)))]))
(: matrix-rows (All (A) (Matrix A) -> (Listof (Matrix A))))
(define (matrix-rows a)
(map (λ: ([a : (Matrix A)]) (array-default-strict a))
(parameterize ([array-strictness #f])
(array->array-list (array-axis-insert (ensure-matrix 'matrix-rows a) 1) 0))))
(: matrix-cols (All (A) (Matrix A) -> (Listof (Matrix A))))
(define (matrix-cols a)
(map (λ: ([a : (Matrix A)]) (array-default-strict a))
(parameterize ([array-strictness #f])
(array->array-list (array-axis-insert (ensure-matrix 'matrix-cols a) 2) 1))))
(: matrix-diagonal (All (A) ((Matrix A) -> (Array A))))
(define (matrix-diagonal a)
(define-values (m n) (matrix-shape a))
(define proc (unsafe-array-proc a))
(array-default-strict
(unsafe-build-array
((inst vector Index) (fxmin m n))
(λ: ([js : Indexes])
(define i (unsafe-vector-ref js 0))
(proc ((inst vector Index) i i))))))
(: matrix-upper-triangle (All (A) (case-> ((Matrix A) -> (Matrix (U A 0)))
((Matrix A) A -> (Matrix A)))))
(define matrix-upper-triangle
(case-lambda
[(M) (matrix-upper-triangle M 0)]
[(M zero)
(define-values (m n) (matrix-shape M))
(define proc (unsafe-array-proc M))
(array-default-strict
(unsafe-build-array
((inst vector Index) m n)
(λ: ([ij : Indexes])
(define i (unsafe-vector-ref ij 0))
(define j (unsafe-vector-ref ij 1))
(if (i . fx<= . j) (proc ij) zero))))]))
(: matrix-lower-triangle (All (A) (case-> ((Matrix A) -> (Matrix (U A 0)))
((Matrix A) A -> (Matrix A)))))
(define matrix-lower-triangle
(case-lambda
[(M) (matrix-lower-triangle M 0)]
[(M zero)
(define-values (m n) (matrix-shape M))
(define proc (unsafe-array-proc M))
(array-default-strict
(unsafe-build-array
((inst vector Index) m n)
(λ: ([ij : Indexes])
(define i (unsafe-vector-ref ij 0))
(define j (unsafe-vector-ref ij 1))
(if (i . fx>= . j) (proc ij) zero))))]))
;; ===================================================================================================
;; Embiggenment (this is a perfectly cromulent word)
(: matrix-augment (All (A) (Listof (Matrix A)) -> (Matrix A)))
(define (matrix-augment as)
(cond [(empty? as) (raise-argument-error 'matrix-augment "nonempty List" as)]
[else
(define m (matrix-num-rows (first as)))
(cond [(andmap (λ: ([a : (Matrix A)]) (= m (matrix-num-rows a))) (rest as))
(array-append* as 1)]
[else
(error 'matrix-augment
"matrices must have the same number of rows; given ~a"
(format-matrices/error as))])]))
(: matrix-stack (All (A) (Listof (Matrix A)) -> (Matrix A)))
(define (matrix-stack as)
(cond [(empty? as) (raise-argument-error 'matrix-stack "nonempty List" as)]
[else
(define n (matrix-num-cols (first as)))
(cond [(andmap (λ: ([a : (Matrix A)]) (= n (matrix-num-cols a))) (rest as))
(array-append* as 0)]
[else
(error 'matrix-stack
"matrices must have the same number of columns; given ~a"
(format-matrices/error as))])]))
;; ===================================================================================================
;; Inner product space (entrywise norm)
(: nonstupid-magnitude (case-> (Flonum -> Nonnegative-Flonum)
(Real -> Nonnegative-Real)
(Float-Complex -> Nonnegative-Flonum)
(Number -> Nonnegative-Real)))
(define (nonstupid-magnitude x)
(if (real? x) (abs x) (magnitude x)))
(: matrix-1norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum)
((Matrix Real) -> Nonnegative-Real)
((Matrix Float-Complex) -> Nonnegative-Flonum)
((Matrix Number) -> Nonnegative-Real)))
(define (matrix-1norm M)
(parameterize ([array-strictness #f])
(array-all-sum (inline-array-map nonstupid-magnitude M))))
(: matrix-2norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum)
((Matrix Real) -> Nonnegative-Real)
((Matrix Float-Complex) -> Nonnegative-Flonum)
((Matrix Number) -> Nonnegative-Real)))
(define (matrix-2norm M)
(parameterize ([array-strictness #f])
(let ([M (array-strict (inline-array-map nonstupid-magnitude M))])
;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max M))
(cond [(and (rational? mx) (positive? mx))
(* mx (sqrt (array-all-sum (inline-array-map (λ (x) (sqr (/ x mx))) M))))]
[else mx]))))
(: matrix-inf-norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum)
((Matrix Real) -> Nonnegative-Real)
((Matrix Float-Complex) -> Nonnegative-Flonum)
((Matrix Number) -> Nonnegative-Real)))
(define (matrix-inf-norm M)
(parameterize ([array-strictness #f])
(array-all-max (inline-array-map nonstupid-magnitude M))))
(: matrix-p-norm (case-> ((Matrix Flonum) Positive-Real -> Nonnegative-Flonum)
((Matrix Real) Positive-Real -> Nonnegative-Real)
((Matrix Float-Complex) Positive-Real -> Nonnegative-Flonum)
((Matrix Number) Positive-Real -> Nonnegative-Real)))
(define (matrix-p-norm M p)
(parameterize ([array-strictness #f])
(let ([M (array-strict (inline-array-map nonstupid-magnitude M))])
;; Compute this divided by the maximum to avoid underflow and overflow
(define mx (array-all-max M))
(cond [(and (rational? mx) (positive? mx))
(* mx (expt (array-all-sum (inline-array-map (λ (x) (expt (abs (/ x mx)) p)) M))
(/ p)))]
[else mx]))))
(: matrix-norm (case-> ((Matrix Flonum) -> Nonnegative-Flonum)
((Matrix Flonum) Real -> Nonnegative-Flonum)
((Matrix Real) -> Nonnegative-Real)
((Matrix Real) Real -> Nonnegative-Real)
((Matrix Float-Complex) -> Nonnegative-Flonum)
((Matrix Float-Complex) Real -> Nonnegative-Flonum)
((Matrix Number) -> Nonnegative-Real)
((Matrix Number) Real -> Nonnegative-Real)))
;; Computes the p norm of a matrix
(define (matrix-norm a [p 2])
(cond [(not (matrix? a)) (raise-argument-error 'matrix-norm "matrix?" 0 a p)]
[(p . = . 1) (matrix-1norm a)]
[(p . = . 2) (matrix-2norm a)]
[(p . = . +inf.0) (matrix-inf-norm a)]
[(p . > . 1) (matrix-p-norm a p)]
[else (raise-argument-error 'matrix-norm "Real >= 1" 1 a p)]))
(: matrix-dot (case-> ((Matrix Flonum) -> Nonnegative-Flonum)
((Matrix Flonum) (Matrix Flonum) -> Flonum)
((Matrix Real) -> Nonnegative-Real)
((Matrix Real) (Matrix Real) -> Real)
((Matrix Float-Complex) -> Nonnegative-Flonum)
((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex)
((Matrix Number) -> Nonnegative-Real)
((Matrix Number) (Matrix Number) -> Number)))
;; Computes the Frobenius inner product of a matrix with itself or of two matrices
(define matrix-dot
(case-lambda
[(a)
(parameterize ([array-strictness #f])
(assert
(array-all-sum
(inline-array-map
(λ (x) (* x (conjugate x)))
(ensure-matrix 'matrix-dot a)))
(make-predicate Nonnegative-Real)))]
[(a b)
(define-values (m n) (matrix-shapes 'matrix-dot a b))
(define aproc (unsafe-array-proc a))
(define bproc (unsafe-array-proc b))
(parameterize ([array-strictness #f])
(array-all-sum
(unsafe-build-array
((inst vector Index) m n)
(λ: ([js : Indexes])
(* (aproc js) (conjugate (bproc js)))))))]))
(: matrix-cos-angle (case-> ((Matrix Flonum) (Matrix Flonum) -> Flonum)
((Matrix Real) (Matrix Real) -> Real)
((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex)
((Matrix Number) (Matrix Number) -> Number)))
(define (matrix-cos-angle M N)
(/ (matrix-dot M N) (* (matrix-2norm M) (matrix-2norm N))))
(: matrix-angle (case-> ((Matrix Flonum) (Matrix Flonum) -> Flonum)
((Matrix Real) (Matrix Real) -> Real)
((Matrix Float-Complex) (Matrix Float-Complex) -> Float-Complex)
((Matrix Number) (Matrix Number) -> Number)))
(define (matrix-angle M N)
(acos (matrix-cos-angle M N)))
(: matrix-normalize
(All (A) (case-> ((Matrix Flonum) -> (Matrix Flonum))
((Matrix Flonum) Real -> (Matrix Flonum))
((Matrix Flonum) Real (-> A) -> (U A (Matrix Flonum)))
((Matrix Real) -> (Matrix Real))
((Matrix Real) Real -> (Matrix Real))
((Matrix Real) Real (-> A) -> (U A (Matrix Real)))
((Matrix Float-Complex) -> (Matrix Float-Complex))
((Matrix Float-Complex) Real -> (Matrix Float-Complex))
((Matrix Float-Complex) Real (-> A) -> (U A (Matrix Float-Complex)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number))))))
(define matrix-normalize
(case-lambda
[(M) (matrix-normalize M 2)]
[(M p) (matrix-normalize M p (λ () (raise-argument-error
'matrix-normalize "nonzero matrix?" 0 M p)))]
[(M p fail)
(array-strict! M)
(define x (matrix-norm M p))
(cond [(and (zero? x) (exact? x)) (fail)]
[else (matrix-scale M (/ x))])]))
;; ===================================================================================================
;; Operators
(: matrix-transpose (All (A) (Matrix A) -> (Matrix A)))
(define (matrix-transpose a)
(array-axis-swap (ensure-matrix 'matrix-transpose a) 0 1))
(: matrix-conjugate (case-> ((Matrix Flonum) -> (Matrix Flonum))
((Matrix Real) -> (Matrix Real))
((Matrix Float-Complex) -> (Matrix Float-Complex))
((Matrix Number) -> (Matrix Number))))
(define (matrix-conjugate a)
(array-conjugate (ensure-matrix 'matrix-conjugate a)))
(: matrix-hermitian (case-> ((Matrix Flonum) -> (Matrix Flonum))
((Matrix Real) -> (Matrix Real))
((Matrix Float-Complex) -> (Matrix Float-Complex))
((Matrix Number) -> (Matrix Number))))
(define (matrix-hermitian a)
(array-default-strict
(parameterize ([array-strictness #f])
(array-axis-swap (array-conjugate (ensure-matrix 'matrix-hermitian a)) 0 1))))
(: matrix-trace (case-> ((Matrix Flonum) -> Flonum)
((Matrix Real) -> Real)
((Matrix Float-Complex) -> Float-Complex)
((Matrix Number) -> Number)))
(define (matrix-trace a)
(cond [(square-matrix? a)
(parameterize ([array-strictness #f])
(array-all-sum (matrix-diagonal a)))]
[else
(raise-argument-error 'matrix-trace "square-matrix?" a)]))
;; ===================================================================================================
;; Row/column operations
(: matrix-map-rows
(All (A B F) (case-> (((Matrix A) -> (Matrix B)) (Matrix A) -> (Matrix B))
(((Matrix A) -> (U #f (Matrix B))) (Matrix A) (-> F)
-> (U F (Matrix B))))))
(define matrix-map-rows
(case-lambda
[(f M) (matrix-stack (map f (matrix-rows M)))]
[(f M fail)
(define ms (matrix-rows M))
(define n (f (first ms)))
(cond [n (let loop ([ms (rest ms)] [ns (list n)])
(cond [(empty? ms) (matrix-stack (reverse ns))]
[else (define n (f (first ms)))
(cond [n (loop (rest ms) (cons n ns))]
[else (fail)])]))]
[else (fail)])]))
(: matrix-map-cols
(All (A B F) (case-> (((Matrix A) -> (Matrix B)) (Matrix A) -> (Matrix B))
(((Matrix A) -> (U #f (Matrix B))) (Matrix A) (-> F)
-> (U F (Matrix B))))))
(define matrix-map-cols
(case-lambda
[(f M) (matrix-augment (map f (matrix-cols M)))]
[(f M fail)
(define ms (matrix-cols M))
(define n (f (first ms)))
(cond [n (let loop ([ms (rest ms)] [ns (list n)])
(cond [(empty? ms) (matrix-augment (reverse ns))]
[else (define n (f (first ms)))
(cond [n (loop (rest ms) (cons n ns))]
[else (fail)])]))]
[else (fail)])]))
(: make-matrix-normalize (Real -> (case-> ((Matrix Flonum) -> (U #f (Matrix Flonum)))
((Matrix Real) -> (U #f (Matrix Real)))
((Matrix Float-Complex) -> (U #f (Matrix Float-Complex)))
((Matrix Number) -> (U #f (Matrix Number))))))
(define ((make-matrix-normalize p) M)
(matrix-normalize M p (λ () #f)))
(: matrix-normalize-rows
(All (A) (case-> ((Matrix Flonum) -> (Matrix Flonum))
((Matrix Flonum) Real -> (Matrix Flonum))
((Matrix Flonum) Real (-> A) -> (U A (Matrix Flonum)))
((Matrix Real) -> (Matrix Real))
((Matrix Real) Real -> (Matrix Real))
((Matrix Real) Real (-> A) -> (U A (Matrix Real)))
((Matrix Float-Complex) -> (Matrix Float-Complex))
((Matrix Float-Complex) Real -> (Matrix Float-Complex))
((Matrix Float-Complex) Real (-> A) -> (U A (Matrix Float-Complex)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number))))))
(define matrix-normalize-rows
(case-lambda
[(M) (matrix-normalize-rows M 2)]
[(M p)
(define (fail) (raise-argument-error 'matrix-normalize-rows "matrix? with nonzero rows" 0 M p))
(matrix-normalize-rows M p fail)]
[(M p fail)
(matrix-map-rows (make-matrix-normalize p) M fail)]))
(: matrix-normalize-cols
(All (A) (case-> ((Matrix Flonum) -> (Matrix Flonum))
((Matrix Flonum) Real -> (Matrix Flonum))
((Matrix Flonum) Real (-> A) -> (U A (Matrix Flonum)))
((Matrix Real) -> (Matrix Real))
((Matrix Real) Real -> (Matrix Real))
((Matrix Real) Real (-> A) -> (U A (Matrix Real)))
((Matrix Float-Complex) -> (Matrix Float-Complex))
((Matrix Float-Complex) Real -> (Matrix Float-Complex))
((Matrix Float-Complex) Real (-> A) -> (U A (Matrix Float-Complex)))
((Matrix Number) -> (Matrix Number))
((Matrix Number) Real -> (Matrix Number))
((Matrix Number) Real (-> A) -> (U A (Matrix Number))))))
(define matrix-normalize-cols
(case-lambda
[(M) (matrix-normalize-cols M 2)]
[(M p)
(define (fail)
(raise-argument-error 'matrix-normalize-cols "matrix? with nonzero columns" 0 M p))
(matrix-normalize-cols M p fail)]
[(M p fail)
(matrix-map-cols (make-matrix-normalize p) M fail)]))
;; ===================================================================================================
(: pairwise-orthogonal? ((Listof (Matrix Number)) Nonnegative-Real -> Boolean))
(define (pairwise-orthogonal? Ms eps)
(define rows (list->vector Ms))
(define m (vector-length rows))
(let/ec: return : Boolean
(for*: ([i0 (in-range m)] [i1 (in-range (fx+ i0 1) m)])
(define r0 (unsafe-vector-ref rows i0))
(define r1 (unsafe-vector-ref rows i1))
(when ((magnitude (matrix-cos-angle r0 r1)) . >= . eps) (return #f)))
#t))
(: matrix-rows-orthogonal? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean)))
(define (matrix-rows-orthogonal? M [eps (* 10 epsilon.0)])
(cond [(negative? eps) (raise-argument-error 'matrix-rows-orthogonal? "Nonnegative-Real" 1 M eps)]
[else (parameterize ([array-strictness #f])
(pairwise-orthogonal? (matrix-rows M) eps))]))
(: matrix-cols-orthogonal? (case-> ((Matrix Number) -> Boolean)
((Matrix Number) Real -> Boolean)))
(define (matrix-cols-orthogonal? M [eps (* 10 epsilon.0)])
(cond [(negative? eps) (raise-argument-error 'matrix-cols-orthogonal? "Nonnegative-Real" 1 M eps)]
[else (parameterize ([array-strictness #f])
(pairwise-orthogonal? (matrix-cols M) eps))]))