trivial/vector.rkt
2016-03-03 02:29:21 -05:00

183 lines
5.2 KiB
Racket

#lang typed/racket/base
;; TODO integrate with trivial/math to get ints from identifiers
(provide
define-vector:
let-vector:
vector-length:
vector-ref:
vector-set!:
vector-map:
vector-map!:
;vector-append:
;vector->list
;vector->immutable-vector
;vector-fill!
;
;; TODO and a few more
;; --- private
(for-syntax parse-vector-length)
)
;; -----------------------------------------------------------------------------
(require
(only-in racket/unsafe/ops
unsafe-vector-set!
unsafe-vector-ref)
racket/vector
(for-syntax
typed/racket/base
syntax/id-table
syntax/parse
syntax/stx
trivial/private/common
))
;; =============================================================================
(define-for-syntax vector-length-key 'vector:length)
(define-for-syntax id+vector-length (make-free-id-table))
(begin-for-syntax (define-syntax-class vector/length
#:attributes (expanded length)
(pattern e
#:with e+ (expand-expr #'e)
#:with len (parse-vector-length #'e+)
#:when (syntax-e #'len)
#:attr expanded #'e+
#:attr length #'len)
))
(define-syntax (define-vector: stx)
(syntax-parse stx
[(_ name:id v:vector/length)
(free-id-table-set! id+vector-length #'name (syntax-e #'v.length))
#'(define name v.expanded)]
[(_ e* ...)
#'(define e* ...)]))
(define-syntax (let-vector: stx)
(syntax-parse stx
[(_ ([name*:id v*:vector/length] ...) e* ...)
#'(let ([name* v*.expanded] ...)
(let-syntax ([name* (make-rename-transformer
(syntax-property #'name*
vector-length-key
'v*.length))] ...)
e* ...))]
[(_ e* ...)
#'(let e* ...)]))
(define-for-syntax (vector-bounds-error sym v-stx i)
(raise-syntax-error
sym
"Index out-of-bounds"
(syntax->datum v-stx)
i
(list v-stx)))
(define-syntax (vector-length: stx)
(syntax-parse stx
[(_ v:vector/length)
(syntax/loc stx 'v.length)]
[_:id
(syntax/loc stx vector-length)]
[(_ e* ...)
(syntax/loc stx (vector-length e* ...))]))
(define-syntax (vector-ref: stx)
(syntax-parse stx
[(_ v:vector/length i:nat)
(unless (< (syntax-e #'i) (syntax-e #'v.length))
(vector-bounds-error 'vector-ref: #'v (syntax-e #'i)))
(syntax/loc stx (unsafe-vector-ref v.expanded i))]
[_:id
(syntax/loc stx vector-ref)]
[(_ e* ...)
(syntax/loc stx (vector-ref e* ...))]))
(define-syntax (vector-set!: stx)
(syntax-parse stx
[(_ v:vector/length i:nat val)
(unless (< (syntax-e #'i) (syntax-e #'v.length))
(vector-bounds-error 'vector-set!: #'v (syntax-e #'i)))
(syntax/loc stx (unsafe-vector-set! v.expanded i val))]
[_:id
(syntax/loc stx vector-set!)]
[(_ e* ...)
(syntax/loc stx (vector-set! e* ...))]))
(define-syntax (vector-map: stx)
(syntax-parse stx
[(_ f v:vector/length)
#:with (i* ...) (for/list ([i (in-range (syntax-e #'v.length))]) i)
#:with f+ (gensym 'f)
#:with v+ (syntax-property
(if (small-vector-size? (syntax-e #'v.length))
(syntax/loc stx
(let ([f+ f])
(vector (f+ (unsafe-vector-ref v.expanded 'i*)) ...)))
(syntax/loc stx
(let ([f+ f])
(build-vector 'v.length (lambda ([i : Integer])
(f+ (vector-ref: v.expanded i)))))))
vector-length-key
(syntax-e #'v.length))
(syntax/loc stx v+)]
[_:id
(syntax/loc stx vector-map)]
[(_ e* ...)
(syntax/loc stx (vector-map e* ...))]))
(define-syntax (vector-map!: stx)
(syntax-parse stx
[(_ f v:vector/length)
#:with f+ (gensym 'f)
#:with v+ (syntax-property
#'(let ([f+ f])
(for ([i (in-range 'v.length)])
(unsafe-vector-set! v.expanded i (f+ (unsafe-vector-ref v.expanded i))))
v.expanded)
vector-length-key
(syntax-e #'v.length))
(syntax/loc stx v+)]
[_:id
(syntax/loc stx vector-map!)]
[(_ e* ...)
(syntax/loc stx (vector-map! e* ...))]))
;; -----------------------------------------------------------------------------
(define-for-syntax (small-vector-size? n)
(< n 101))
;; Assume `stx` is creating a vector; get the length of the vector to-be-made
(define-for-syntax (parse-vector-length stx)
(cond
[(syntax-property stx vector-length-key)
=> (lambda (x) x)]
[(identifier? stx)
(free-id-table-ref id+vector-length stx #f)]
[else
(syntax-parse stx #:literals (vector make-vector build-vector)
[(~or '#(e* ...)
#(e* ...)
;; TODO #{} #[] #6{} ...
(_ vector e* ...) ;; TODO the _ should be matching #%app
(vector e* ...))
(length (syntax->list #'(e* ...)))]
[(~or (make-vector n e* ...)
(_ make-vector n e* ...)
(build-vector n e* ...)
(_ build-vector n e* ...))
(if (syntax-transforming?)
(quoted-stx-value? (expand-expr #'n))
(and (exact-nonnegative-integer? (syntax-e #'n)) (syntax-e #'n)))]
[_
#f])]))