racket/collects/math/private/array/flarray-pointwise.rkt
Neil Toronto 5a43f2c6bc Finished array documentation!
Cleaned up other docs in preparation for alpha-testing announcement

Created `math/utils' module for stuff that doesn't go anywhere else (e.g.
FFT scaling convention, max-math-threads parameters)

Reduced the number of macros that expand to applications of `array-map'

Added `flvector-sum', defined `flsum' in terms of it

Reduced the number of pointwise `flvector', `flarray' and `fcarray' operations

Reworked `inline-build-flvector' and `inline-flvector-map' to be faster and
expand to less code in both typed and untyped Racket

Redefined conversions like `list->flvector' in terms of for loops (can do
it now that TR has working `for/flvector:', etc.)
2012-11-29 15:45:17 -07:00

139 lines
5.2 KiB
Racket

#lang typed/racket/base
(require racket/flonum
(for-syntax racket/base)
"../../flonum.rkt"
"array-struct.rkt"
"array-broadcast.rkt"
"array-pointwise.rkt"
"mutable-array.rkt"
"flarray-struct.rkt"
"utils.rkt")
(provide
;; Mapping
inline-flarray-map
flarray-map
;; Pointwise operations
flarray-scale
flarray-sqr
flarray-sqrt
flarray-abs
flarray+
flarray*
flarray-
flarray/
flarray-min
flarray-max)
;; ===================================================================================================
;; Mapping
(define-syntax (inline-flarray-map stx)
(syntax-case stx ()
[(_ f) (syntax/loc stx (unsafe-flarray #() (flvector (f))))]
[(_ f arr-expr)
(syntax/loc stx
(let: ([arr : FlArray arr-expr])
(unsafe-flarray (array-shape arr) (inline-flvector-map f (flarray-data arr)))))]
[(_ f arr-expr arr-exprs ...)
(with-syntax ([(arrs ...) (generate-temporaries #'(arr-exprs ...))]
[(dss ...) (generate-temporaries #'(arr-exprs ...))]
[(procs ...) (generate-temporaries #'(arr-exprs ...))])
(syntax/loc stx
(let: ([arr : FlArray arr-expr]
[arrs : FlArray arr-exprs] ...)
(define ds (array-shape arr))
(define dss (array-shape arrs)) ...
(cond [(and (equal? ds dss) ...)
(unsafe-flarray
ds (inline-flvector-map f (flarray-data arr) (flarray-data arrs) ...))]
[else
(define new-ds (array-shape-broadcast (list ds dss ...)))
(define proc (unsafe-array-proc (array-broadcast arr new-ds)))
(define procs (unsafe-array-proc (array-broadcast arrs new-ds))) ...
(array->flarray
(unsafe-build-array new-ds (λ: ([js : Indexes])
(f (proc js) (procs js) ...))))]))))]))
(: flarray-map (case-> ((-> Float) -> FlArray)
((Float -> Float) FlArray -> FlArray)
((Float Float Float * -> Float) FlArray FlArray FlArray * -> FlArray)))
(define flarray-map
(case-lambda:
[([f : (-> Float)])
(inline-flarray-map f)]
[([f : (Float -> Float)] [arr : FlArray])
(inline-flarray-map f arr)]
[([f : (Float Float -> Float)] [arr0 : FlArray] [arr1 : FlArray])
(inline-flarray-map f arr0 arr1)]
[([f : (Float Float Float * -> Float)] [arr0 : FlArray] [arr1 : FlArray] . [arrs : FlArray *])
(define ds (array-shape arr0))
(define dss (map (λ: ([arr : FlArray]) (array-shape arr)) (cons arr1 arrs)))
(define new-ds (array-shape-broadcast (list* ds dss)))
(let: ([arr0 : (Array Float) (array-broadcast arr0 new-ds)]
[arr1 : (Array Float) (array-broadcast arr1 new-ds)]
[arrs : (Listof (Array Float))
(map (λ: ([arr : FlArray]) (array-broadcast arr new-ds)) arrs)])
(define proc0 (unsafe-array-proc arr0))
(define proc1 (unsafe-array-proc arr1))
(define procs (map (λ: ([arr : (Array Float)]) (unsafe-array-proc arr)) arrs))
(array->flarray
(unsafe-build-array new-ds (λ: ([js : Indexes])
(apply f (proc0 js) (proc1 js)
(map (λ: ([proc : (Indexes -> Float)]) (proc js))
procs))))))]))
;; ===================================================================================================
;; Pointwise operations
(define-syntax-rule (lift-flvector1 f)
(λ (arr) (unsafe-flarray (array-shape arr) (f (flarray-data arr)))))
(define-syntax-rule (lift-flvector2 f array-f)
(λ (arr1 arr2)
(define ds1 (array-shape arr1))
(define ds2 (array-shape arr2))
(cond [(equal? ds1 ds2) (unsafe-flarray ds1 (f (flarray-data arr1) (flarray-data arr2)))]
[else (array->flarray (array-f arr1 arr2))])))
(: flarray-scale (FlArray Float -> FlArray))
(define (flarray-scale arr y)
(define-syntax-rule (fun xs) (flvector-scale xs y))
((lift-flvector1 fun) arr))
(: flarray-sqr (FlArray -> FlArray))
(define flarray-sqr (lift-flvector1 flvector-sqr))
(: flarray-sqrt (FlArray -> FlArray))
(define flarray-sqrt (lift-flvector1 flvector-sqrt))
(: flarray-abs (FlArray -> FlArray))
(define flarray-abs (lift-flvector1 flvector-abs))
(: flarray+ (FlArray FlArray -> FlArray))
(define flarray+ (lift-flvector2 flvector+ array+))
(: flarray* (FlArray FlArray -> FlArray))
(define flarray* (lift-flvector2 flvector* array*))
(: flarray- (case-> (FlArray -> FlArray)
(FlArray FlArray -> FlArray)))
(define flarray-
(case-lambda
[(arr) ((lift-flvector1 flvector-) arr)]
[(arr1 arr2) ((lift-flvector2 flvector- array-) arr1 arr2)]))
(: flarray/ (case-> (FlArray -> FlArray)
(FlArray FlArray -> FlArray)))
(define flarray/
(case-lambda
[(arr) ((lift-flvector1 flvector/) arr)]
[(arr1 arr2) ((lift-flvector2 flvector/ array/) arr1 arr2)]))
(: flarray-min (FlArray FlArray -> FlArray))
(define flarray-min (lift-flvector2 flvector-min array-min))
(: flarray-max (FlArray FlArray -> FlArray))
(define flarray-max (lift-flvector2 flvector-max array-max))