racket/collects/math/private/array/array-broadcast.rkt
Neil Toronto 986e695bd5 Made arrays strict by default; please merge to release
* Added parameter `array-strictness', default #t

* Added `array-default-strict!' and `array-default-strict', which act
  like the functions without "default" in the name when
  `array-strictness' is #t; otherwise they do nothing

* Lots of small changes to existing array functions, mostly to ensure
  computations are done using nonstrict arrays, but return values are
  strict when `array-strictness' is #t

* Added strictness tests

* Added tests to ensure untyped code can use `math/array'

* Rewrote `array-map' exported to untyped code using untyped Racket

* Rearranged a lot of `math/array' documentation
2013-01-15 13:53:28 -07:00

109 lines
4.5 KiB
Racket

#lang typed/racket
(require racket/fixnum
"array-struct.rkt"
"../unsafe.rkt"
"utils.rkt")
(provide array-broadcasting
array-broadcast
array-shape-broadcast)
(: array-broadcasting (Parameterof (U #f #t 'permissive)))
(define array-broadcasting (make-parameter #t))
(: shift-stretch-axes (All (A) ((Array A) Indexes -> (Array A))))
(define (shift-stretch-axes arr new-ds)
(define old-ds (array-shape arr))
(define old-dims (vector-length old-ds))
(define new-dims (vector-length new-ds))
(define shift
(let ([shift (- new-dims old-dims)])
(cond [(index? shift) shift]
[else (error 'array-broadcast
"cannot broadcast to a lower-dimensional shape; given ~e and ~e"
arr new-ds)])))
(define old-js (make-thread-local-indexes old-dims))
(define old-f (unsafe-array-proc arr))
(unsafe-build-array
new-ds
(λ: ([new-js : Indexes])
(let ([old-js (old-js)])
(let: loop : A ([k : Nonnegative-Fixnum 0])
(cond [(k . < . old-dims)
(define new-jk (unsafe-vector-ref new-js (+ k shift)))
(define old-dk (unsafe-vector-ref old-ds k))
(define old-jk (unsafe-fxmodulo new-jk old-dk))
(unsafe-vector-set! old-js k old-jk)
(loop (+ k 1))]
[else (old-f old-js)]))))))
(: array-broadcast (All (A) ((Array A) Indexes -> (Array A))))
(define (array-broadcast arr ds)
(cond [(equal? ds (array-shape arr)) arr]
[else (define new-arr (shift-stretch-axes arr ds))
(if (or (array-strict? arr) ((array-size new-arr) . fx<= . (array-size arr)))
new-arr
(array-default-strict new-arr))]))
(: shape-insert-axes (Indexes Fixnum -> Indexes))
(define (shape-insert-axes ds n)
(vector-append ((inst make-vector Index) n 1) ds))
(: shape-permissive-broadcast (Indexes Indexes Index (-> Nothing) -> Indexes))
(define (shape-permissive-broadcast ds1 ds2 dims fail)
(define: new-ds : Indexes (make-vector dims 0))
(let loop ([#{k : Nonnegative-Fixnum} 0])
(cond [(k . < . dims)
(define dk1 (unsafe-vector-ref ds1 k))
(define dk2 (unsafe-vector-ref ds2 k))
(unsafe-vector-set!
new-ds k
(cond [(or (= dk1 0) (= dk2 0)) (fail)]
[else (fxmax dk1 dk2)]))
(loop (+ k 1))]
[else new-ds])))
(: shape-normal-broadcast (Indexes Indexes Index (-> Nothing) -> Indexes))
(define (shape-normal-broadcast ds1 ds2 dims fail)
(define: new-ds : Indexes (make-vector dims 0))
(let loop ([#{k : Nonnegative-Fixnum} 0])
(cond [(k . < . dims)
(define dk1 (unsafe-vector-ref ds1 k))
(define dk2 (unsafe-vector-ref ds2 k))
(unsafe-vector-set!
new-ds k
(cond [(= dk1 dk2) dk1]
[(and (= dk1 1) (dk2 . > . 0)) dk2]
[(and (= dk2 1) (dk1 . > . 0)) dk1]
[else (fail)]))
(loop (+ k 1))]
[else new-ds])))
(: shape-broadcast2 (Indexes Indexes (-> Nothing) (U #f #t 'permissive) -> Indexes))
(define (shape-broadcast2 ds1 ds2 fail broadcasting)
(cond [(equal? ds1 ds2) ds1]
[(not broadcasting) (fail)]
[else
(define dims1 (vector-length ds1))
(define dims2 (vector-length ds2))
(define n (- dims2 dims1))
(let-values ([(ds1 ds2 dims)
(cond [(n . > . 0) (values (shape-insert-axes ds1 n) ds2 dims2)]
[(n . < . 0) (values ds1 (shape-insert-axes ds2 (- n)) dims1)]
[else (values ds1 ds2 dims1)])])
(if (eq? broadcasting 'permissive)
(shape-permissive-broadcast ds1 ds2 dims fail)
(shape-normal-broadcast ds1 ds2 dims fail)))]))
(: array-shape-broadcast (case-> ((Listof Indexes) -> Indexes)
((Listof Indexes) (U #f #t 'permissive) -> Indexes)))
(define (array-shape-broadcast dss [broadcasting (array-broadcasting)])
(define (fail) (error 'array-shape-broadcast
"incompatible array shapes (array-broadcasting ~v): ~a"
broadcasting
(string-join (map (λ (ds) (format "~e" ds)) dss) ", ")))
(cond [(empty? dss) #()]
[else (for/fold ([new-ds (first dss)]) ([ds (in-list (rest dss))])
(shape-broadcast2 new-ds ds fail broadcasting))]))