diff --git a/collects/tests/typed-racket/optimizer/tests/multi-flcomp.rkt b/collects/tests/typed-racket/optimizer/tests/multi-flcomp.rkt new file mode 100644 index 00000000..e4c05dce --- /dev/null +++ b/collects/tests/typed-racket/optimizer/tests/multi-flcomp.rkt @@ -0,0 +1,22 @@ +#; +( +TR opt: multi-flcomp.rkt 18:0 (<= 1.0 2.0 3.0) -- multi float comp +TR opt: multi-flcomp.rkt 20:0 (<= 1.0 2.0 3.0 4.0) -- multi float comp +TR opt: multi-flcomp.rkt 21:0 (<= 1.0 2.0 3.0 (+ 2.0 2.0)) -- multi float comp +TR opt: multi-flcomp.rkt 21:16 (+ 2.0 2.0) -- binary float +TR opt: multi-flcomp.rkt 22:0 (<= 1.0 2.0 (+ 2.0 2.0) 3.0) -- multi float comp +TR opt: multi-flcomp.rkt 22:12 (+ 2.0 2.0) -- binary float +#t +#t +#t +#t +#f +) + +#lang typed/racket + +(<= 1.0 2.0 3.0) +(<= 1.0 2.0 3) ; unsafe, last one is not a float +(<= 1.0 2.0 3.0 4.0) +(<= 1.0 2.0 3.0 (+ 2.0 2.0)) +(<= 1.0 2.0 (+ 2.0 2.0) 3.0) diff --git a/collects/typed-racket/optimizer/float.rkt b/collects/typed-racket/optimizer/float.rkt index ed663f7e..129f7d4c 100644 --- a/collects/typed-racket/optimizer/float.rkt +++ b/collects/typed-racket/optimizer/float.rkt @@ -156,13 +156,35 @@ #:with opt (begin (log-optimization "binary float" float-opt-msg this-syntax) (n-ary->binary #'op.unsafe #'f1.opt #'f2.opt #'(fs.opt ...)))) + (pattern (#%plain-app (~var op (float-op binary-float-comps)) + f1:float-expr + f2:float-expr) + #:with opt + (begin (log-optimization "binary float comp" float-opt-msg this-syntax) + #'(op.unsafe f1.opt f2.opt))) (pattern (#%plain-app (~var op (float-op binary-float-comps)) f1:float-expr f2:float-expr fs:float-expr ...) #:with opt - (begin (log-optimization "binary float comp" float-opt-msg this-syntax) - (n-ary->binary #'op.unsafe #'f1.opt #'f2.opt #'(fs.opt ...)))) + (begin (log-optimization "multi float comp" float-opt-msg this-syntax) + ;; First, generate temps to bind the result of each f2 fs ... + ;; to avoid computing them multiple times. + (define lifted (map (lambda (x) (unboxed-gensym)) (syntax->list #'(f2 fs ...)))) + ;; Second, build the list ((op f1 tmp2) (op tmp2 tmp3) ...) + (define tests + (let loop ([res (list #`(op.unsafe f1.opt #,(car lifted)))] + [prev (car lifted)] + [l (cdr lifted)]) + (cond [(null? l) (reverse res)] + [else (loop (cons #`(op.unsafe #,prev #,(car l)) res) + (car l) + (cdr l))]))) + ;; Finally, build the whole thing. + #`(let #,(for/list ([lhs (in-list lifted)] + [rhs (in-list (syntax->list #'(f2.opt fs.opt ...)))]) + #`(#,lhs #,rhs)) + (and #,@tests)))) (pattern (#%plain-app (~and op (~literal -)) f:float-expr) #:with opt