261 lines
9.0 KiB
Racket
261 lines
9.0 KiB
Racket
#lang racket
|
|
(provide for/matrix
|
|
for*/matrix
|
|
in-row
|
|
in-column)
|
|
|
|
(require math/array
|
|
(except-in math/matrix in-row in-column))
|
|
|
|
;;; COMPREHENSIONS
|
|
|
|
; (for/matrix m n (clause ...) . defs+exprs)
|
|
; Return an m x n matrix with elements from the last expr.
|
|
; The first n values produced becomes the first row.
|
|
; The next n values becomes the second row and so on.
|
|
; The bindings in clauses run in parallel.
|
|
(define-syntax (for/matrix stx)
|
|
(syntax-case stx ()
|
|
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
|
(syntax/loc stx
|
|
(let ([m m-expr] [n n-expr])
|
|
(define flat-vector
|
|
(for/vector #:length (* m n)
|
|
(clause ...) . defs+exprs))
|
|
; TODO (efficiency): Use a flat-vector->array instead
|
|
(flat-vector->matrix m n flat-vector)))]))
|
|
|
|
; (for*/matrix m n (clause ...) . defs+exprs)
|
|
; Return an m x n matrix with elements from the last expr.
|
|
; The first n values produced becomes the first row.
|
|
; The next n values becomes the second row and so on.
|
|
; The bindings in clauses run nested.
|
|
; (for*/matrix m n #:column (clause ...) . defs+exprs)
|
|
; Return an m x n matrix with elements from the last expr.
|
|
; The first m values produced becomes the first column.
|
|
; The next m values becomes the second column and so on.
|
|
; The bindings in clauses run nested.
|
|
(define-syntax (for*/matrix stx)
|
|
(syntax-case stx ()
|
|
[(_ m-expr n-expr #:column (clause ...) . defs+exprs)
|
|
(syntax/loc stx
|
|
(let* ([m m-expr]
|
|
[n n-expr]
|
|
[v (make-vector (* m n) 0)]
|
|
[w (for*/vector #:length (* m n) (clause ...) . defs+exprs)])
|
|
(for* ([i (in-range m)] [j (in-range n)])
|
|
(vector-set! v (+ (* i n) j)
|
|
(vector-ref w (+ (* j m) i))))
|
|
(flat-vector->matrix m n v)))]
|
|
[(_ m-expr n-expr (clause ...) . defs+exprs)
|
|
(syntax/loc stx
|
|
(let ([m m-expr] [n n-expr])
|
|
(flat-vector->matrix
|
|
m n (for*/vector #:length (* m n) (clause ...) . defs+exprs))))]))
|
|
|
|
; TODO: The following is uncommented until matrix+ can be imported.
|
|
|
|
; (for/matrix-sum (clause ...) . defs+exprs)
|
|
; Return the matrix sum of all matrices produced by the last expr.
|
|
; The bindings in clauses are parallel.
|
|
|
|
;(define-syntax (for/matrix-sum stx)
|
|
; (syntax-case stx ()
|
|
; [(_ (clause ...) . defs+exprs)
|
|
; (syntax/loc stx
|
|
; (let ([ms (for/list (clause ...) . defs+exprs)])
|
|
; (foldl matrix+ (first ms) (rest ms))))]))
|
|
;
|
|
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
|
|
; (for/matrix-sum ([i 3]) M))
|
|
; (flat-vector->matrix 2 2 #(3 6 9 12)))
|
|
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
|
|
; (for/matrix-sum ([i 2] [j 2]) M))
|
|
; (flat-vector->matrix 2 2 #(2 4 6 8)))
|
|
|
|
; (for*/matrix-sum (clause ...) . defs+exprs)
|
|
; Return the matrix sum of all matrices produced by the last expr.
|
|
; The bindings in clauses are in nested.
|
|
|
|
;(define-syntax (for*/matrix-sum stx)
|
|
; (syntax-case stx ()
|
|
; [(_ (clause ...) . defs+exprs)
|
|
; (syntax/loc stx
|
|
; (let ([ms (for*/list (clause ...) . defs+exprs)])
|
|
; (foldl matrix+ (first ms) (rest ms))))]))
|
|
;
|
|
;(equal? (let ([M (flat-vector->matrix 2 2 #(1 2 3 4))])
|
|
; (for*/matrix-sum ([i 2] [j 2]) M))
|
|
; (flat-vector->matrix 2 2 #(4 8 12 16)))
|
|
|
|
|
|
;;;
|
|
;;; SEQUENCES
|
|
;;;
|
|
|
|
(define (in-row/proc M r)
|
|
(define-values (m n) (matrix-dimensions M))
|
|
(make-do-sequence
|
|
(λ ()
|
|
(values
|
|
; pos->element
|
|
(λ (j) (matrix-ref M r j))
|
|
; next-pos
|
|
(λ (j) (+ j 1))
|
|
; initial-pos
|
|
0
|
|
; continue-with-pos?
|
|
(λ (j) (< j n))
|
|
#f #f ))))
|
|
|
|
; (in-row M i]
|
|
; Returns a sequence of all elements of row i,
|
|
; that is xi0, xi1, xi2, ...
|
|
(define-sequence-syntax in-row
|
|
(λ () #'in-row/proc)
|
|
(λ (stx)
|
|
(syntax-case stx ()
|
|
[[(x) (_ M-expr r-expr)]
|
|
#'((x)
|
|
(:do-in
|
|
([(M r n d)
|
|
(let ([M1 M-expr])
|
|
(define-values (rd cd) (matrix-dimensions M1))
|
|
(values M1 r-expr rd
|
|
(mutable-array-data
|
|
(array->mutable-array M1))))])
|
|
(begin
|
|
(unless (array-matrix? M)
|
|
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
(unless (integer? r)
|
|
(raise-type-error 'in-row "expected row number, got ~a" r))
|
|
(unless (and (integer? r) (and (<= 0 r ) (< r n)))
|
|
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
|
([j 0])
|
|
(< j n)
|
|
([(x) (vector-ref d (+ (* r n) j))])
|
|
#true
|
|
#true
|
|
[(+ j 1)]))]
|
|
[[(i x) (_ M-expr r-expr)]
|
|
#'((i x)
|
|
(:do-in
|
|
([(M r n d)
|
|
(let ([M1 M-expr])
|
|
(define-values (rd cd) (matrix-dimensions M1))
|
|
(values M1 r-expr rd
|
|
(mutable-array-data
|
|
(array->mutable-array M1))))])
|
|
(begin
|
|
(unless (array-matrix? M)
|
|
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
(unless (integer? r)
|
|
(raise-type-error 'in-row "expected row number, got ~a" r)))
|
|
([j 0])
|
|
(< j n)
|
|
([(x) (vector-ref d (+ (* r n) j))]
|
|
[(i) j])
|
|
#true
|
|
#true
|
|
[(+ j 1)]))]
|
|
[[_ clause] (raise-syntax-error
|
|
'in-row "expected (in-row <matrix> <row>)" #'clause #'clause)])))
|
|
|
|
; (in-column M j]
|
|
; Returns a sequence of all elements of column j,
|
|
; that is x0j, x1j, x2j, ...
|
|
|
|
|
|
(define (in-column/proc M s)
|
|
(define-values (m n) (matrix-dimensions M))
|
|
(make-do-sequence
|
|
(λ ()
|
|
(values
|
|
; pos->element
|
|
(λ (i) (matrix-ref M i s))
|
|
; next-pos
|
|
(λ (i) (+ i 1))
|
|
; initial-pos
|
|
0
|
|
; continue-with-pos?
|
|
(λ (i) (< i m))
|
|
#f #f ))))
|
|
|
|
(define-sequence-syntax in-column
|
|
(λ () #'in-column/proc)
|
|
(λ (stx)
|
|
(syntax-case stx ()
|
|
[[(x) (_ M-expr s-expr)]
|
|
#'((x)
|
|
(:do-in
|
|
([(M s n m d)
|
|
(let ([M1 M-expr])
|
|
(define-values (rd cd) (matrix-dimensions M1))
|
|
(values M1 s-expr rd cd
|
|
(mutable-array-data
|
|
(array->mutable-array M1))))])
|
|
(begin
|
|
(unless (array-matrix? M)
|
|
(raise-type-error 'in-row "expected matrix, got ~a" M))
|
|
(unless (integer? s)
|
|
(raise-type-error 'in-row "expected col number, got ~a" s))
|
|
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
|
(raise-type-error 'in-col "expected col number, got ~a" s)))
|
|
([j 0])
|
|
(< j m)
|
|
([(x) (vector-ref d (+ (* j n) s))])
|
|
#true
|
|
#true
|
|
[(+ j 1)]))]
|
|
[[(i x) (_ M-expr s-expr)]
|
|
#'((x)
|
|
(:do-in
|
|
([(M s n m d)
|
|
(let ([M1 M-expr])
|
|
(define-values (rd cd) (matrix-dimensions M1))
|
|
(values M1 s-expr rd cd
|
|
(mutable-array-data
|
|
(array->mutable-array M1))))])
|
|
(begin
|
|
(unless (array-matrix? M)
|
|
(raise-type-error 'in-column "expected matrix, got ~a" M))
|
|
(unless (integer? s)
|
|
(raise-type-error 'in-column "expected col number, got ~a" s))
|
|
(unless (and (integer? s) (and (<= 0 s ) (< s m)))
|
|
(raise-type-error 'in-column "expected col number, got ~a" s)))
|
|
([j 0])
|
|
(< j m)
|
|
([(x) (vector-ref d (+ (* j n) s))]
|
|
[(i) j])
|
|
#true
|
|
#true
|
|
[(+ j 1)]))]
|
|
[[_ clause] (raise-syntax-error
|
|
'in-column "expected (in-column <matrix> <column>)" #'clause #'clause)])))
|
|
|
|
(module* test #f
|
|
(require (except-in math/matrix in-row in-column)
|
|
rackunit)
|
|
; "matrix-sequences.rkt"
|
|
; These work in racket not in typed racket
|
|
(check-equal? (matrix->list (for*/matrix 2 3 ([i 2] [j 3]) (+ i j)))
|
|
'[[0 1 2] [1 2 3]])
|
|
(check-equal? (matrix->list (for*/matrix 2 3 #:column ([i 2] [j 3]) (+ i j)))
|
|
'[[0 2 2] [1 1 3]])
|
|
(check-equal? (matrix->list (for*/matrix 2 2 #:column ([i 4]) i))
|
|
'[[0 2] [1 3]])
|
|
(check-equal? (matrix->list (for/matrix 2 2 ([i 4]) i))
|
|
'[[0 1] [2 3]])
|
|
(check-equal? (matrix->list (for/matrix 2 3 ([i 6] [j (in-range 6 12)]) (+ i j)))
|
|
'[[6 8 10] [12 14 16]])
|
|
(check-equal? (for/list ([x (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
|
'(3 4))
|
|
(check-equal? (for/list ([(i x) (in-row (flat-vector->matrix 2 2 #(1 2 3 4)) 1)])
|
|
(list i x))
|
|
'((0 3) (1 4)))
|
|
(check-equal? (for/list ([x (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)]) x)
|
|
'(2 4))
|
|
(check-equal? (for/list ([(i x) (in-column (flat-vector->matrix 2 2 #(1 2 3 4)) 1)])
|
|
(list i x))
|
|
'((0 2) (1 4))))
|