#lang typed/racket/base (require racket/fixnum racket/list racket/vector "matrix-types.rkt" "../unsafe.rkt" "../array/array-struct.rkt" "../array/array-constructors.rkt" "../array/array-unfold.rkt" "../array/utils.rkt") (provide identity-matrix make-matrix build-matrix diagonal-matrix/zero diagonal-matrix block-diagonal-matrix/zero block-diagonal-matrix vandermonde-matrix) ;; =================================================================================================== ;; Basic constructors (: identity-matrix (Integer -> (Matrix (U 0 1)))) (define (identity-matrix m) (diagonal-array 2 m 1 0)) (: make-matrix (All (A) (Integer Integer A -> (Matrix A)))) (define (make-matrix m n x) (make-array (vector m n) x)) (: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A)))) (define (build-matrix m n proc) (cond [(or (not (index? m)) (= m 0)) (raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)] [(or (not (index? n)) (= n 0)) (raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)] [else (unsafe-build-array ((inst vector Index) m n) (λ: ([js : Indexes]) (proc (unsafe-vector-ref js 0) (unsafe-vector-ref js 1))))])) ;; =================================================================================================== ;; Diagonal matrices (: diagonal-matrix/zero (All (A) ((Listof A) A -> (Matrix A)))) (define (diagonal-matrix/zero xs zero) (cond [(empty? xs) (raise-argument-error 'diagonal-matrix "nonempty List" xs)] [else (define vs (list->vector xs)) (define m (vector-length vs)) (unsafe-build-array ((inst vector Index) m m) (λ: ([js : Indexes]) (define i (unsafe-vector-ref js 0)) (cond [(= i (unsafe-vector-ref js 1)) (unsafe-vector-ref vs i)] [else zero])))])) (: diagonal-matrix (All (A) ((Listof A) -> (Matrix (U A 0))))) (define (diagonal-matrix xs) (diagonal-matrix/zero xs 0)) ;; =================================================================================================== ;; Block diagonal matrices (: block-diagonal-matrix/zero* (All (A) (Vectorof (Matrix A)) A -> (Matrix A))) (define (block-diagonal-matrix/zero* as zero) (define num (vector-length as)) (define-values (ms ns) (let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty] [ns : (Listof Index) empty] ) ([a (in-vector as)]) (define-values (m n) (matrix-shape a)) (values (cons m ms) (cons n ns)))]) (values (reverse ms) (reverse ns)))) (define res-m (assert (apply + ms) index?)) (define res-n (assert (apply + ns) index?)) (define vs ((inst make-vector Index) res-m 0)) (define hs ((inst make-vector Index) res-n 0)) (define is ((inst make-vector Index) res-m 0)) (define js ((inst make-vector Index) res-n 0)) (define-values (_res-i _res-j) (for/fold: ([res-i : Nonnegative-Fixnum 0] [res-j : Nonnegative-Fixnum 0] ) ([m (in-list ms)] [n (in-list ns)] [k : Nonnegative-Fixnum (in-range num)]) (let ([k (assert k index?)]) (for: ([i : Nonnegative-Fixnum (in-range m)]) (vector-set! vs (unsafe-fx+ res-i i) k) (vector-set! is (unsafe-fx+ res-i i) (assert i index?))) (for: ([j : Nonnegative-Fixnum (in-range n)]) (vector-set! hs (unsafe-fx+ res-j j) k) (vector-set! js (unsafe-fx+ res-j j) (assert j index?)))) (values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n)))) (define procs (vector-map (λ: ([a : (Matrix A)]) (unsafe-array-proc a)) as)) (unsafe-build-array ((inst vector Index) res-m res-n) (λ: ([ij : Indexes]) (define i (unsafe-vector-ref ij 0)) (define j (unsafe-vector-ref ij 1)) (define v (unsafe-vector-ref vs i)) (cond [(fx= v (vector-ref hs j)) (define proc (unsafe-vector-ref procs v)) (define iv (unsafe-vector-ref is i)) (define jv (unsafe-vector-ref js j)) (unsafe-vector-set! ij 0 iv) (unsafe-vector-set! ij 1 jv) (define res (proc ij)) (unsafe-vector-set! ij 0 i) (unsafe-vector-set! ij 1 j) res] [else zero])))) (: block-diagonal-matrix/zero (All (A) ((Listof (Matrix A)) A -> (Matrix A)))) (define (block-diagonal-matrix/zero as zero) (let ([as (list->vector as)]) (define num (vector-length as)) (cond [(= num 0) (raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)] [(= num 1) (unsafe-vector-ref as 0)] [else (block-diagonal-matrix/zero* as zero)]))) (: block-diagonal-matrix (All (A) ((Listof (Matrix A)) -> (Matrix (U A 0))))) (define (block-diagonal-matrix as) (block-diagonal-matrix/zero as 0)) ;; =================================================================================================== ;; Special matrices (: expt-hack (case-> (Real Integer -> Real) (Number Integer -> Number))) ;; Stop using this when TR correctly derives expt : Real Integer -> Real (define (expt-hack x n) (cond [(real? x) (assert (expt x n) real?)] [else (expt x n)])) (: vandermonde-matrix (case-> ((Listof Real) Integer -> (Matrix Real)) ((Listof Number) Integer -> (Matrix Number)))) (define (vandermonde-matrix xs n) (cond [(empty? xs) (raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)] [(or (not (index? n)) (zero? n)) (raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)] [else (array-axis-expand (list->array xs) 1 n expt-hack)]))