diff --git a/collects/typed-scheme/private/optimize.rkt b/collects/typed-scheme/private/optimize.rkt index e34c5b65..5d466deb 100644 --- a/collects/typed-scheme/private/optimize.rkt +++ b/collects/typed-scheme/private/optimize.rkt @@ -2,7 +2,7 @@ (require syntax/parse (for-template scheme/base scheme/flonum scheme/unsafe/ops) "../utils/utils.rkt" unstable/match scheme/match unstable/syntax - (rep type-rep) + (rep type-rep) syntax/id-table racket/dict (types abbrev type-table utils subtype)) (provide optimize) @@ -26,21 +26,22 @@ [(tc-result1: (== -Flonum type-equal?)) #t] [_ #f]) #:with opt #'e.opt)) -(define-syntax-class float-binary-op - #:literals (+ - * / = <= < > >= min max - fl+ fl- fl* fl/ fl= fl<= fl< fl> fl>= flmin flmax) - (pattern (~and i:id (~or + - * / = <= < > >= min max)) - #:with unsafe (format-id #'here "unsafe-fl~a" #'i)) - (pattern (~and i:id (~or fl+ fl- fl* fl/ fl= fl<= fl< fl> fl>= flmin flmax)) - #:with unsafe (format-id #'here "unsafe-~a" #'i))) +(define (mk-float-tbl generic) + (for/fold ([h (make-immutable-free-id-table)]) ([g generic]) + (let ([f (format-id g "fl~a" g)] [u (format-id g "unsafe-fl~a" g)]) + (dict-set (dict-set h g u) f u)))) -(define-syntax-class float-unary-op - #:literals (abs sin cos tan asin acos atan log exp sqrt round floor ceiling truncate - flabs flsin flcos fltan flasin flacos flatan fllog flexp flsqrt flround flfloor flceiling fltruncate) - (pattern (~and i:id (~or abs sin cos tan asin acos atan log exp sqrt round floor ceiling truncate)) - #:with unsafe (format-id #'here "unsafe-fl~a" #'i)) - (pattern (~and i:id (~or flabs flsin flcos fltan flasin flacos flatan fllog flexp flsqrt flround flfloor flceiling fltruncate)) - #:with unsafe (format-id #'here "unsafe-~a" #'i))) +(define binary-float-ops + (mk-float-tbl (list #'+ #'- #'* #'/ #'= #'<= #'< #'> #'>= #'min #'max))) + +(define unary-float-ops + (mk-float-tbl (list #'abs #'sin #'cos #'tan #'asin #'acos #'atan #'log #'exp + #'sqrt #'round #'floor #'ceiling #'truncate))) + +(define-syntax-class (float-op tbl) + (pattern i:id + #:when (dict-ref tbl #'i #f) + #:with unsafe (dict-ref tbl #'i))) (define-syntax-class pair-opt-expr (pattern e:opt-expr @@ -50,9 +51,8 @@ #:with opt #'e.opt)) (define-syntax-class pair-unary-op - #:literals (car cdr) - (pattern (~and i:id (~or car cdr)) - #:with unsafe (format-id #'here "unsafe-~a" #'i))) + (pattern (~literal car) #:with unsafe #'unsafe-car) + (pattern (~literal cdr) #:with unsafe #'unsafe-cdr)) (define-syntax-class opt-expr (pattern e:opt-expr* @@ -72,12 +72,12 @@ #:literal-sets (kernel-literals) ;; interesting cases, where something is optimized - (pattern (#%plain-app op:float-unary-op f:float-opt-expr) + (pattern (#%plain-app (~var op (float-op unary-float-ops)) f:float-opt-expr) #:with opt (begin (log-optimization "unary float" #'op) #'(op.unsafe f.opt))) ;; unlike their safe counterparts, unsafe binary operators can only take 2 arguments - (pattern (~and res (#%plain-app op:float-binary-op f1:float-arg-expr f2:float-arg-expr fs:float-arg-expr ...)) + (pattern (~and res (#%plain-app (~var op (float-op binary-float-ops)) f1:float-arg-expr f2:float-arg-expr fs:float-arg-expr ...)) #:when (match (type-of #'res) [(tc-result1: (== -Flonum type-equal?)) #t] [_ #f]) #:with opt