
* Finally added `array-axis-expand' as a dual for `array-axis-reduce' in order to implement `vandermonde-matrix' elegantly * Better, shorter matrix multiply; reworked all matrix arithmetic * Split "matrix-operations.rkt" into at least 5 parts: * "matrix-operations.rkt" * "matrix-basic.rkt" * "matrix-comprehension.rkt" * "matrix-sequences.rkt" * "matrix-column.rkt" Added "matrix-constructors.rkt" Added `matrix', `row-matrix', and `col-matrix' macros A lot of other little changes Currently, `in-row' and `in-column' are broken. I intend to implement them in a way that makes them work in untyped and Typed Racket.
377 lines
16 KiB
Racket
377 lines
16 KiB
Racket
#lang racket/base
|
|
|
|
(provide
|
|
;; Constructors
|
|
identity-matrix
|
|
make-matrix
|
|
build-matrix
|
|
diagonal-matrix/zero
|
|
diagonal-matrix
|
|
block-diagonal-matrix/zero
|
|
block-diagonal-matrix
|
|
vandermonde-matrix
|
|
;; Basic conversion
|
|
list->matrix
|
|
matrix->list
|
|
vector->matrix
|
|
matrix->vector
|
|
->row-matrix
|
|
->col-matrix
|
|
;; Nested conversion
|
|
list*->matrix
|
|
matrix->list*
|
|
vector*->matrix
|
|
matrix->vector*
|
|
;; Syntax
|
|
matrix
|
|
row-matrix
|
|
col-matrix)
|
|
|
|
(module typed-defs typed/racket/base
|
|
(require racket/fixnum
|
|
racket/list
|
|
racket/vector
|
|
math/array
|
|
"../array/utils.rkt"
|
|
"matrix-types.rkt"
|
|
"utils.rkt"
|
|
"../unsafe.rkt")
|
|
|
|
(provide (all-defined-out))
|
|
|
|
;; =================================================================================================
|
|
;; Constructors
|
|
|
|
(: identity-matrix (Integer -> (Matrix (U 0 1))))
|
|
(define (identity-matrix m) (diagonal-array 2 m 1 0))
|
|
|
|
(: make-matrix (All (A) (Integer Integer A -> (Matrix A))))
|
|
(define (make-matrix m n x)
|
|
(make-array (vector m n) x))
|
|
|
|
(: build-matrix (All (A) (Integer Integer (Index Index -> A) -> (Matrix A))))
|
|
(define (build-matrix m n proc)
|
|
(cond [(or (not (index? m)) (= m 0))
|
|
(raise-argument-error 'build-matrix "Positive-Index" 0 m n proc)]
|
|
[(or (not (index? n)) (= n 0))
|
|
(raise-argument-error 'build-matrix "Positive-Index" 1 m n proc)]
|
|
[else
|
|
(unsafe-build-array
|
|
((inst vector Index) m n)
|
|
(λ: ([js : Indexes])
|
|
(proc (unsafe-vector-ref js 0)
|
|
(unsafe-vector-ref js 1))))]))
|
|
|
|
(: diagonal-matrix/zero (All (A) (Array A) A -> (Array A)))
|
|
(define (diagonal-matrix/zero a zero)
|
|
(define ds (array-shape a))
|
|
(cond [(= 1 (vector-length ds))
|
|
(define m (unsafe-vector-ref ds 0))
|
|
(define proc (unsafe-array-proc a))
|
|
(unsafe-build-array
|
|
((inst vector Index) m m)
|
|
(λ: ([js : Indexes])
|
|
(define i (unsafe-vector-ref js 0))
|
|
(cond [(= i (unsafe-vector-ref js 1)) (proc ((inst vector Index) i))]
|
|
[else zero])))]
|
|
[else
|
|
(raise-argument-error 'diagonal-matrix "Array with one dimension" a)]))
|
|
|
|
(: diagonal-matrix (case-> ((Array Real) -> (Array Real))
|
|
((Array Number) -> (Array Number))))
|
|
(define (diagonal-matrix a)
|
|
(diagonal-matrix/zero a 0))
|
|
|
|
(: block-diagonal-matrix/zero* (All (A) (Vectorof (Array A)) A -> (Array A)))
|
|
(define (block-diagonal-matrix/zero* as zero)
|
|
(define num (vector-length as))
|
|
(define-values (ms ns)
|
|
(let-values ([(ms ns) (for/fold: ([ms : (Listof Index) empty]
|
|
[ns : (Listof Index) empty]
|
|
) ([a (in-vector as)])
|
|
(define-values (m n) (matrix-shape a))
|
|
(values (cons m ms) (cons n ns)))])
|
|
(values (reverse ms) (reverse ns))))
|
|
(define res-m (assert (apply + ms) index?))
|
|
(define res-n (assert (apply + ns) index?))
|
|
(define vs ((inst make-vector Index) res-m 0))
|
|
(define hs ((inst make-vector Index) res-n 0))
|
|
(define is ((inst make-vector Index) res-m 0))
|
|
(define js ((inst make-vector Index) res-n 0))
|
|
(define-values (_res-i _res-j)
|
|
(for/fold: ([res-i : Nonnegative-Fixnum 0]
|
|
[res-j : Nonnegative-Fixnum 0]
|
|
) ([m (in-list ms)]
|
|
[n (in-list ns)]
|
|
[k : Nonnegative-Fixnum (in-range num)])
|
|
(let ([k (assert k index?)])
|
|
(for: ([i : Nonnegative-Fixnum (in-range m)])
|
|
(vector-set! vs (unsafe-fx+ res-i i) k)
|
|
(vector-set! is (unsafe-fx+ res-i i) (assert i index?)))
|
|
(for: ([j : Nonnegative-Fixnum (in-range n)])
|
|
(vector-set! hs (unsafe-fx+ res-j j) k)
|
|
(vector-set! js (unsafe-fx+ res-j j) (assert j index?))))
|
|
(values (unsafe-fx+ res-i m) (unsafe-fx+ res-j n))))
|
|
(define procs (vector-map (λ: ([a : (Array A)]) (unsafe-array-proc a)) as))
|
|
(unsafe-build-array
|
|
((inst vector Index) res-m res-n)
|
|
(λ: ([ij : Indexes])
|
|
(define i (unsafe-vector-ref ij 0))
|
|
(define j (unsafe-vector-ref ij 1))
|
|
(define v (unsafe-vector-ref vs i))
|
|
(cond [(fx= v (vector-ref hs j))
|
|
(define proc (unsafe-vector-ref procs v))
|
|
(define iv (unsafe-vector-ref is i))
|
|
(define jv (unsafe-vector-ref js j))
|
|
(unsafe-vector-set! ij 0 iv)
|
|
(unsafe-vector-set! ij 1 jv)
|
|
(define res (proc ij))
|
|
(unsafe-vector-set! ij 0 i)
|
|
(unsafe-vector-set! ij 1 j)
|
|
res]
|
|
[else
|
|
zero]))))
|
|
|
|
(: block-diagonal-matrix/zero (All (A) (Listof (Array A)) A -> (Array A)))
|
|
(define (block-diagonal-matrix/zero as zero)
|
|
(let ([as (list->vector as)])
|
|
(define num (vector-length as))
|
|
(cond [(= num 0)
|
|
(raise-argument-error 'block-diagonal-matrix/zero "nonempty List" as)]
|
|
[(= num 1)
|
|
(unsafe-vector-ref as 0)]
|
|
[else
|
|
(block-diagonal-matrix/zero* as zero)])))
|
|
|
|
(: block-diagonal-matrix (case-> ((Listof (Array Real)) -> (Array Real))
|
|
((Listof (Array Number)) -> (Array Number))))
|
|
(define (block-diagonal-matrix as)
|
|
(block-diagonal-matrix/zero as 0))
|
|
|
|
(: expt-hack (case-> (Real Integer -> Real)
|
|
(Number Integer -> Number)))
|
|
;; Stop using this when TR correctly derives expt : Real Integer -> Real
|
|
(define (expt-hack x n)
|
|
(cond [(real? x) (assert (expt x n) real?)]
|
|
[else (expt x n)]))
|
|
|
|
(: vandermonde-matrix (case-> ((Listof Real) Integer -> (Array Real))
|
|
((Listof Number) Integer -> (Array Number))))
|
|
(define (vandermonde-matrix xs n)
|
|
(cond [(empty? xs)
|
|
(raise-argument-error 'vandermonde-matrix "nonempty List" 0 xs n)]
|
|
[(or (not (index? n)) (zero? n))
|
|
(raise-argument-error 'vandermonde-matrix "Positive-Index" 1 xs n)]
|
|
[else
|
|
(array-axis-expand (list->array xs) 1 n expt-hack)]))
|
|
|
|
;; =================================================================================================
|
|
;; Flat conversion
|
|
|
|
(: list->matrix (All (A) (Integer Integer (Listof A) -> (Array A))))
|
|
(define (list->matrix m n xs)
|
|
(cond [(or (not (index? m)) (= m 0))
|
|
(raise-argument-error 'list->matrix "Positive-Index" 0 m n xs)]
|
|
[(or (not (index? n)) (= n 0))
|
|
(raise-argument-error 'list->matrix "Positive-Index" 1 m n xs)]
|
|
[else (list->array (vector m n) xs)]))
|
|
|
|
(: matrix->list (All (A) ((Array A) -> (Listof A))))
|
|
(define (matrix->list a)
|
|
(array->list (ensure-matrix 'matrix->list a)))
|
|
|
|
(: vector->matrix (All (A) (Integer Integer (Vectorof A) -> (Mutable-Array A))))
|
|
(define (vector->matrix m n v)
|
|
(cond [(or (not (index? m)) (= m 0))
|
|
(raise-argument-error 'vector->matrix "Positive-Index" 0 m n v)]
|
|
[(or (not (index? n)) (= n 0))
|
|
(raise-argument-error 'vector->matrix "Positive-Index" 1 m n v)]
|
|
[else (vector->array (vector m n) v)]))
|
|
|
|
(: matrix->vector (All (A) ((Array A) -> (Vectorof A))))
|
|
(define (matrix->vector a)
|
|
(array->vector (ensure-matrix 'matrix->vector a)))
|
|
|
|
(: list->row-matrix (All (A) ((Listof A) -> (Array A))))
|
|
(define (list->row-matrix xs)
|
|
(cond [(empty? xs) (raise-argument-error 'list->row-matrix "nonempty List" xs)]
|
|
[else (list->array ((inst vector Index) 1 (length xs)) xs)]))
|
|
|
|
(: list->col-matrix (All (A) ((Listof A) -> (Array A))))
|
|
(define (list->col-matrix xs)
|
|
(cond [(empty? xs) (raise-argument-error 'list->col-matrix "nonempty List" xs)]
|
|
[else (list->array ((inst vector Index) (length xs) 1) xs)]))
|
|
|
|
(: vector->row-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
|
(define (vector->row-matrix xs)
|
|
(define n (vector-length xs))
|
|
(cond [(zero? n) (raise-argument-error 'vector->row-matrix "nonempty Vector" xs)]
|
|
[else (vector->array ((inst vector Index) 1 n) xs)]))
|
|
|
|
(: vector->col-matrix (All (A) ((Vectorof A) -> (Mutable-Array A))))
|
|
(define (vector->col-matrix xs)
|
|
(define n (vector-length xs))
|
|
(cond [(zero? n) (raise-argument-error 'vector->col-matrix "nonempty Vector" xs)]
|
|
[else (vector->array ((inst vector Index) n 1) xs)]))
|
|
|
|
(: find-nontrivial-axis ((Vectorof Index) -> (Values Index Index)))
|
|
(define (find-nontrivial-axis ds)
|
|
(define dims (vector-length ds))
|
|
(let: loop : (Values Index Index) ([k : Nonnegative-Fixnum 0])
|
|
(cond [(k . < . dims) (define dk (unsafe-vector-ref ds k))
|
|
(if (dk . > . 1) (values k dk) (loop (fx+ k 1)))]
|
|
[else (values 0 0)])))
|
|
|
|
(: array->row-matrix (All (A) ((Array A) -> (Array A))))
|
|
(define (array->row-matrix arr)
|
|
(define (fail)
|
|
(raise-argument-error 'array->row-matrix "nonempty Array with one axis of length >= 1" arr))
|
|
(define ds (array-shape arr))
|
|
(define dims (vector-length ds))
|
|
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
|
(cond [(zero? (array-size arr)) (fail)]
|
|
[(row-matrix? arr) arr]
|
|
[(= num-ones dims)
|
|
(define: js : (Vectorof Index) (make-vector dims 0))
|
|
(define proc (unsafe-array-proc arr))
|
|
(unsafe-build-array ((inst vector Index) 1 1)
|
|
(λ: ([ij : Indexes]) (proc js)))]
|
|
[(= num-ones (- dims 1))
|
|
(define-values (k n) (find-nontrivial-axis ds))
|
|
(define js (make-thread-local-indexes dims))
|
|
(define proc (unsafe-array-proc arr))
|
|
(unsafe-build-array ((inst vector Index) 1 n)
|
|
(λ: ([ij : Indexes])
|
|
(let ([js (js)])
|
|
(unsafe-vector-set! js k (unsafe-vector-ref ij 1))
|
|
(proc js))))]
|
|
[else (fail)]))
|
|
|
|
(: array->col-matrix (All (A) ((Array A) -> (Array A))))
|
|
(define (array->col-matrix arr)
|
|
(define (fail)
|
|
(raise-argument-error 'array->col-matrix "nonempty Array with one axis of length >= 1" arr))
|
|
(define ds (array-shape arr))
|
|
(define dims (vector-length ds))
|
|
(define num-ones (vector-count (λ: ([d : Index]) (= d 1)) ds))
|
|
(cond [(zero? (array-size arr)) (fail)]
|
|
[(col-matrix? arr) arr]
|
|
[(= num-ones dims)
|
|
(define: js : (Vectorof Index) (make-vector dims 0))
|
|
(define proc (unsafe-array-proc arr))
|
|
(unsafe-build-array ((inst vector Index) 1 1)
|
|
(λ: ([ij : Indexes]) (proc js)))]
|
|
[(= num-ones (- dims 1))
|
|
(define-values (k m) (find-nontrivial-axis ds))
|
|
(define js (make-thread-local-indexes dims))
|
|
(define proc (unsafe-array-proc arr))
|
|
(unsafe-build-array ((inst vector Index) m 1)
|
|
(λ: ([ij : Indexes])
|
|
(let ([js (js)])
|
|
(unsafe-vector-set! js k (unsafe-vector-ref ij 0))
|
|
(proc js))))]
|
|
[else (fail)]))
|
|
|
|
(: ->row-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
|
(define (->row-matrix xs)
|
|
(cond [(list? xs) (list->row-matrix xs)]
|
|
[(array? xs) (array->row-matrix xs)]
|
|
[else (vector->row-matrix xs)]))
|
|
|
|
(: ->col-matrix (All (A) ((U (Listof A) (Vectorof A) (Array A)) -> (Array A))))
|
|
(define (->col-matrix xs)
|
|
(cond [(list? xs) (list->col-matrix xs)]
|
|
[(array? xs) (array->col-matrix xs)]
|
|
[else (vector->col-matrix xs)]))
|
|
|
|
;; =================================================================================================
|
|
;; Nested conversion
|
|
|
|
(: list*-shape (All (A) (Listof (Listof A)) (-> Nothing) -> (Values Positive-Index Positive-Index)))
|
|
(define (list*-shape xss fail)
|
|
(define m (length xss))
|
|
(cond [(m . > . 0)
|
|
(define n (length (first xss)))
|
|
(cond [(and (n . > . 0) (andmap (λ: ([xs : (Listof A)]) (= n (length xs))) (rest xss)))
|
|
(values m n)]
|
|
[else (fail)])]
|
|
[else (fail)]))
|
|
|
|
(: vector*-shape (All (A) (Vectorof (Vectorof A)) (-> Nothing)
|
|
-> (Values Positive-Index Positive-Index)))
|
|
(define (vector*-shape xss fail)
|
|
(define m (vector-length xss))
|
|
(cond [(m . > . 0)
|
|
(define ns ((inst vector-map Index (Vectorof A)) vector-length xss))
|
|
(define n (vector-length (unsafe-vector-ref xss 0)))
|
|
(cond [(and (n . > . 0)
|
|
(let: loop : Boolean ([i : Nonnegative-Fixnum 1])
|
|
(cond [(i . fx< . m)
|
|
(if (= n (vector-length (unsafe-vector-ref xss i)))
|
|
(loop (fx+ i 1))
|
|
#f)]
|
|
[else #t])))
|
|
(values m n)]
|
|
[else (fail)])]
|
|
[else (fail)]))
|
|
|
|
(: list*->matrix (All (A) (Listof (Listof A)) -> (Matrix A)))
|
|
(define (list*->matrix xss)
|
|
(define (fail)
|
|
(raise-argument-error 'list*->matrix
|
|
"nested lists with rectangular shape and at least one matrix element"
|
|
xss))
|
|
(define-values (m n) (list*-shape xss fail))
|
|
(list->array ((inst vector Index) m n) (apply append xss)))
|
|
|
|
(: matrix->list* (All (A) (Matrix A) -> (Listof (Listof A))))
|
|
(define (matrix->list* a)
|
|
(cond [(matrix? a) (array->list (array->list-array a 1))]
|
|
[else (raise-argument-error 'matrix->list* "matrix?" a)]))
|
|
|
|
(: vector*->matrix (All (A) (Vectorof (Vectorof A)) -> (Mutable-Array A)))
|
|
(define (vector*->matrix xss)
|
|
(define (fail)
|
|
(raise-argument-error 'vector*->matrix
|
|
"nested vectors with rectangular shape and at least one matrix element"
|
|
xss))
|
|
(define-values (m n) (vector*-shape xss fail))
|
|
(vector->matrix m n (apply vector-append (vector->list xss))))
|
|
|
|
(: matrix->vector* : (All (A) (Matrix A) -> (Vectorof (Vectorof A))))
|
|
(define (matrix->vector* a)
|
|
(cond [(matrix? a) (array->vector ((inst array-axis-reduce A (Vectorof A)) a 1 build-vector))]
|
|
[else (raise-argument-error 'matrix->vector* "matrix?" a)]))
|
|
) ; module
|
|
|
|
(require (for-syntax racket/base
|
|
syntax/parse)
|
|
(only-in typed/racket/base :)
|
|
math/array
|
|
(submod "." typed-defs))
|
|
|
|
(define-syntax (matrix stx)
|
|
(syntax-parse stx #:literals (:)
|
|
[(_ [[x0 xs0 ...] [x xs ...] ...])
|
|
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...]))]
|
|
[(_ [[x0 xs0 ...] [x xs ...] ...] : T)
|
|
(syntax/loc stx (array #[#[x0 xs0 ...] #[x xs ...] ...] : T))]
|
|
[(_ [xs ... (~and [] r) ys ...] (~optional (~seq : T)))
|
|
(raise-syntax-error 'matrix "given empty row" stx #'r)]
|
|
[(_ (~and [] c) (~optional (~seq : T)))
|
|
(raise-syntax-error 'matrix "given empty matrix" stx #'c)]))
|
|
|
|
(define-syntax (row-matrix stx)
|
|
(syntax-parse stx #:literals (:)
|
|
[(_ [x xs ...]) (syntax/loc stx (array #[#[x xs ...]]))]
|
|
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x xs ...]] : T))]
|
|
[(_ (~and [] r) (~optional (~seq : T)))
|
|
(raise-syntax-error 'row-matrix "given empty row" stx #'r)]))
|
|
|
|
(define-syntax (col-matrix stx)
|
|
(syntax-parse stx #:literals (:)
|
|
[(_ [x xs ...]) (syntax/loc stx (array #[#[x] #[xs] ...]))]
|
|
[(_ [x xs ...] : T) (syntax/loc stx (array #[#[x] #[xs] ...] : T))]
|
|
[(_ (~and [] c) (~optional (~seq : T)))
|
|
(raise-syntax-error 'row-matrix "given empty column" stx #'c)]))
|