310 lines
9.5 KiB
Racket
310 lines
9.5 KiB
Racket
#lang typed/racket/base
|
|
|
|
;; TOOD abstract some designs
|
|
|
|
(provide
|
|
define-vector:
|
|
let-vector:
|
|
vector-length:
|
|
vector-ref:
|
|
vector-set!:
|
|
vector-map:
|
|
vector-map!:
|
|
vector-append:
|
|
vector->list:
|
|
vector->immutable-vector:
|
|
vector-fill!:
|
|
vector-take:
|
|
vector-take-right:
|
|
vector-drop:
|
|
vector-drop-right:
|
|
; vector-split-at:
|
|
; vector-split-at-right:
|
|
|
|
;; --- private
|
|
(for-syntax parse-vector-length)
|
|
)
|
|
|
|
;; -----------------------------------------------------------------------------
|
|
|
|
(require
|
|
(only-in racket/unsafe/ops
|
|
unsafe-vector-set!
|
|
unsafe-vector-ref)
|
|
racket/vector
|
|
trivial/math
|
|
(for-syntax
|
|
typed/racket/base
|
|
racket/syntax
|
|
syntax/id-table
|
|
syntax/parse
|
|
syntax/stx
|
|
trivial/private/common
|
|
))
|
|
|
|
;; =============================================================================
|
|
|
|
(define-for-syntax vector-length-key 'vector:length)
|
|
(define-for-syntax id+vector-length (make-free-id-table))
|
|
|
|
(begin-for-syntax (define-syntax-class vector/length
|
|
#:attributes (expanded length)
|
|
(pattern e
|
|
#:with e+ (expand-expr #'e)
|
|
#:with len (parse-vector-length #'e+)
|
|
#:when (syntax-e #'len)
|
|
#:attr expanded #'e+
|
|
#:attr length #'len)
|
|
))
|
|
|
|
(define-syntax (define-vector: stx)
|
|
(syntax-parse stx
|
|
[(_ name:id v:vector/length)
|
|
(free-id-table-set! id+vector-length #'name (syntax-e #'v.length))
|
|
#'(define name v.expanded)]
|
|
[(_ e* ...)
|
|
#'(define e* ...)]))
|
|
|
|
(define-syntax (let-vector: stx)
|
|
(syntax-parse stx
|
|
[(_ ([name*:id v*:vector/length] ...) e* ...)
|
|
#'(let ([name* v*.expanded] ...)
|
|
(let-syntax ([name* (make-rename-transformer
|
|
(syntax-property #'name*
|
|
vector-length-key
|
|
'v*.length))] ...)
|
|
e* ...))]
|
|
[(_ e* ...)
|
|
#'(let e* ...)]))
|
|
|
|
(define-for-syntax (vector-bounds-error sym v-stx i)
|
|
(raise-syntax-error
|
|
sym
|
|
"Index out-of-bounds"
|
|
(syntax->datum v-stx)
|
|
i
|
|
(list v-stx)))
|
|
|
|
(define-syntax (vector-length: stx)
|
|
(syntax-parse stx
|
|
[(_ v:vector/length)
|
|
(syntax/loc stx 'v.length)]
|
|
[_:id
|
|
(syntax/loc stx vector-length)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector-length e* ...))]))
|
|
|
|
(define-syntax (vector-ref: stx)
|
|
(syntax-parse stx
|
|
[(_ v:vector/length i:nat/expand)
|
|
(unless (< (syntax-e #'i.expanded) (syntax-e #'v.length))
|
|
(vector-bounds-error 'vector-ref: #'v (syntax-e #'i.expanded)))
|
|
(syntax/loc stx (unsafe-vector-ref v.expanded 'i.expanded))]
|
|
[_:id
|
|
(syntax/loc stx vector-ref)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector-ref e* ...))]))
|
|
|
|
(define-syntax (vector-set!: stx)
|
|
(syntax-parse stx
|
|
[(_ v:vector/length i:nat/expand val)
|
|
(unless (< (syntax-e #'i.expanded) (syntax-e #'v.length))
|
|
(vector-bounds-error 'vector-set!: #'v (syntax-e #'i.expanded)))
|
|
(syntax/loc stx (unsafe-vector-set! v.expanded 'i.expanded val))]
|
|
[_:id
|
|
(syntax/loc stx vector-set!)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector-set! e* ...))]))
|
|
|
|
(define-syntax (vector-map: stx)
|
|
(syntax-parse stx
|
|
[(_ f v:vector/length)
|
|
#:with f+ (gensym 'f)
|
|
#:with v+ (gensym 'v)
|
|
#:with v++ (syntax-property
|
|
(if (small-vector-size? (syntax-e #'v.length))
|
|
(with-syntax ([(i* ...) (for/list ([i (in-range (syntax-e #'v.length))]) i)])
|
|
(syntax/loc stx
|
|
(let ([f+ f] [v+ v.expanded])
|
|
(vector (f+ (unsafe-vector-ref v+ 'i*)) ...))))
|
|
(syntax/loc stx
|
|
(let ([f+ f] [v+ v.expanded])
|
|
(build-vector 'v.length (lambda ([i : Integer])
|
|
(f+ (vector-ref: v+ i)))))))
|
|
vector-length-key
|
|
(syntax-e #'v.length))
|
|
(syntax/loc stx v++)]
|
|
[_:id
|
|
(syntax/loc stx vector-map)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector-map e* ...))]))
|
|
|
|
(define-syntax (vector-map!: stx)
|
|
(syntax-parse stx
|
|
[(_ f v:vector/length)
|
|
#:with f+ (gensym 'f)
|
|
#:with v+ (gensym 'v)
|
|
#:with v++ (syntax-property
|
|
#'(let ([f+ f]
|
|
[v+ v.expanded])
|
|
(for ([i (in-range 'v.length)])
|
|
(unsafe-vector-set! v+ i (f+ (unsafe-vector-ref v+ i))))
|
|
v+)
|
|
vector-length-key
|
|
(syntax-e #'v.length))
|
|
(syntax/loc stx v++)]
|
|
[_:id
|
|
(syntax/loc stx vector-map!)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector-map! e* ...))]))
|
|
|
|
(define-syntax (vector-append: stx)
|
|
(syntax-parse stx
|
|
[(_ v1:vector/length v2:vector/length)
|
|
#:with v1+ (gensym 'v1)
|
|
#:with v2+ (gensym 'v2)
|
|
(define l1 (syntax-e #'v1.length))
|
|
(define l2 (syntax-e #'v2.length))
|
|
(syntax-property
|
|
(if (and (small-vector-size? l1)
|
|
(small-vector-size? l2))
|
|
(with-syntax ([(i1* ...) (for/list ([i (in-range l1)]) i)]
|
|
[(i2* ...) (for/list ([i (in-range l2)]) i)])
|
|
(syntax/loc stx
|
|
(let ([v1+ v1.expanded]
|
|
[v2+ v2.expanded])
|
|
(vector (vector-ref: v1+ i1*) ...
|
|
(vector-ref: v2+ i2*) ...))))
|
|
(quasisyntax/loc stx
|
|
(let ([v1+ v1.expanded]
|
|
[v2+ v2.expanded])
|
|
(build-vector
|
|
#,(+ l1 l2)
|
|
(lambda (i)
|
|
(if (< i '#,l1)
|
|
(unsafe-vector-ref v1+ i)
|
|
(unsafe-vector-ref v2+ i)))))))
|
|
vector-length-key
|
|
(+ l1 l2))]
|
|
[_:id
|
|
(syntax/loc stx vector-append)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector-append e* ...))]))
|
|
|
|
(define-syntax (vector->list: stx)
|
|
(syntax-parse stx
|
|
[(_ v:vector/length)
|
|
#:with v+ (gensym 'v)
|
|
(define len (syntax-e #'v.length))
|
|
(if (small-vector-size? len)
|
|
(with-syntax ([(i* ...) (for/list ([i (in-range len)]) i)])
|
|
(syntax/loc stx
|
|
(let ([v+ v.expanded])
|
|
(list (unsafe-vector-ref v+ i*) ...))))
|
|
(syntax/loc stx
|
|
(let ([v+ v.expanded])
|
|
(build-list 'v.length (lambda (i) (unsafe-vector-ref v+ i))))))]
|
|
[_:id
|
|
(syntax/loc stx vector->list)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector->list e* ...))]))
|
|
|
|
(define-syntax (vector->immutable-vector: stx)
|
|
(syntax-parse stx
|
|
[(_ v:vector/length)
|
|
(syntax-property
|
|
(syntax/loc stx (vector->immutable-vector v.expanded))
|
|
vector-length-key
|
|
(syntax-e #'v.length))]
|
|
[_:id
|
|
(syntax/loc stx vector->immutable-vector)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector->immutable-vector e* ...))]))
|
|
|
|
(define-syntax (vector-fill!: stx)
|
|
(syntax-parse stx
|
|
[(_ v:vector/length val)
|
|
#:with v+ (gensym 'v)
|
|
(define len (syntax-e #'v.length))
|
|
(syntax-property
|
|
(syntax/loc stx
|
|
(let ([v+ v.expanded])
|
|
(for ([i (in-range 'v.length)])
|
|
(unsafe-vector-set! v+ i val))))
|
|
vector-length-key
|
|
(syntax-e #'v.length))]
|
|
[_:id
|
|
(syntax/loc stx vector->fill!)]
|
|
[(_ e* ...)
|
|
(syntax/loc stx (vector->fill! e* ...))]))
|
|
|
|
(begin-for-syntax (define-syntax-rule (make-slice-op op-name left? take?)
|
|
(lambda (stx)
|
|
(syntax-parse stx
|
|
[(_ v:vector/length n:nat/expand)
|
|
#:with (lo hi)
|
|
(if 'take?
|
|
(if 'left?
|
|
(list 0 (syntax-e #'n.expanded))
|
|
(list
|
|
(- (syntax-e #'v.length) (syntax-e #'n.expanded))
|
|
(syntax-e #'v.length)))
|
|
(if 'left?
|
|
(list (syntax-e #'n.expanded) (syntax-e #'v.length))
|
|
(list 0 (- (syntax-e #'v.length) (syntax-e #'n.expanded)))))
|
|
#:with n+ (gensym 'n)
|
|
#:with v+ (gensym 'v)
|
|
(unless (<= (syntax-e #'n.expanded) (syntax-e #'v.length))
|
|
(vector-bounds-error 'op-name #'v
|
|
(if 'take? (if 'left? (syntax-e #'hi) (syntax-e #'lo))
|
|
(if 'left? (syntax-e #'lo) (syntax-e #'hi)))))
|
|
(syntax-property
|
|
(syntax/loc stx
|
|
(let ([v+ v.expanded]
|
|
[n+ (-: 'hi 'lo)])
|
|
(build-vector n+ (lambda ([i : Integer]) (unsafe-vector-ref v+ (+: i 'lo))))))
|
|
vector-length-key
|
|
(syntax-e #'v.length))]
|
|
[(_ v n:int/expand)
|
|
(vector-bounds-error 'op-name #'v (syntax-e #'n.expanded))]
|
|
[_:id
|
|
(syntax/loc stx op-name)]
|
|
[(_ e* (... ...))
|
|
(syntax/loc stx (op-name e* (... ...)))]))))
|
|
|
|
(define-syntax vector-take: (make-slice-op vector-take #t #t))
|
|
(define-syntax vector-take-right: (make-slice-op vector-take-right #f #t))
|
|
(define-syntax vector-drop-right: (make-slice-op vector-drop-right #f #f))
|
|
(define-syntax vector-drop: (make-slice-op vector-drop #t #f))
|
|
|
|
;; -----------------------------------------------------------------------------
|
|
|
|
(define-for-syntax (small-vector-size? n)
|
|
(< n 20))
|
|
|
|
;; Assume `stx` is creating a vector; get the length of the vector to-be-made
|
|
(define-for-syntax (parse-vector-length stx)
|
|
(cond
|
|
[(syntax-property stx vector-length-key)
|
|
=> (lambda (x) x)]
|
|
[(identifier? stx)
|
|
(free-id-table-ref id+vector-length stx #f)]
|
|
[else
|
|
(syntax-parse stx #:literals (vector make-vector build-vector)
|
|
[(~or '#(e* ...)
|
|
#(e* ...)
|
|
;; TODO #{} #[] #6{} ...
|
|
(_ vector e* ...) ;; TODO the _ should be matching #%app
|
|
(vector e* ...))
|
|
(length (syntax->list #'(e* ...)))]
|
|
[(~or (make-vector n e* ...)
|
|
(_ make-vector n e* ...)
|
|
(build-vector n e* ...)
|
|
(_ build-vector n e* ...))
|
|
(if (syntax-transforming?)
|
|
(quoted-stx-value? (expand-expr #'n))
|
|
(and (exact-nonnegative-integer? (syntax-e #'n)) (syntax-e #'n)))]
|
|
[_
|
|
#f])]))
|
|
|