Made arrays strict by default; please merge to release

* Added parameter `array-strictness', default #t

* Added `array-default-strict!' and `array-default-strict', which act
  like the functions without "default" in the name when
  `array-strictness' is #t; otherwise they do nothing

* Lots of small changes to existing array functions, mostly to ensure
  computations are done using nonstrict arrays, but return values are
  strict when `array-strictness' is #t

* Added strictness tests

* Added tests to ensure untyped code can use `math/array'

* Rewrote `array-map' exported to untyped code using untyped Racket

* Rearranged a lot of `math/array' documentation
This commit is contained in:
Neil Toronto 2013-01-15 13:23:08 -07:00
parent 131af9955d
commit 986e695bd5
21 changed files with 1507 additions and 759 deletions

View File

@ -8,7 +8,6 @@
"private/array/array-transform.rkt"
"private/array/array-convert.rkt"
"private/array/array-fold.rkt"
"private/array/array-special-folds.rkt"
"private/array/array-unfold.rkt"
"private/array/array-print.rkt"
"private/array/array-fft.rkt"
@ -36,7 +35,6 @@
"private/array/array-transform.rkt"
"private/array/array-convert.rkt"
"private/array/array-fold.rkt"
"private/array/array-special-folds.rkt"
"private/array/array-unfold.rkt"
"private/array/array-print.rkt"
"private/array/array-syntax.rkt"

View File

@ -40,7 +40,11 @@
(: array-broadcast (All (A) ((Array A) Indexes -> (Array A))))
(define (array-broadcast arr ds)
(if (equal? ds (array-shape arr)) arr (shift-stretch-axes arr ds)))
(cond [(equal? ds (array-shape arr)) arr]
[else (define new-arr (shift-stretch-axes arr ds))
(if (or (array-strict? arr) ((array-size new-arr) . fx<= . (array-size arr)))
new-arr
(array-default-strict new-arr))]))
(: shape-insert-axes (Indexes Fixnum -> Indexes))
(define (shape-insert-axes ds n)

View File

@ -51,8 +51,9 @@
[(= k (- dims 1))
(fcarray-last-axis-fft (array->fcarray arr))]
[else
(array-axis-swap (fcarray-last-axis-fft (array->fcarray (array-axis-swap arr k (- dims 1))))
k (- dims 1))]))
(parameterize ([array-strictness #f])
(array-axis-swap (fcarray-last-axis-fft (array->fcarray (array-axis-swap arr k (- dims 1))))
k (- dims 1)))]))
(: fcarray-fft (FCArray -> FCArray))
(define (fcarray-fft arr)

View File

@ -1,6 +1,9 @@
#lang racket/base
(require (for-syntax racket/base)
(only-in typed/racket/base assert index?)
"array-struct.rkt"
"array-pointwise.rkt"
"typed-array-fold.rkt")
;; ===================================================================================================
@ -28,6 +31,22 @@
(define-all-fold array-all-min min)
(define-all-fold array-all-max max)
(define-syntax-rule (array-count f arr ...)
(assert
(parameterize ([array-strictness #f])
(array-all-sum (inline-array-map (λ (b) (if b 1 0))
(array-map f arr ...))
0))
index?))
(define-syntax-rule (array-andmap pred? arr ...)
(parameterize ([array-strictness #f])
(array-all-and (array-map pred? arr ...))))
(define-syntax-rule (array-ormap pred? arr ...)
(parameterize ([array-strictness #f])
(array-all-or (array-map pred? arr ...))))
(provide array-axis-fold
array-axis-sum
array-axis-prod
@ -44,6 +63,9 @@
array-all-max
array-all-and
array-all-or
array-count
array-andmap
array-ormap
array-axis-reduce
unsafe-array-axis-reduce
array->list-array)

View File

@ -1,27 +0,0 @@
#lang typed/racket/base
(require "array-struct.rkt"
"array-fold.rkt"
"array-pointwise.rkt")
(provide array-count)
(: array-count
(All (A B T ...)
(case-> ((A -> Any) (Array A) -> Index)
((A B T ... T -> Any) (Array A) (Array B) (Array T) ... T -> Index))))
(define array-count
(case-lambda:
[([f : (A -> Any)] [arr0 : (Array A)])
(assert (array-all-sum (inline-array-map (λ: ([a : A]) (if (f a) 1 0)) arr0) 0) index?)]
[([f : (A B -> Any)] [arr0 : (Array A)] [arr1 : (Array B)])
(assert
(array-all-sum (inline-array-map (λ: ([a : A] [b : B]) (if (f a b) 1 0)) arr0 arr1) 0)
index?)]
[([f : (A B T ... T -> Any)] [arr0 : (Array A)] [arr1 : (Array B)] . [arrs : (Array T) ... T])
(assert
(array-all-sum (apply array-map
(λ: ([a : A] [b : B] . [ts : T ... T]) (if (apply f a b ts) 1 0))
arr0 arr1 arrs)
0)
index?)]))

View File

@ -6,14 +6,14 @@
"array-syntax.rkt"
(except-in "typed-array-struct.rkt"
build-array
build-strict-array
build-simple-array
list->array))
(require/untyped-contract
(begin (require "typed-array-struct.rkt"))
"typed-array-struct.rkt"
[build-array (All (A) ((Vectorof Integer) ((Vectorof Index) -> A) -> (Array A)))]
[build-strict-array (All (A) ((Vectorof Integer) ((Vectorof Index) -> A) -> (Array A)))]
[build-simple-array (All (A) ((Vectorof Integer) ((Vectorof Index) -> A) -> (Array A)))]
[list->array (All (A) (case-> ((Listof A) -> (Array A))
((Vectorof Integer) (Listof A) -> (Array A))))])
@ -29,15 +29,18 @@
array-shape
array-dims
array-size
array-strictness
array-strict
array-strict!
array-default-strict
array-default-strict!
array-strict?
build-array
build-strict-array
build-simple-array
list->array
make-unsafe-array-proc
unsafe-build-array
unsafe-build-strict-array
unsafe-build-simple-array
unsafe-list->array
unsafe-array-proc
array-lazy
@ -65,3 +68,8 @@
(let ([arr arr-expr])
(array-strict! arr)
arr))
(define-syntax-rule (array-default-strict arr-expr)
(let ([arr arr-expr])
(array-default-strict! arr)
arr))

View File

@ -33,7 +33,8 @@
(define (array-axis-expand arr k dk f)
(let ([k (check-array-axis 'array-axis-expand arr k)])
(cond [(not (index? dk)) (raise-argument-error 'array-axis-expand "Index" 2 arr k dk f)]
[else (unsafe-array-axis-expand arr k dk f)])))
[else (array-default-strict
(unsafe-array-axis-expand arr k dk f))])))
;; ===================================================================================================
;; Specific unfolds/expansions
@ -46,6 +47,7 @@
(let ([arr (array-strict (array-map (inst list->vector A) arr))])
;(define dks (remove-duplicates (array->list (array-map vector-length arr))))
(define dk (array-all-min (array-map vector-length arr)))
(unsafe-array-axis-expand arr k dk (inst unsafe-vector-ref A)))]
(array-default-strict
(unsafe-array-axis-expand arr k dk (inst unsafe-vector-ref A))))]
[else
(raise-argument-error 'list-array->array (format "Index <= ~a" dims) 1 arr k)]))

View File

@ -10,7 +10,7 @@
(define (make-array ds v)
(let ([ds (check-array-shape
ds (λ () (raise-argument-error 'make-array "(Vectorof Index)" 0 ds v)))])
(unsafe-build-strict-array ds (λ (js) v))))
(unsafe-build-simple-array ds (λ (js) v))))
(: axis-index-array (In-Indexes Integer -> (Array Index)))
(define (axis-index-array ds k)
@ -18,21 +18,21 @@
ds (λ () (raise-argument-error 'axis-index-array "(Vectorof Index)" 0 ds k)))]
[dims (vector-length ds)])
(cond [(and (0 . <= . k) (k . < . dims))
(unsafe-build-strict-array ds (λ: ([js : Indexes]) (unsafe-vector-ref js k)))]
(unsafe-build-simple-array ds (λ: ([js : Indexes]) (unsafe-vector-ref js k)))]
[else (raise-argument-error 'axis-index-array (format "Index < ~a" dims) 1 ds k)])))
(: index-array (In-Indexes -> (Array Index)))
(define (index-array ds)
(let ([ds (check-array-shape
ds (λ () (raise-argument-error 'index-array "(Vectorof Index)" ds)))])
(unsafe-build-strict-array ds (λ: ([js : Indexes])
(unsafe-build-simple-array ds (λ: ([js : Indexes])
(assert (unsafe-array-index->value-index ds js) index?)))))
(: indexes-array (In-Indexes -> (Array Indexes)))
(define (indexes-array ds)
(let ([ds (check-array-shape
ds (λ () (raise-argument-error 'indexes-array "(Vectorof Index)" ds)))])
(unsafe-build-strict-array ds (λ: ([js : Indexes]) (vector-copy-all js)))))
(unsafe-build-simple-array ds (λ: ([js : Indexes]) (vector-copy-all js)))))
(: diagonal-array (All (A) (Integer Integer A A -> (Array A))))
(define (diagonal-array dims size on-value off-value)
@ -42,15 +42,15 @@
(define: ds : Indexes (make-vector dims size))
;; specialize for various cases
(cond [(or (dims . <= . 1) (size . <= . 1))
(unsafe-build-strict-array ds (λ: ([js : Indexes]) on-value))]
(unsafe-build-simple-array ds (λ: ([js : Indexes]) on-value))]
[(= dims 2)
(unsafe-build-strict-array
(unsafe-build-simple-array
ds (λ: ([js : Indexes])
(define j0 (unsafe-vector-ref js 0))
(define j1 (unsafe-vector-ref js 1))
(if (= j0 j1) on-value off-value)))]
[else
(unsafe-build-strict-array
(unsafe-build-simple-array
ds (λ: ([js : Indexes])
(define j0 (unsafe-vector-ref js 0))
(let: loop : A ([i : Nonnegative-Fixnum 1])

View File

@ -38,15 +38,16 @@
(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B))))
(define (array-axis-reduce arr k f)
(let ([k (check-array-axis 'array-axis-reduce arr k)])
(unsafe-array-axis-reduce
arr k
(λ: ([dk : Index] [proc : (Index -> A)])
(: safe-proc (Integer -> A))
(define (safe-proc jk)
(cond [(or (jk . < . 0) (jk . >= . dk))
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
[else (proc jk)]))
(f dk safe-proc)))))
(array-default-strict
(unsafe-array-axis-reduce
arr k
(λ: ([dk : Index] [proc : (Index -> A)])
(: safe-proc (Integer -> A))
(define (safe-proc jk)
(cond [(or (jk . < . 0) (jk . >= . dk))
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
[else (proc jk)]))
(f dk safe-proc))))))
(: array-axis-fold/init (All (A B) ((Array A) Integer (A B -> B) B -> (Array B))))
(define (array-axis-fold/init arr k f init)
@ -72,8 +73,8 @@
((Array A) Integer (A B -> B) B -> (Array B)))))
(define array-axis-fold
(case-lambda
[(arr k f) (array-axis-fold/no-init arr k f)]
[(arr k f init) (array-axis-fold/init arr k f init)]))
[(arr k f) (array-default-strict (array-axis-fold/no-init arr k f))]
[(arr k f init) (array-default-strict (array-axis-fold/init arr k f init))]))
;; ===================================================================================================
;; Whole-array folds
@ -93,13 +94,18 @@
(define array-all-fold
(case-lambda
[(arr f)
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
(array-axis-fold arr k f)))
#())]
;; Though `f' is folded over multiple axes, each element of `arr' is referred to only once, so
;; turning strictness off can't hurt performance
(parameterize ([array-strictness #f])
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
(array-axis-fold arr k f)))
#()))]
[(arr f init)
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
(array-axis-fold arr k f init)))
#())]))
;; See above for why non-strictness is okay
(parameterize ([array-strictness #f])
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
(array-axis-fold arr k f init)))
#()))]))
) ; begin-encourage-inline
@ -109,13 +115,14 @@
(: array-axis-count (All (A) ((Array A) Integer (A -> Any) -> (Array Index))))
(define (array-axis-count arr k pred?)
(let ([k (check-array-axis 'array-axis-count arr k)])
(unsafe-array-axis-reduce
arr k (λ: ([dk : Index] [proc : (Index -> A)])
(let: loop : Index ([jk : Nonnegative-Fixnum 0] [acc : Nonnegative-Fixnum 0])
(if (jk . fx< . dk)
(cond [(pred? (proc jk)) (loop (fx+ jk 1) (unsafe-fx+ acc 1))]
[else (loop (fx+ jk 1) acc)])
(assert acc index?)))))))
(array-default-strict
(unsafe-array-axis-reduce
arr k (λ: ([dk : Index] [proc : (Index -> A)])
(let: loop : Index ([jk : Nonnegative-Fixnum 0] [acc : Nonnegative-Fixnum 0])
(if (jk . fx< . dk)
(cond [(pred? (proc jk)) (loop (fx+ jk 1) (unsafe-fx+ acc 1))]
[else (loop (fx+ jk 1) acc)])
(assert acc index?))))))))
;; ===================================================================================================
;; Short-cutting axis folds
@ -123,30 +130,36 @@
(: array-axis-and (All (A) ((Array A) Integer -> (Array (U A Boolean)))))
(define (array-axis-and arr k)
(let ([k (check-array-axis 'array-axis-and arr k)])
(unsafe-array-axis-reduce
arr k (λ: ([dk : Index] [proc : (Index -> A)])
(let: loop : (U A Boolean) ([jk : Nonnegative-Fixnum 0] [acc : (U A Boolean) #t])
(cond [(jk . fx< . dk) (define v (and acc (proc jk)))
(if v (loop (fx+ jk 1) v) v)]
[else acc]))))))
(array-default-strict
(unsafe-array-axis-reduce
arr k (λ: ([dk : Index] [proc : (Index -> A)])
(let: loop : (U A Boolean) ([jk : Nonnegative-Fixnum 0] [acc : (U A Boolean) #t])
(cond [(jk . fx< . dk) (define v (and acc (proc jk)))
(if v (loop (fx+ jk 1) v) v)]
[else acc])))))))
(: array-axis-or (All (A) ((Array A) Integer -> (Array (U A #f)))))
(define (array-axis-or arr k)
(let ([k (check-array-axis 'array-axis-or arr k)])
(unsafe-array-axis-reduce
arr k (λ: ([dk : Index] [proc : (Index -> A)])
(let: loop : (U A #f) ([jk : Nonnegative-Fixnum 0] [acc : (U A #f) #f])
(cond [(jk . fx< . dk) (define v (or acc (proc jk)))
(if v v (loop (fx+ jk 1) v))]
[else acc]))))))
(array-default-strict
(unsafe-array-axis-reduce
arr k (λ: ([dk : Index] [proc : (Index -> A)])
(let: loop : (U A #f) ([jk : Nonnegative-Fixnum 0] [acc : (U A #f) #f])
(cond [(jk . fx< . dk) (define v (or acc (proc jk)))
(if v v (loop (fx+ jk 1) v))]
[else acc])))))))
(: array-all-and (All (A B) ((Array A) -> (U A Boolean))))
(define (array-all-and arr)
(array-ref ((inst array-fold (U A Boolean)) arr array-axis-and) #()))
;; See `array-all-fold' for why non-strictness is okay
(parameterize ([array-strictness #f])
(array-ref ((inst array-fold (U A Boolean)) arr array-axis-and) #())))
(: array-all-or (All (A B) ((Array A) -> (U A #f))))
(define (array-all-or arr)
(array-ref ((inst array-fold (U A #f)) arr array-axis-or) #()))
;; See `array-all-fold' for why non-strictness is okay
(parameterize ([array-strictness #f])
(array-ref ((inst array-fold (U A #f)) arr array-axis-or) #())))
;; ===================================================================================================
;; Other folds
@ -156,6 +169,7 @@
(define (array->list-array arr [k 0])
(define dims (array-dims arr))
(cond [(and (k . >= . 0) (k . < . dims))
(unsafe-array-axis-reduce arr k (inst build-list A))]
(array-default-strict
(unsafe-array-axis-reduce arr k (inst build-list A)))]
[else
(raise-argument-error 'array->list-array (format "Index < ~a" dims) 1 arr k)]))

View File

@ -46,7 +46,8 @@
(define (array-indexes-ref arr idxs)
(define ds (array-shape idxs))
(define idxs-proc (unsafe-array-proc idxs))
(unsafe-build-array ds (λ: ([js : Indexes]) (array-ref arr (idxs-proc js)))))
(array-default-strict
(unsafe-build-array ds (λ: ([js : Indexes]) (array-ref arr (idxs-proc js))))))
(: array-indexes-set! (All (A) ((Settable-Array A) (Array In-Indexes) (Array A) -> Void)))
(define (array-indexes-set! arr idxs vals)
@ -232,14 +233,19 @@
(unless (= dims num-specs)
(error 'array-slice-ref "expected list with ~e slice specifications; given ~e in ~e"
dims num-specs orig-slices))
(let-values ([(arr jss) (slices->array-axis-transform 'array-slice-ref arr slices)])
(for/fold ([arr (unsafe-array-axis-transform arr jss)]) ([na (in-list new-axes)])
(match-define (cons k dk) na)
(array-axis-insert arr k dk)))))
(array-default-strict
(parameterize ([array-strictness #f])
(let-values ([(arr jss) (slices->array-axis-transform 'array-slice-ref arr slices)])
(for/fold: : (Array A) ([arr : (Array A) (unsafe-array-axis-transform arr jss)]
) ([na (in-list new-axes)])
(match-define (cons k dk) na)
(array-axis-insert arr k dk)))))))
(: array-slice-set! (All (A) ((Settable-Array A) (Listof Slice-Spec) (Array A) -> Void)))
(define (array-slice-set! arr slices vals)
(let ([idxs (array-slice-ref (indexes-array (array-shape arr)) slices)])
;; No reason to make `idxs' strict, since we build it ourselves and don't return it
(let ([idxs (parameterize ([array-strictness #f])
(array-slice-ref (indexes-array (array-shape arr)) slices))])
(array-indexes-set! arr idxs vals)))
;; ---------------------------------------------------------------------------------------------------

View File

@ -29,6 +29,7 @@
(define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1))
(define gs (map unsafe-array-proc arrs))
(unsafe-build-array
ds (λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> T)]) (g js)) gs)))))]))
(array-default-strict
(unsafe-build-array
ds (λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> T)]) (g js)) gs))))))]))

View File

@ -100,10 +100,11 @@
(let ([arrs (list->vector (map (λ: ([arr : (Array A)]) (array-broadcast arr ds)) arrs))])
(define dk (vector-length arrs))
(define new-ds (unsafe-vector-insert ds k dk))
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define jk (unsafe-vector-ref js k))
(let ([old-js (unsafe-vector-remove js k)])
((unsafe-array-proc (unsafe-vector-ref arrs jk)) old-js)))))]
(array-default-strict
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define jk (unsafe-vector-ref js k))
(let ([old-js (unsafe-vector-remove js k)])
((unsafe-array-proc (unsafe-vector-ref arrs jk)) old-js))))))]
[else
(error 'array-list->array (format "expected axis Index <= ~e; given ~e" dims k))]))

View File

@ -8,6 +8,9 @@
(provide (all-defined-out))
(: array-strictness (Parameterof (U #f #t)))
(define array-strictness (make-parameter #t))
;; ===================================================================================================
;; Equality and hashing
@ -74,6 +77,13 @@
((Array-strict! arr))
(set-box! strict? #t)))
(: array-default-strict! (All (A) ((Array A) -> Void)))
(define (array-default-strict! arr)
(define strict? (Array-strict? arr))
(when (and (not (unbox strict?)) (array-strictness))
((Array-strict! arr))
(set-box! strict? #t)))
(: unsafe-build-array (All (A) (Indexes (Indexes -> A) -> (Array A))))
(define (unsafe-build-array ds f)
;; This box's contents get replaced when the array we're constructing is made strict, so that
@ -94,29 +104,32 @@
(λ: ([js : Indexes]) ((unbox f) js)))
(Array ds size ((inst box Boolean) #f) strict! unsafe-proc)))
(: unsafe-build-strict-array (All (A) (Indexes (Indexes -> A) -> (Array A))))
(define (unsafe-build-strict-array ds f)
(define size (check-array-shape-size 'unsafe-build-strict-array ds))
(: unsafe-build-simple-array (All (A) (Indexes (Indexes -> A) -> (Array A))))
(define (unsafe-build-simple-array ds f)
(define size (check-array-shape-size 'unsafe-build-simple-array ds))
(Array ds size (box #t) void f))
(: build-array (All (A) (In-Indexes (Indexes -> A) -> (Array A))))
(define (build-array ds proc)
(let ([ds (check-array-shape
ds (λ () (raise-argument-error 'build-array "(Vectorof Index)" 0 ds proc)))])
(unsafe-build-array ds (λ: ([js : Indexes])
(proc (vector->immutable-vector js))))))
(define arr
(unsafe-build-array ds (λ: ([js : Indexes])
(proc (vector->immutable-vector js)))))
(array-default-strict! arr)
arr))
(: build-strict-array (All (A) (In-Indexes (Indexes -> A) -> (Array A))))
(define (build-strict-array ds proc)
(: build-simple-array (All (A) (In-Indexes (Indexes -> A) -> (Array A))))
(define (build-simple-array ds proc)
(let ([ds (check-array-shape
ds (λ () (raise-argument-error 'build-strict-array "(Vectorof Index)" 0 ds proc)))])
(unsafe-build-strict-array ds (λ: ([js : Indexes])
ds (λ () (raise-argument-error 'build-simple-array "(Vectorof Index)" 0 ds proc)))])
(unsafe-build-simple-array ds (λ: ([js : Indexes])
(proc (vector->immutable-vector js))))))
(: unsafe-list->array (All (A) (Indexes (Listof A) -> (Array A))))
(define (unsafe-list->array ds xs)
(define vs (list->vector xs))
(unsafe-build-strict-array
(unsafe-build-simple-array
ds (λ: ([js : Indexes]) (unsafe-vector-ref vs (unsafe-array-index->value-index ds js)))))
(: list->array (All (A) (case-> ((Listof A) -> (Array A))

View File

@ -35,16 +35,17 @@
1 arr perm)))])
(define dims (vector-length ds))
(define old-js (make-thread-local-indexes dims))
(unsafe-array-transform
arr ds
(λ: ([js : Indexes])
(let ([old-js (old-js)])
(let: loop : Indexes ([i : Nonnegative-Fixnum 0])
(cond [(i . < . dims) (unsafe-vector-set! old-js
(unsafe-vector-ref perm i)
(unsafe-vector-ref js i))
(loop (+ i 1))]
[else old-js])))))))
(array-default-strict
(unsafe-array-transform
arr ds
(λ: ([js : Indexes])
(let ([old-js (old-js)])
(let: loop : Indexes ([i : Nonnegative-Fixnum 0])
(cond [(i . < . dims) (unsafe-vector-set! old-js
(unsafe-vector-ref perm i)
(unsafe-vector-ref js i))
(loop (+ i 1))]
[else old-js]))))))))
(: array-axis-swap (All (A) ((Array A) Integer Integer -> (Array A))))
(define (array-axis-swap arr i0 i1)
@ -62,16 +63,17 @@
(unsafe-vector-set! new-ds i0 j1)
(unsafe-vector-set! new-ds i1 j0)
(define proc (unsafe-array-proc arr))
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define j0 (unsafe-vector-ref js i0))
(define j1 (unsafe-vector-ref js i1))
(unsafe-vector-set! js i0 j1)
(unsafe-vector-set! js i1 j0)
(define v (proc js))
(unsafe-vector-set! js i0 j0)
(unsafe-vector-set! js i1 j1)
v))]))
(array-default-strict
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define j0 (unsafe-vector-ref js i0))
(define j1 (unsafe-vector-ref js i1))
(unsafe-vector-set! js i0 j1)
(unsafe-vector-set! js i1 j0)
(define v (proc js))
(unsafe-vector-set! js i0 j0)
(unsafe-vector-set! js i1 j1)
v)))]))
;; ===================================================================================================
;; Adding/removing axes
@ -88,9 +90,8 @@
[else
(define new-ds (unsafe-vector-insert ds k dk))
(define proc (unsafe-array-proc arr))
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(proc (unsafe-vector-remove js k))))]))
(array-default-strict
(unsafe-build-array new-ds (λ: ([js : Indexes]) (proc (unsafe-vector-remove js k)))))]))
(: array-axis-ref (All (A) ((Array A) Integer Integer -> (Array A))))
(define (array-axis-ref arr k jk)
@ -104,9 +105,8 @@
[else
(define new-ds (unsafe-vector-remove ds k))
(define proc (unsafe-array-proc arr))
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(proc (unsafe-vector-insert js k jk))))]))
(array-default-strict
(unsafe-build-array new-ds (λ: ([js : Indexes]) (proc (unsafe-vector-insert js k jk)))))]))
;; ===================================================================================================
;; Reshape
@ -124,12 +124,13 @@
(define old-dims (vector-length old-ds))
(define g (unsafe-array-proc arr))
(define old-js (make-thread-local-indexes old-dims))
(unsafe-build-array
ds (λ: ([js : Indexes])
(let ([old-js (old-js)])
(define j (unsafe-array-index->value-index ds js))
(unsafe-value-index->array-index! old-ds j old-js)
(g old-js))))])))
(array-default-strict
(unsafe-build-array
ds (λ: ([js : Indexes])
(let ([old-js (old-js)])
(define j (unsafe-array-index->value-index ds js))
(unsafe-value-index->array-index! old-ds j old-js)
(g old-js)))))])))
(: array-flatten (All (A) ((Array A) -> (Array A))))
(define (array-flatten arr)
@ -141,12 +142,13 @@
(define old-dims (vector-length old-ds))
(define g (unsafe-array-proc arr))
(define old-js (make-thread-local-indexes old-dims))
(unsafe-build-array
ds (λ: ([js : Indexes])
(let ([old-js (old-js)])
(define j (unsafe-vector-ref js 0))
(unsafe-value-index->array-index! old-ds j old-js)
(g old-js))))]))
(array-default-strict
(unsafe-build-array
ds (λ: ([js : Indexes])
(let ([old-js (old-js)])
(define j (unsafe-vector-ref js 0))
(unsafe-value-index->array-index! old-ds j old-js)
(g old-js)))))]))
;; ===================================================================================================
;; Append
@ -200,11 +202,11 @@
(unsafe-vector-set! old-jks jk i)
(i-loop (+ i 1) (unsafe-fx+ jk 1))]
[else (arrs-loop (cdr arrs) (cdr dks) jk)]))))
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define jk (unsafe-vector-ref js k))
(unsafe-vector-set! js k (unsafe-vector-ref old-jks jk))
(define v ((unsafe-vector-ref old-procs jk) js))
(unsafe-vector-set! js k jk)
v))])))
(array-default-strict
(unsafe-build-array
new-ds (λ: ([js : Indexes])
(define jk (unsafe-vector-ref js k))
(unsafe-vector-set! js k (unsafe-vector-ref old-jks jk))
(define v ((unsafe-vector-ref old-procs jk) js))
(unsafe-vector-set! js k jk)
v)))])))

View File

@ -84,31 +84,31 @@
(raise-argument-error 'list*->array "rectangular (Listof* A)" lst))
(define ds (list-shape lst pred?))
(cond [(pred? lst) (unsafe-build-strict-array #() (λ (js) lst))]
(cond [(pred? lst) (unsafe-build-simple-array #() (λ (js) lst))]
[ds (let ([ds (check-array-shape ds raise-shape-error)])
(define size (array-shape-size ds))
(unsafe-vector->array ds (list*->flat-vector lst size pred?)))]
[else (raise-shape-error)]))
(: vector*->array (All (A) ((Vectorof* A) ((Vectorof* A) -> Any : A) -> (Array A))))
(: vector*->array (All (A) ((Vectorof* A) ((Vectorof* A) -> Any : A) -> (Mutable-Array A))))
(define (vector*->array vec pred?)
(define (raise-shape-error)
;; don't have to worry about non-Index size - can't fit in memory anyway
(raise-argument-error 'vector*->array "rectangular (Vectorof* A)" vec))
(define ds (vector-shape vec pred?))
(cond [(pred? vec) (unsafe-build-strict-array #() (λ (js) vec))]
(cond [(pred? vec) (array->mutable-array (unsafe-build-simple-array #() (λ (js) vec)))]
[ds (let ([ds (check-array-shape ds raise-shape-error)])
(define dims (vector-length ds))
(unsafe-build-array
ds (λ: ([js : Indexes])
(let: loop : A ([i : Nonnegative-Fixnum 0] [vec : (Vectorof* A) vec])
(cond [(pred? vec) vec]
[(i . < . dims)
(define j_i (unsafe-vector-ref js i))
(loop (+ i 1) (vector-ref vec j_i))]
[else (error 'vector*->array "internal error")]
)))))]
(array->mutable-array
(unsafe-build-array
ds (λ: ([js : Indexes])
(let: loop : A ([i : Nonnegative-Fixnum 0] [vec : (Vectorof* A) vec])
(cond [(pred? vec) vec]
[(i . < . dims)
(define j_i (unsafe-vector-ref js i))
(loop (+ i 1) (vector-ref vec j_i))]
[else (error 'vector*->array "internal error")]))))))]
[else (raise-shape-error)]))
) ; begin
) ; make-conversion-functions

View File

@ -17,13 +17,17 @@
(define-syntax (inline-array-map stx)
(syntax-case stx ()
[(_ f) (syntax/loc stx (unsafe-build-array #() (λ (js) (f))))]
[(_ f)
(syntax/loc stx
(array-default-strict
(unsafe-build-array #() (λ (js) (f)))))]
[(_ f arr-expr)
(syntax/loc stx
(let ([arr (ensure-array 'array-map arr-expr)])
(define ds (array-shape arr))
(define proc (unsafe-array-proc arr))
(unsafe-build-array ds (λ: ([js : Indexes]) (f (proc js))))))]
(array-default-strict
(unsafe-build-array ds (λ: ([js : Indexes]) (f (proc js)))))))]
[(_ f arr-expr arr-exprs ...)
(with-syntax ([(arrs ...) (generate-temporaries #'(arr-exprs ...))]
[(procs ...) (generate-temporaries #'(arr-exprs ...))])
@ -35,42 +39,44 @@
[arrs (array-broadcast arrs ds)] ...)
(define proc (unsafe-array-proc arr))
(define procs (unsafe-array-proc arrs)) ...
(unsafe-build-array ds (λ: ([js : Indexes]) (f (proc js) (procs js) ...)))))))])))
(array-default-strict
(unsafe-build-array ds (λ: ([js : Indexes]) (f (proc js) (procs js) ...))))))))])))
(require 'syntax-defs)
(module untyped-defs typed/racket/base
(require "array-struct.rkt"
(module untyped-defs racket/base
(require racket/contract
"array-struct.rkt"
"array-broadcast.rkt"
"utils.rkt"
(submod ".." syntax-defs))
(provide array-map)
(provide (contract-out
[array-map (->i ([f (unconstrained-domain-> any/c)])
#:rest [xs (listof array?)]
#:pre/name (f xs)
"function has the wrong arity"
(procedure-arity-includes? f (length xs))
[_ array?])]))
(: array-map (All (R A) (case-> ((-> R) -> (Array R))
((A -> R) (Array A) -> (Array R))
((A A A * -> R) (Array A) (Array A) (Array A) * -> (Array R)))))
(define array-map
(case-lambda:
[([f : (-> R)])
(inline-array-map f)]
[([f : (A -> R)] [arr : (Array A)])
(inline-array-map f arr)]
[([f : (A A -> R)] [arr0 : (Array A)] [arr1 : (Array A)])
(inline-array-map f arr0 arr1)]
[([f : (A A A * -> R)] [arr0 : (Array A)] [arr1 : (Array A)] . [arrs : (Array A) *])
(case-lambda
[(f) (inline-array-map f)]
[(f arr) (inline-array-map f arr)]
[(f arr0 arr1) (inline-array-map f arr0 arr1)]
[(f arr0 arr1 . arrs)
(define ds (array-shape-broadcast (list* (array-shape arr0)
(array-shape arr1)
(map (inst array-shape A) arrs))))
(map array-shape arrs))))
(let ([arr0 (array-broadcast arr0 ds)]
[arr1 (array-broadcast arr1 ds)]
[arrs (map (λ: ([arr : (Array A)]) (array-broadcast arr ds)) arrs)])
[arrs (map (λ (arr) (array-broadcast arr ds)) arrs)])
(define g0 (unsafe-array-proc arr0))
(define g1 (unsafe-array-proc arr1))
(define gs (map (inst unsafe-array-proc A) arrs))
(unsafe-build-array
ds (λ: ([js : Indexes]) (apply f (g0 js) (g1 js)
(map (λ: ([g : (Indexes -> A)]) (g js)) gs)))))]))
(define gs (map unsafe-array-proc arrs))
(array-default-strict
(unsafe-build-array
ds (λ (js) (apply f (g0 js) (g1 js) (map (λ (g) (g js)) gs))))))]))
)
(require 'untyped-defs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,187 @@
#lang typed/racket
(require math/array
typed/rackunit)
(define (check-always)
(printf "(array-strictness) = ~v~n" (array-strictness))
(check-true (array-strict? (make-array #(4 4) 0)))
(check-true (array-strict? (indexes-array #(4 4))))
(check-true (array-strict? (index-array #(4 4))))
(check-true (array-strict? (axis-index-array #(4 4) 0)))
(check-true (array-strict? (diagonal-array 2 6 1 0)))
(check-true (array-strict? (list->array '(1 2 3 4))))
(check-true (array-strict? (list->array #(2 2) '(1 2 3 4))))
(check-true (array-strict? (list*->array 0 exact-integer?)))
(check-true (array-strict? (list*->array '(1 2 3 4) exact-integer?)))
(check-true (array-strict? (list*->array '((1 2) (3 4)) exact-integer?)))
(check-true (array-strict? (vector->array #(1 2 3 4))))
(check-true (array-strict? (vector->array #(2 2) #(1 2 3 4))))
(check-true (array-strict? (vector*->array 0 exact-integer?)))
(check-true (array-strict? ((inst vector*->array Integer) #(1 2 3 4) exact-integer?)))
(check-true (array-strict? ((inst vector*->array Integer) #(#(1 2) #(3 4)) exact-integer?)))
(check-true (array-strict? (build-simple-array #(4 4) (λ (_) 0))))
(check-false (array-strict? (array-lazy (build-simple-array #(4 4) (λ (_) 0)))))
)
(define nonstrict-2x2-arr
(parameterize ([array-strictness #f])
(build-array #(2 2) (λ (_) 0))))
(define strict-2x2-arr
(parameterize ([array-strictness #t])
(build-array #(2 2) (λ (_) 0))))
(define 2x2-indexes-arr
(array #['#(0 0) '#(1 1)]))
(check-false (array-strict? nonstrict-2x2-arr))
(check-true (array-strict? strict-2x2-arr))
(parameterize ([array-strictness #t])
(check-always)
(check-true (array-strict? (array-list->array (list))))
(check-true (array-strict? (array-list->array (list (array #[0 1])))))
(check-true (array-strict? (array-list->array (list (array #[0 1]) (array #[2 3])))))
(check-true (andmap (inst array-strict? Integer) (sequence->list (in-array-axis (array #[0 1])))))
(check-false (array-strict? (array-broadcast nonstrict-2x2-arr ((inst vector Index) 2 2))))
(check-true (array-strict? (array-broadcast nonstrict-2x2-arr ((inst vector Index) 2 4))))
(check-true (array-strict? (array-broadcast strict-2x2-arr ((inst vector Index) 2 2))))
(check-false (array-strict? (array-broadcast strict-2x2-arr ((inst vector Index) 2 4))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array-indexes-ref arr 2x2-indexes-arr))))
(for*: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[spec (list '(0) 0)])
(check-true (array-strict? (array-slice-ref arr (list (::) spec))))
(check-true (array-strict? (array-slice-ref arr (list (::) (::new 2) spec)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array-transform arr #(2 2)
(λ: ([js : Indexes])
(vector (vector-ref js 1) (vector-ref js 0)))))))
(for: ([k (list 0 1)])
(check-true (array-strict? (array-append* (list nonstrict-2x2-arr strict-2x2-arr) k))))
(for*: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)]
[dk (list 0 1 2)])
(check-true (array-strict? (array-axis-insert arr k dk))))
(for*: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)]
[jk (list 0 1)])
(check-true (array-strict? (array-axis-ref arr k jk))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array-axis-swap arr 0 1))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array-axis-permute arr '(1 0)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array-reshape arr #(4)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array-flatten arr))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-true (array-strict? (array-axis-sum arr k))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-true (array-strict? (array-axis-count arr k even?))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-true (array-strict? (array-axis-and (array-map even? arr) k))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-true (array-strict? (array-axis-or (array-map even? arr) k))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array->list-array arr))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (list-array->array (array->list-array arr)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-true (array-strict? (array-fold arr (inst array->list-array (Listof* Integer))))))
)
(parameterize ([array-strictness #f])
(check-always)
(check-false (array-strict? (array-list->array (list))))
(check-false (array-strict? (array-list->array (list (array #[0 1])))))
(check-false (array-strict? (array-list->array (list (array #[0 1]) (array #[2 3])))))
(check-false (ormap (inst array-strict? Integer) (sequence->list (in-array-axis (array #[0 1])))))
(check-false (array-strict? (array-broadcast nonstrict-2x2-arr ((inst vector Index) 2 2))))
(check-false (array-strict? (array-broadcast nonstrict-2x2-arr ((inst vector Index) 2 4))))
(check-true (array-strict? (array-broadcast strict-2x2-arr ((inst vector Index) 2 2))))
(check-false (array-strict? (array-broadcast strict-2x2-arr ((inst vector Index) 2 4))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array-indexes-ref arr 2x2-indexes-arr))))
(for*: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[spec (list '(0) 0)])
(check-false (array-strict? (array-slice-ref arr (list (::) spec))))
(check-false (array-strict? (array-slice-ref arr (list (::) (::new 2) spec)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array-transform arr #(2 2)
(λ: ([js : Indexes])
(vector (vector-ref js 1) (vector-ref js 0)))))))
(for: ([k (list 0 1)])
(check-false (array-strict? (array-append* (list nonstrict-2x2-arr strict-2x2-arr) k))))
(for*: ([arr (list nonstrict-2x2-arr strict-2x2-arr)] [k (list 0 1)] [dk (list 0 1 2)])
(check-false (array-strict? (array-axis-insert arr k dk))))
(for*: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)]
[jk (list 0 1)])
(check-false (array-strict? (array-axis-ref arr k jk))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array-axis-swap arr 0 1))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array-axis-permute arr '(1 0)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array-reshape arr #(4)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array-flatten arr))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-false (array-strict? (array-axis-sum arr k))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-false (array-strict? (array-axis-count arr k even?))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-false (array-strict? (array-axis-and (array-map even? arr) k))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)]
[k (list 0 1)])
(check-false (array-strict? (array-axis-or (array-map even? arr) k))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array->list-array arr))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (list-array->array (array->list-array arr)))))
(for: ([arr (list nonstrict-2x2-arr strict-2x2-arr)])
(check-false (array-strict? (array-fold arr (inst array->list-array (Listof* Integer))))))
)

View File

@ -25,12 +25,6 @@
(define-syntax-rule (array-axis-ormap arr k pred?)
(array-axis-or (array-map pred? arr) k))
(define-syntax-rule (array-all-andmap arr pred?)
(array-all-and (array-map pred? arr)))
(define-syntax-rule (array-all-ormap arr pred?)
(array-all-or (array-map pred? arr)))
;; ---------------------------------------------------------------------------------------------------
;; array-mutable
@ -370,32 +364,32 @@
(let ([arr (array #[#[1.0 1.0 2.0 3.0] #[0.0 -1.0 2.0 3.0]])])
(check-equal? (array-axis-andmap arr 0 positive?) (array #[#f #f #t #t]))
(check-equal? (array-axis-andmap arr 1 positive?) (array #[#t #f]))
(check-equal? (array-all-andmap arr positive?) #f))
(check-equal? (array-andmap positive? arr) #f))
(let ([arr (array #[#[1.0 1.0 2.0 3.0] #[2.0 3.0 2.0 3.0]])])
(check-equal? (array-axis-andmap arr 0 positive?) (array #[#t #t #t #t]))
(check-equal? (array-axis-andmap arr 1 positive?) (array #[#t #t]))
(check-equal? (array-all-andmap arr positive?) #t))
(check-equal? (array-andmap positive? arr) #t))
(let ([arr (array #[#[-1.0 -1.0 -2.0 -3.0] #[0.0 -1.0 2.0 3.0]])])
(check-equal? (array-axis-ormap arr 0 positive?) (array #[#f #f #t #t]))
(check-equal? (array-axis-ormap arr 1 positive?) (array #[#f #t]))
(check-equal? (array-all-ormap arr positive?) #t))
(check-equal? (array-ormap positive? arr) #t))
(let ([arr (array #[#[-1.0 -1.0 -2.0 -3.0] #[-2.0 -3.0 -2.0 -3.0]])])
(check-equal? (array-axis-ormap arr 0 positive?) (array #[#f #f #f #f]))
(check-equal? (array-axis-ormap arr 1 positive?) (array #[#f #f]))
(check-equal? (array-all-ormap arr positive?) #f))
(check-equal? (array-ormap positive? arr) #f))
(let ([arr (make-array #() 0.0)])
(check-equal? (array-count positive? arr) 0)
(check-equal? (array-all-andmap arr positive?) #f)
(check-equal? (array-all-ormap arr positive?) #f))
(check-equal? (array-andmap positive? arr) #f)
(check-equal? (array-ormap positive? arr) #f))
(let ([arr (make-array #() 1.0)])
(check-equal? (array-count positive? arr) 1)
(check-equal? (array-all-andmap arr positive?) #t)
(check-equal? (array-all-ormap arr positive?) #t))
(check-equal? (array-andmap positive? arr) #t)
(check-equal? (array-ormap positive? arr) #t))
(let ([arr (make-array #(4 0) 0.0)])
(check-equal? (array-axis-count arr 0 positive?) (array #[]))
@ -405,8 +399,8 @@
(check-equal? (array-axis-andmap arr 1 positive?) (array #[#t #t #t #t]))
(check-equal? (array-axis-ormap arr 1 positive?) (array #[#f #f #f #f]))
(check-equal? (array-count positive? arr) 0)
(check-equal? (array-all-andmap arr positive?) #t)
(check-equal? (array-all-ormap arr positive?) #f))
(check-equal? (array-andmap positive? arr) #t)
(check-equal? (array-ormap positive? arr) #f))
;; ---------------------------------------------------------------------------------------------------
;; FFT
@ -982,7 +976,7 @@
(: arr (Array Integer))
(define arr
(array-lazy
(build-array
(build-simple-array
#(12 12)
(λ: ([js : Indexes])
(match-define (vector j0 j1) js)

View File

@ -0,0 +1,292 @@
#lang racket
(require (for-syntax racket/match)
rackunit
math/array)
;; ===================================================================================================
;; Contract tests
(begin-for-syntax
(define exceptions '(array
mutable-array
flarray
fcarray
inline-array-map
array+
array*
array-
array/
array-min
array-max
array-scale
array-abs
array-sqr
array-sqrt
array-conjugate
array-real-part
array-imag-part
array-make-rectangular
array-magnitude
array-angle
array-make-polar
array<
array<=
array>
array>=
array=
array-not
array-and
array-or
array-if
array-axis-sum
array-axis-prod
array-axis-min
array-axis-max
array-all-sum
array-all-prod
array-all-min
array-all-max
array-count
array-andmap
array-ormap
inline-flarray-map
inline-fcarray-map
array-strict
array-default-strict
make-unsafe-array-proc
make-unsafe-array-set-proc
array/syntax))
(define (looks-like-value? sym)
(define str (symbol->string sym))
(and (not (char-upper-case? (string-ref str 0)))
(not (regexp-match #rx"for/" str))
(not (regexp-match #rx"for\\*/" str))
(not (member sym exceptions))))
(define array-exports
(let ()
(match-define (list (list #f _ ...)
(list 1 _ ...)
(list 0 array-exports ...))
(syntax-local-module-exports #'math/array))
(filter looks-like-value? array-exports)))
)
(define-syntax (all-exports stx)
(with-syntax ([(array-exports ...) array-exports])
(syntax/loc stx
(begin (void array-exports) ...))))
(all-exports)
;; ---------------------------------------------------------------------------------------------------
;; Comprehensions
(check-equal? (for/array #:shape #() () 3)
(mutable-array 3))
(check-equal? (for/array #:shape #() () 'foo)
(mutable-array 'foo))
(check-equal? (for/array #:shape #(2) ([x (in-naturals)]) x)
(mutable-array #[0 1]))
(check-equal? (for/array #:shape #(2 3) ([i (in-range 0 6)])
(vector (quotient i 3) (remainder i 3)))
(indexes-array #(2 3)))
(check-equal? (for*/array #:shape #() () 3)
(mutable-array 3))
(check-equal? (for*/array #:shape #() () 'foo)
(mutable-array 'foo))
(check-equal? (for*/array #:shape #(2) ([x (in-naturals)]) x)
(mutable-array #[0 1]))
(check-equal? (for*/array #:shape #(2 3) ([i (in-range 0 2)]
[j (in-range 0 3)])
(vector i j))
(indexes-array #(2 3)))
(check-equal? (for*/array #:shape #() () 3)
(for*/array #:shape #() () 3))
(check-equal? (for*/array #:shape #() () 'foo)
(for*/array #:shape #() () 'foo))
(check-equal? (for*/array #:shape #(2) ([x (in-naturals)]) x)
(for*/array #:shape #(2) ([x (in-naturals)]) x))
(check-equal? (for*/array #:shape #(2 3) ([i (in-range 0 2)]
[j (in-range 0 3)])
(list i j))
(for*/array #:shape #(2 3) ([i (in-range 0 2)]
[j (in-range 0 3)])
(list i j)))
;; ---------------------------------------------------------------------------------------------------
;; Sequences
(check-equal? (for/list ([x (in-array (array #[#[1 2 3] #[4 5 6]]))]) x)
'(1 2 3 4 5 6))
(check-equal? (for/list ([js (in-array (indexes-array #()))]) js)
'(#()))
(check-equal? (for/list ([js (in-array (indexes-array #(0)))]) js)
'())
(check-equal? (for/list ([js (in-array (indexes-array #(2 2)))]) js)
'(#(0 0) #(0 1) #(1 0) #(1 1)))
(check-equal? (sequence->list (in-array (indexes-array #())))
'(#()))
(check-equal? (sequence->list (in-array (indexes-array #(0))))
'())
(check-equal? (sequence->list (in-array (indexes-array #(2 2))))
'(#(0 0) #(0 1) #(1 0) #(1 1)))
(check-equal? (for/list ([js (in-array-indexes #())]) js)
'(#()))
(check-equal? (for/list ([js (in-array-indexes #(0))]) js)
'())
(check-equal? (for/list ([js (in-array-indexes #(2 2))]) js)
'(#(0 0) #(0 1) #(1 0) #(1 1)))
(check-equal? (sequence->list (in-array-indexes #()))
'(#()))
(check-equal? (sequence->list (in-array-indexes #(0)))
'())
(check-equal? (sequence->list (in-array-indexes #(2 2)))
'(#(0 0) #(0 1) #(1 0) #(1 1)))
(check-equal? (for/list ([js (in-unsafe-array-indexes #(2 2))]) js)
'(#(0 0) #(0 0) #(0 0) #(0 0)))
(check-equal? (for/list ([js (in-unsafe-array-indexes #())]) (vector-copy js))
'(#()))
(check-equal? (for/list ([js (in-unsafe-array-indexes #(0))]) (vector-copy js))
'())
(check-equal? (for/list ([js (in-unsafe-array-indexes #(2 2))]) (vector-copy js))
'(#(0 0) #(0 1) #(1 0) #(1 1)))
(check-equal? (sequence->list (in-unsafe-array-indexes #()))
'(#()))
(check-equal? (sequence->list (in-unsafe-array-indexes #(0)))
'())
(check-equal? (sequence->list (in-unsafe-array-indexes #(2 2)))
'(#(0 0) #(0 1) #(1 0) #(1 1)))
(let ([arr (indexes-array #(4 5))])
(check-equal? (for/list ([brr (in-array-axis arr 0)])
(for/list ([js (in-array brr)])
js))
(array->list* arr))
(check-equal? (for/list ([brr (in-array-axis arr 1)])
(for/list ([js (in-array brr)])
js))
(array->list* (array-axis-swap arr 0 1))))
;; ---------------------------------------------------------------------------------------------------
;; Construction macros
(check-equal? (array #[1 2 3])
(list->array '(1 2 3)))
(check-equal? (array #[#[0 1] #[2 3]])
(index-array #(2 2)))
(check-equal? (mutable-array #[1 2 3])
(list->array '(1 2 3)))
(check-equal? (mutable-array #[#[0 1] #[2 3]])
(index-array #(2 2)))
(check-equal? (flarray #[1 2 3])
(array->flarray (list->array '(1 2 3))))
(check-equal? (flarray #[#[0 1] #[2 3]])
(array->flarray (index-array #(2 2))))
(check-equal? (fcarray #[1 2 3])
(array->fcarray (list->array '(1 2 3))))
(check-equal? (fcarray #[#[0 1] #[2 3]])
(array->fcarray (index-array #(2 2))))
;; ---------------------------------------------------------------------------------------------------
;; Mapping
(check-equal? (array-map +) (array 0))
(check-equal? (array-map + (array #[1 2 3]))
(array #[1 2 3]))
(check-equal? (array-map + (array #[1 2 3]) (array #[10 20 30]))
(array #[11 22 33]))
(check-equal? (array-map make-rectangular (array #[1 2 3]) (array #[10 20 30]))
(array #[1+10i 2+20i 3+30i]))
(check-exn exn:fail:contract? (λ () (array-map)))
(check-exn exn:fail:contract? (λ () (array-map 5)))
(check-exn exn:fail:contract? (λ () (array-map + 5)))
(check-exn exn:fail:contract? (λ () (array-map -)))
(check-exn exn:fail:contract? (λ () (array-map make-rectangular (array #[1 2 3]))))
(check-exn exn:fail:contract? (λ () (array-map values (array #[1 2 3]) (array #[10 20 30]))))
(check-equal? (array+ (array #[1 2 3]) (array #[10 20 30]))
(array-map + (array #[1 2 3]) (array #[10 20 30])))
(check-exn exn:fail:contract? (λ () (array+ (array #[1 2 3]) (array 'x))))
(check-equal? (array-scale (array #[1 2 3]) 2)
(array-map (λ (x) (* x 2)) (array #[1 2 3])))
(check-exn exn:fail:contract? (λ () (array-scale (array #[1 2 3]) 'x)))
(check-equal? (inline-array-map +) (array 0))
(check-equal? (inline-array-map + (array #[1 2 3]))
(array #[1 2 3]))
(check-equal? (inline-array-map + (array #[1 2 3]) (array #[10 20 30]))
(array #[11 22 33]))
(check-equal? (inline-array-map make-rectangular (array #[1 2 3]) (array #[10 20 30]))
(array #[1+10i 2+20i 3+30i]))
(check-exn exn:fail:contract? (λ () (inline-array-map 5)))
(check-exn exn:fail:contract? (λ () (inline-array-map + 5)))
(check-exn exn:fail:contract? (λ () (inline-array-map -)))
(check-exn exn:fail:contract? (λ () (inline-array-map make-rectangular (array #[1 2 3]))))
(check-exn exn:fail:contract? (λ () (inline-array-map values (array #[1 2 3]) (array #[10 20 30]))))
;; ---------------------------------------------------------------------------------------------------
;; Folding
(check-equal? (array-axis-sum (array #[1 2 3]) 0)
(array 6))
(check-equal? (array-axis-sum (array #[]) 0 0.0)
(array 0.0))
(check-equal? (array-all-sum (array #[1 2 3]))
6)
(check-equal? (array-all-sum (array #[]) 0.0)
0.0)
(check-equal? (array-count even? (array #[1 2 3]))
1)
(check-equal? (array-count equal? (array #[1 2 3]) (array #[2 1 3]))
1)
(check-false (array-andmap equal? (array #[1 2 3]) (array #[2 1 3])))
(check-true (array-ormap equal? (array #[1 2 3]) (array #[2 1 3])))
(let ([arr (parameterize ([array-strictness #f])
(array-strict
(build-array #(3 3) (λ (js) (apply + (vector->list js))))))])
(check-true (array-strict? arr))
(check-equal? arr (array #[#[0 1 2] #[1 2 3] #[2 3 4]])))
(let ([arr (parameterize ([array-strictness #f])
(array-default-strict
(build-array #(3 3) (λ (js) (apply + (vector->list js))))))])
(check-false (array-strict? arr))
(check-equal? arr (array #[#[0 1 2] #[1 2 3] #[2 3 4]]))
(array-strict! arr)
(check-true (array-strict? arr))
(check-equal? arr (array #[#[0 1 2] #[1 2 3] #[2 3 4]])))

View File

@ -21,23 +21,32 @@
divtime))))))
divtime)
(define truth
(array #[#[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
#[1 1 1 1 1 1 2 2 2 3 2 2 1 1 1]
#[1 1 1 1 1 2 2 2 3 6 20 3 2 1 1]
#[1 1 1 2 2 2 3 4 5 20 17 4 3 2 1]
#[1 1 2 2 3 3 4 11 20 20 20 10 14 2 2]
#[1 2 3 4 6 6 6 20 20 20 20 20 9 3 2]
#[2 3 4 6 18 20 14 20 20 20 20 20 20 3 2]
#[20 20 20 20 20 20 20 20 20 20 20 20 5 3 2]
#[2 3 4 6 18 20 14 20 20 20 20 20 20 3 2]
#[1 2 3 4 6 6 6 20 20 20 20 20 9 3 2]
#[1 1 2 2 3 3 4 11 20 20 20 10 14 2 2]
#[1 1 1 2 2 2 3 4 5 20 17 4 3 2 1]
#[1 1 1 1 1 2 2 2 3 6 20 3 2 1 1]
#[1 1 1 1 1 1 2 2 2 3 2 2 1 1 1]
#[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]))
(check-true
(equal? (mandelbrot 0.2 20)
(array #[#[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
#[1 1 1 1 1 1 2 2 2 3 2 2 1 1 1]
#[1 1 1 1 1 2 2 2 3 6 20 3 2 1 1]
#[1 1 1 2 2 2 3 4 5 20 17 4 3 2 1]
#[1 1 2 2 3 3 4 11 20 20 20 10 14 2 2]
#[1 2 3 4 6 6 6 20 20 20 20 20 9 3 2]
#[2 3 4 6 18 20 14 20 20 20 20 20 20 3 2]
#[20 20 20 20 20 20 20 20 20 20 20 20 5 3 2]
#[2 3 4 6 18 20 14 20 20 20 20 20 20 3 2]
#[1 2 3 4 6 6 6 20 20 20 20 20 9 3 2]
#[1 1 2 2 3 3 4 11 20 20 20 10 14 2 2]
#[1 1 1 2 2 2 3 4 5 20 17 4 3 2 1]
#[1 1 1 1 1 2 2 2 3 6 20 3 2 1 1]
#[1 1 1 1 1 1 2 2 2 3 2 2 1 1 1]
#[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1]])))
truth))
(check-true
(equal? (parameterize ([array-strictness #f])
(mandelbrot 0.2 20))
truth))
#;
(begin
(require images/flomap)