diff --git a/typed-racket-lib/typed-racket/optimizer/pair.rkt b/typed-racket-lib/typed-racket/optimizer/pair.rkt index 0f90ec81..a5f97a30 100644 --- a/typed-racket-lib/typed-racket/optimizer/pair.rkt +++ b/typed-racket-lib/typed-racket/optimizer/pair.rkt @@ -6,7 +6,7 @@ (for-syntax racket/base syntax/parse racket/syntax) "../utils/utils.rkt" (rep type-rep) - (types type-table utils base-abbrev) + (types type-table utils base-abbrev resolve subtype) (typecheck typechecker) (optimizer utils logging)) @@ -25,10 +25,7 @@ (define (has-pair-type? e) - (and (subtypeof? e (-pair Univ Univ)) - ;; sometimes composite operations end up with Nothing as result type, - ;; not sure why. TODO investigate - (not (isoftype? e -Bottom)))) + (subtypeof? e (-pair Univ Univ))) ;; can't do the above for mpairs, as they are invariant (define (has-mpair-type? e) (match (type-of e) ; type of the operand @@ -67,25 +64,12 @@ ;; change the source location of a given syntax object -(define (relocate stx loc-stx) +(define ((relocate loc-stx) stx) (datum->syntax stx (syntax->datum stx) loc-stx stx stx)) ;; if the equivalent sequence of cars and cdrs is guaranteed not to fail, ;; we can optimize -;; accessors is a list of syntax objects, all #'car or #'cdr -(define (gen-alt accessors op arg stx) - (define (gen-alt-helper accessors) - (for/fold [(accum arg)] [(acc (reverse accessors))] - (quasisyntax/loc stx (#%plain-app #,(relocate acc op) #,accum)))) - (let ((ty (type-of stx)) - (obj (gen-alt-helper accessors))) - ;; we're calling the typechecker, but this is just a shortcut, we're - ;; still conceptually single pass (we're not iterating). we could get - ;; the same result by statically destructing the types. - (tc-expr/check obj ty) - obj)) - (define-syntax gen-pair-derived-expr (syntax-parser [(_ name:id (orig:id seq ...) ...) @@ -96,8 +80,9 @@ (define-literal-syntax-class lit-class-name (orig)) (define-syntax-class syntax-class-name #:commit + #:attributes (arg alt) (pattern (#%plain-app (~var op lit-class-name) arg) - #:with alt (gen-alt (list seq ...) #'op #'arg this-syntax)))) ... + #:with alt (map (relocate #'op) (list seq ...))))) ... (define-merged-syntax-class name (syntax-class-name ...)))])) (gen-pair-derived-expr pair-derived-expr @@ -144,5 +129,30 @@ (define-syntax-class pair-derived-opt-expr #:commit (pattern e:pair-derived-expr - #:with e*:pair-opt-expr #'e.alt - #:with opt #'e*.opt)) + #:with opt + ;; optimize alt inside-out, as long as it's safe to + (let-values + ([(t res) + (for/fold ([t (match (type-of #'e.arg) + [(tc-result1: t) t])] + [res #'e.arg]) + ([accessor (in-list (reverse (syntax->list #'e.alt)))]) + (cond + [(subtype t (-pair Univ Univ)) ; safe to optimize this one layer + (syntax-parse accessor + [op:pair-op + (log-pair-opt) + (values + (match (resolve t) + [(Pair: a d) ; peel off one layer of the type + (syntax-parse #'op + [:car^ a] + [:cdr^ d])] + [_ ; not a pair type, give up on optimizing more + #f]) + #`(op.unsafe #,res))])] + [else ; unsafe, just rebuild the rest of the accessors + (log-pair-missed-opt accessor #'e.arg) + (values t ; stays unsafe from now on + #`(#,accessor #,res))]))]) + res))) diff --git a/typed-racket-test/optimizer/tests/derived-pair-open-terms.rkt b/typed-racket-test/optimizer/tests/derived-pair-open-terms.rkt new file mode 100644 index 00000000..221743d0 --- /dev/null +++ b/typed-racket-test/optimizer/tests/derived-pair-open-terms.rkt @@ -0,0 +1,36 @@ +#;#; +#< Integer)) +(define (f x) + (cadr x)) + +(: g ((List Integer Integer Integer) -> Integer)) +(define (g x) + (caddr x)) + +(: h ((Listof Integer) -> Integer)) +(define (h x) + (first x)) ; unsafe + +(: i ((Listof Integer) -> (Listof Integer))) +(define (i x) + (rest x)) ; unsafe + +(: j ((cons Integer (Listof Integer)) -> (Listof Integer))) +(define (j x) + (cddr x)) ; partially safe diff --git a/typed-racket-test/optimizer/tests/derived-pair.rkt b/typed-racket-test/optimizer/tests/derived-pair.rkt index d139d6ab..c04a2ccc 100644 --- a/typed-racket-test/optimizer/tests/derived-pair.rkt +++ b/typed-racket-test/optimizer/tests/derived-pair.rkt @@ -1,13 +1,13 @@ #;#; #<