racket/collects/math/private/array/array-broadcast.rkt

103 lines
4.2 KiB
Racket

#lang typed/racket
(require racket/fixnum
"array-struct.rkt"
"../unsafe.rkt"
"utils.rkt")
(provide array-broadcasting
array-broadcast
array-shape-broadcast)
(: array-broadcasting (Parameterof (U #f #t 'permissive)))
(define array-broadcasting (make-parameter #t))
(: shift-stretch-axes (All (A) ((Array A) Indexes -> (Array A))))
(define (shift-stretch-axes arr new-ds)
(define old-ds (array-shape arr))
(define old-dims (vector-length old-ds))
(define new-dims (vector-length new-ds))
(define shift
(let ([shift (- new-dims old-dims)])
(cond [(index? shift) shift]
[else (error 'array-broadcast
"cannot broadcast to lower-dimensional array; given ~e and ~e"
arr new-ds)])))
(define old-js (make-thread-local-indexes old-dims))
(define old-f (unsafe-array-proc arr))
(unsafe-build-array
new-ds
(λ: ([new-js : Indexes])
(let ([old-js (old-js)])
(let: loop : A ([k : Nonnegative-Fixnum 0])
(cond [(k . < . old-dims)
(define new-jk (unsafe-vector-ref new-js (+ k shift)))
(define old-dk (unsafe-vector-ref old-ds k))
(define old-jk (unsafe-fxmodulo new-jk old-dk))
(unsafe-vector-set! old-js k old-jk)
(loop (+ k 1))]
[else (old-f old-js)]))))))
(: array-broadcast (All (A) ((Array A) Indexes -> (Array A))))
(define (array-broadcast arr ds)
(if (equal? ds (array-shape arr)) arr (shift-stretch-axes arr ds)))
(: shape-insert-axes (Indexes Fixnum -> Indexes))
(define (shape-insert-axes ds n)
(vector-append ((inst make-vector Index) n 1) ds))
(: shape-permissive-broadcast (Indexes Indexes Index (-> Nothing) -> Indexes))
(define (shape-permissive-broadcast ds1 ds2 dims fail)
(define: new-ds : Indexes (make-vector dims 0))
(let loop ([#{k : Nonnegative-Fixnum} 0])
(cond [(k . < . dims)
(define dk1 (unsafe-vector-ref ds1 k))
(define dk2 (unsafe-vector-ref ds2 k))
(unsafe-vector-set!
new-ds k
(cond [(or (= dk1 0) (= dk2 0)) (fail)]
[else (fxmax dk1 dk2)]))
(loop (+ k 1))]
[else new-ds])))
(: shape-normal-broadcast (Indexes Indexes Index (-> Nothing) -> Indexes))
(define (shape-normal-broadcast ds1 ds2 dims fail)
(define: new-ds : Indexes (make-vector dims 0))
(let loop ([#{k : Nonnegative-Fixnum} 0])
(cond [(k . < . dims)
(define dk1 (unsafe-vector-ref ds1 k))
(define dk2 (unsafe-vector-ref ds2 k))
(unsafe-vector-set!
new-ds k
(cond [(= dk1 dk2) dk1]
[(and (= dk1 1) (dk2 . > . 0)) dk2]
[(and (= dk2 1) (dk1 . > . 0)) dk1]
[else (fail)]))
(loop (+ k 1))]
[else new-ds])))
(: shape-broadcast2 (Indexes Indexes (-> Nothing) (U #f #t 'permissive) -> Indexes))
(define (shape-broadcast2 ds1 ds2 fail broadcasting)
(cond [(equal? ds1 ds2) ds1]
[(not broadcasting) (fail)]
[else
(define dims1 (vector-length ds1))
(define dims2 (vector-length ds2))
(define n (- dims2 dims1))
(let-values ([(ds1 ds2 dims)
(cond [(n . > . 0) (values (shape-insert-axes ds1 n) ds2 dims2)]
[(n . < . 0) (values ds1 (shape-insert-axes ds2 (- n)) dims1)]
[else (values ds1 ds2 dims1)])])
(if (eq? broadcasting 'permissive)
(shape-permissive-broadcast ds1 ds2 dims fail)
(shape-normal-broadcast ds1 ds2 dims fail)))]))
(: array-shape-broadcast (case-> ((Listof Indexes) -> Indexes)
((Listof Indexes) (U #f #t 'permissive) -> Indexes)))
(define (array-shape-broadcast dss [broadcasting (array-broadcasting)])
(define (fail) (error 'array-shape-broadcast "incompatible array shapes (broadcasting ~v): ~a"
broadcasting (string-join (map (λ (ds) (format "~e" ds)) dss) ", ")))
(cond [(empty? dss) #()]
[else (for/fold ([new-ds (first dss)]) ([ds (in-list (rest dss))])
(shape-broadcast2 new-ds ds fail broadcasting))]))