diff --git a/typed-racket-lib/typed-racket/static-contracts/optimize.rkt b/typed-racket-lib/typed-racket/static-contracts/optimize.rkt index c2617332..8820b507 100644 --- a/typed-racket-lib/typed-racket/static-contracts/optimize.rkt +++ b/typed-racket-lib/typed-racket/static-contracts/optimize.rkt @@ -117,16 +117,50 @@ [(arr/sc: args rest (list (any/sc:) ...)) (arr/sc args rest #f)] [(none/sc:) any/sc] - [(app sc-terminal-kind 'flat) any/sc] + [(or/sc: (? flat-terminal-kind?) ...) any/sc] + [(? flat-terminal-kind?) any/sc] [else sc])) +(define (flat-terminal-kind? sc) + (eq? 'flat (sc-terminal-kind sc))) +;; The side of a static contract describes the source of the values that +;; the contract needs to check. +;; - 'positive : values exported by the server module +;; - 'negative : values imported from a client module +;; - 'both : values from both server & client +(define (side? v) + (memq v '(positive negative both))) + +;; A _weak side_ is a side that is currently unsafe to optimize +;; Example: +;; when optimizing an `(or/sc scs ...)` on the 'positive side, +;; each of the `scs` should be optimized on the '(weak positive) side, +;; and their sub-contracts --- if any --- may be optimized on the 'positive side +(define (weak-side? x) + (match x + [(list 'weak (? side?)) + #true] + [_ + #false])) + +(define (strengthen-side side) + (if (weak-side? side) + (second side) + side)) + +(define (weaken-side side) + (if (weak-side? side) + side + `(weak ,side))) (define (invert-side v) - (case v - [(positive) 'negative] - [(negative) 'positive] - [(both) 'both])) + (if (weak-side? v) + (weaken-side (invert-side v)) + (case v + [(positive) 'negative] + [(negative) 'positive] + [(both) 'both]))) (define (combine-variance side var) (case var @@ -134,6 +168,55 @@ [(contravariant) (invert-side side)] [(invariant) 'both])) +;; update-side : sc? weak-side? -> weak-side? +;; Change the current side to something safe & strong-as-possible +;; for optimizing the sub-contracts of the given `sc`. +(define (update-side sc side) + (match sc + [(or/sc: scs ...) + #:when (not (andmap flat-terminal-kind? scs)) + (weaken-side side)] + [(? guarded-sc?) + (strengthen-side side)] + [_ + ;; Keep same side by default. + ;; This is precisely safe for "unguarded" static contracts like and/sc + ;; and conservatively safe for everything else. + side])) + +;; guarded-sc? : sc? -> boolean? +;; Returns #true if the given static contract represents a type with a "real" +;; type constructor. E.g. list/sc is "real" and or/sc is not. +(define (guarded-sc? sc) + (match sc + [(or (? flat-terminal-kind?) + (->/sc: _ _ _ _ _ _) + (arr/sc: _ _ _) + (async-channel/sc: _) + (box/sc: _) + (channel/sc: _) + (cons/sc: _ _) + (continuation-mark-key/sc: _) + (evt/sc: _) + (hash/sc: _ _) + (immutable-hash/sc: _ _) + (list/sc: _ ...) + (listof/sc: _) + (mutable-hash/sc: _ _) + (parameter/sc: _ _) + (promise/sc: _) + (prompt-tag/sc: _ _) + (sequence/sc: _ ...) + (set/sc: _) + (struct/sc: _ _) + (syntax/sc: _) + (vector/sc: _ ...) + (vectorof/sc: _) + (weak-hash/sc: _ _)) + #true] + [_ + #false])) + (define (remove-unused-recursive-contracts sc) (define root (generate-temporary)) (define main-table (make-free-id-table)) @@ -208,12 +291,14 @@ ;; If we trust a specific side then we drop all contracts protecting that side. (define (optimize sc #:trusted-positive [trusted-positive #f] #:trusted-negative [trusted-negative #f]) ;; single-step: reduce and trusted-side-reduce if appropriate - (define (single-step sc side) + (define (single-step sc maybe-weak-side) (define trusted - (case side - [(positive) trusted-positive] - [(negative) trusted-negative] - [(both) (and trusted-positive trusted-negative)])) + (if (weak-side? maybe-weak-side) + #false + (case maybe-weak-side + [(positive) trusted-positive] + [(negative) trusted-negative] + [(both) (and trusted-positive trusted-negative)]))) (reduce (if trusted @@ -223,8 +308,9 @@ ;; full-pass: single-step at every static contract subpart (define (full-pass sc) (define ((recur side) sc variance) - (define new-side (combine-variance side variance)) - (single-step (sc-map sc (recur new-side)) new-side)) + (define curr-side (combine-variance side variance)) + (define next-side (update-side sc curr-side)) + (single-step (sc-map sc (recur next-side)) curr-side)) ((recur 'positive) sc 'covariant)) ;; Do full passes until we reach a fix point, and then remove all unneccessary recursive parts diff --git a/typed-racket-test/succeed/issue-598.rkt b/typed-racket-test/succeed/issue-598.rkt new file mode 100644 index 00000000..aae49cf2 --- /dev/null +++ b/typed-racket-test/succeed/issue-598.rkt @@ -0,0 +1,22 @@ +#lang typed/racket/base + +(module u racket/base + (define (f b) + (set-box! b "hello")) + (provide f)) + +(define-type Maybe-Box (U #f (Boxof Integer))) + +(require/typed 'u (f (-> Maybe-Box Void))) + +(define b : Maybe-Box (box 4)) + +(module+ test + (require typed/rackunit) + + (check-exn exn:fail:contract? + (λ () (f b))) + + (check-equal? + (if (box? b) (+ 1 (unbox b)) (error 'deadcode)) + 5)) diff --git a/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt b/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt index b2f72883..1c07e751 100644 --- a/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt +++ b/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt @@ -113,7 +113,6 @@ #:neg (vectorof/sc none/sc)) ;; Heterogeneous Vectors - ;; TODO fix ability to test equality here (check-optimize (vector/sc any/sc) #:pos any/sc #:neg (vector-length/sc 1)) @@ -179,6 +178,14 @@ (check-optimize (or/sc none/sc none/sc) #:pos any/sc #:neg none/sc) + (check-optimize (or/sc set?/sc (list/sc set?/sc) (list/sc set?/sc set?/sc)) + ;; if all contracts are flat, optimize trusted positive + #:pos any/sc + #:neg (or/sc set?/sc (list/sc set?/sc) (list/sc set?/sc set?/sc))) + (check-optimize (or/sc set?/sc (list/sc (flat/sc #'symbol?)) (box/sc (flat/sc #'symbol?))) + ;; don't optimize if any contracts are non-flat --- but do optimize under guarded constructors + #:pos (or/sc set?/sc (list-length/sc 1) (box/sc (flat/sc #'symbol?))) + #:neg (or/sc set?/sc (list/sc (flat/sc #'symbol?)) (box/sc (flat/sc #'symbol?)))) ;; None (check-optimize none/sc @@ -343,6 +350,33 @@ (arr/sc empty #f (list set?/sc)) (arr/sc (list any/sc) #f (list (listof/sc set?/sc)))))) + ;; more Or case + (check-optimize + ;; (or (or ....)), both "or"s contain non-flat contracts --- don't optimize + (or/sc cons?/sc (or/sc cons?/sc (box/sc cons?/sc)) (box/sc cons?/sc)) + #:pos (or/sc cons?/sc (or/sc cons?/sc (box/sc cons?/sc)) (box/sc cons?/sc)) + #:neg (or/sc cons?/sc (or/sc cons?/sc (box/sc cons?/sc)) (box/sc cons?/sc))) + (check-optimize + ;; (or (or ...)), only the inner "or" contains a non-flat contract --- don't optimize + (or/sc cons?/sc (or/sc cons?/sc (box/sc cons?/sc))) + #:pos (or/sc cons?/sc (or/sc cons?/sc (box/sc cons?/sc))) + #:neg (or/sc cons?/sc (or/sc cons?/sc (box/sc cons?/sc)))) + (check-optimize + ;; (or (or ...)), only the outer "or" contains a non-flat contract --- still don't optimize + (or/sc (box/sc cons?/sc) (or/sc cons?/sc set?/sc)) + #:pos (or/sc (box/sc cons?/sc) (or/sc cons?/sc set?/sc)) + #:neg (or/sc (box/sc cons?/sc) (or/sc cons?/sc set?/sc))) + (check-optimize + ;; (or (and/sc ...)) where the "or" has a non-flat "and" is all flat --- don't optimize + ;; this is just to make sure `and/sc` isn't treated specially + (or/sc (box/sc cons?/sc) (and/sc cons?/sc list?/sc)) + #:pos (or/sc (box/sc cons?/sc) (and/sc cons?/sc list?/sc)) + #:neg (or/sc (box/sc cons?/sc) (and/sc cons?/sc list?/sc))) + (check-optimize + ;; (or (and ...)) where both contain flat contracts --- could optimize, but would need to recognize the and/c is flat + (or/sc set?/sc (and/sc cons?/sc list?/sc)) + #:pos (or/sc set?/sc (and/sc cons?/sc list?/sc)) + #:neg (or/sc set?/sc (and/sc cons?/sc list?/sc))) )) (define tests