racket/collects/math/private/matrix/matrix-comprehension.rkt
Neil Toronto 8d5a069d41 Moar `math/matrix' review/refactoring
* Split "matrix-constructors.rkt" into three parts:
 * "matrix-constructors.rkt"
 * "matrix-conversion.rkt"
 * "matrix-syntax.rkt"

* Made `matrix-map' automatically inline (it's dirt simple)

* Renamed a few things, changed some type signatures

* Fixed error in `matrix-dot' caught by testing (it was broadcasting)

* Rewrote matrix comprehensions in terms of array comprehensions

* Removed `in-column' and `in-row' (can use `in-array', `matrix-col' and
  `matrix-row')

* Tons of new rackunit tests: only "matrix-2d.rkt" and
  "matrix-operations.rkt" are left (though the latter is large)
2012-12-20 17:32:16 -07:00

64 lines
2.4 KiB
Racket

#lang racket/base
(require (for-syntax racket/base
syntax/parse)
math/array)
(provide for/matrix:
for*/matrix:
for/matrix
for*/matrix)
(module typed-defs typed/racket/base
(require (for-syntax racket/base
syntax/parse)
math/array)
(provide (all-defined-out))
(: ensure-matrix-dims (Symbol Any Any -> (Values Positive-Index Positive-Index)))
(define (ensure-matrix-dims name m n)
(cond [(or (not (index? m)) (zero? m)) (raise-argument-error name "Positive-Index" 0 m n)]
[(or (not (index? n)) (zero? n)) (raise-argument-error name "Positive-Index" 1 m n)]
[else (values m n)]))
(define-syntax (base-for/matrix: stx)
(syntax-parse stx #:literals (:)
[(_ name:id for/array:id
m-expr:expr n-expr:expr
(~optional (~seq #:fill fill-expr:expr))
(clause ...)
(~optional (~seq : A:expr))
body:expr ...+)
(with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())]
[(maybe-type ...) (if (attribute A) #'(: A) #'())])
(syntax/loc stx
(let-values ([(m n) (ensure-matrix-dims 'name
(ann m-expr Integer)
(ann n-expr Integer))])
(for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...) maybe-type ...
body ...))))]))
(define-syntax-rule (for/matrix: e ...) (base-for/matrix: for/matrix: for/array: e ...))
(define-syntax-rule (for*/matrix: e ...) (base-for/matrix: for*/matrix: for*/array: e ...))
)
(require (submod "." typed-defs))
(define-syntax (base-for/matrix stx)
(syntax-parse stx
[(_ name:id for/array:id
m-expr:expr n-expr:expr
(~optional (~seq #:fill fill-expr:expr))
(clause ...)
body:expr ...+)
(with-syntax ([(maybe-fill ...) (if (attribute fill-expr) #'(#:fill fill-expr) #'())])
(syntax/loc stx
(let-values ([(m n) (ensure-matrix-dims 'name m-expr n-expr)])
(for/array #:shape (vector m-expr n-expr) maybe-fill ... (clause ...)
body ...))))]))
(define-syntax-rule (for/matrix e ...) (base-for/matrix for/matrix for/array e ...))
(define-syntax-rule (for*/matrix e ...) (base-for/matrix for*/matrix for*/array e ...))