diff --git a/collects/tests/typed-racket/succeed/poly-same-annotation.rkt b/collects/tests/typed-racket/succeed/poly-same-annotation.rkt new file mode 100644 index 0000000000..fcf7f4ddcf --- /dev/null +++ b/collects/tests/typed-racket/succeed/poly-same-annotation.rkt @@ -0,0 +1,11 @@ +#lang typed/racket/base +(require racket/list) + +(: f1 (All (A) (Listof A) -> (Listof A))) +(define (f1 a) + (map (λ: ([a : A]) a) empty)) + +(: f2 (All (A) (Listof A) -> (Listof A))) +(define (f2 a) + (map (λ: ([a : A]) a) empty)) + diff --git a/collects/tests/typed-racket/succeed/pr13094.rkt b/collects/tests/typed-racket/succeed/pr13094.rkt index 9abc1baa9e..9405a28bc8 100644 --- a/collects/tests/typed-racket/succeed/pr13094.rkt +++ b/collects/tests/typed-racket/succeed/pr13094.rkt @@ -2,7 +2,7 @@ ;; Test alpha equivalent types -(: x (All (A) (A -> A))) +(: x (All (A) (A -> A))) (define x (plambda: (C) ((f : C)) f)) (: y (All (A) (A A -> A))) diff --git a/collects/tests/typed-racket/succeed/scoped-type-vars.rkt b/collects/tests/typed-racket/succeed/scoped-type-vars.rkt new file mode 100644 index 0000000000..44d9624b1d --- /dev/null +++ b/collects/tests/typed-racket/succeed/scoped-type-vars.rkt @@ -0,0 +1,25 @@ +#lang typed/racket/base + + +(: f1 (All (A) (A -> A))) +(define f1 (lambda: ((x : A)) x)) + +(: f2 (All (A) (A A A -> A))) +(define f2 + (ann + (plambda: (C) ((x : A) (y : B) (z : C)) (or x y z)) + (All (B) (B B B -> B)))) + +(: f3 (All (A ...) (All (B ...) (A ... A -> (B ... B -> Natural))))) +(define f3 (lambda: (x : A ... A) (lambda: (y : B ... B) (+ (length x) (length y))))) + +;; PR 13622 +(: f4 (All (x) (All (y z) (x x x -> Any)))) +(define f4 (plambda: (x) ((x : x) (y : x) (z : x)) (or x y z))) + +;; PR 13539 +(: f5 (All (A) (All (B) (A B -> Integer)))) +(define (f5 x y) + (: z B) + (define z y) + 5) diff --git a/collects/tests/typed-racket/unit-tests/typecheck-tests.rkt b/collects/tests/typed-racket/unit-tests/typecheck-tests.rkt index 7861d5ebc3..225de0e400 100644 --- a/collects/tests/typed-racket/unit-tests/typecheck-tests.rkt +++ b/collects/tests/typed-racket/unit-tests/typecheck-tests.rkt @@ -1632,7 +1632,8 @@ (-polydots (a) (->... (list) (a a) (make-ListDots a 'a))) #:expected (ret (-polydots (a) (->... (list) (a a) (make-ListDots a 'a))))] - + [tc-e/t (ann (lambda (x) #t) (All (a) Any)) + (-poly (a) Univ)] [tc-e ((inst filter Any Symbol) symbol? null) (-lst -Symbol)] diff --git a/collects/typed-racket/env/scoped-tvar-env.rkt b/collects/typed-racket/env/scoped-tvar-env.rkt new file mode 100644 index 0000000000..c37fa5d656 --- /dev/null +++ b/collects/typed-racket/env/scoped-tvar-env.rkt @@ -0,0 +1,66 @@ +#lang racket/base + +;; Maintain mapping of type variables introduced by literal Alls in type annotations. + +(require "../utils/utils.rkt" + (for-template racket/base) + (rep type-rep) + syntax/parse + unstable/debug + syntax/id-table + racket/contract + racket/match + racket/list + racket/dict) + +(provide register-scoped-tvars lookup-scoped-tvars + add-scoped-tvars lookup-scoped-tvar-layer) + +;; tvar-stx-mapping: (hash/c syntax? (listof (listof identifier?))) +(define tvar-stx-mapping (make-weak-hash)) + +;; add-scoped-tvars: syntax? (or/c #f (listof identifier)) -> void? +;; Annotate the given expression with the given identifiers if it is safe. +;; If there are no identifiers, then nothing is done. +;; Safe expressions are lambda, case-lambda, or the expansion of keyword and opt-lambda forms. +(define (add-scoped-tvars stx vars) + (match vars + [(or #f (list)) (void)] + [else + (define (add-vars stx) + (hash-update! tvar-stx-mapping stx (lambda (old-vars) (cons vars old-vars)) null)) + (let loop ((stx stx)) + (syntax-parse stx + #:literals (#%expression #%plain-lambda let-values case-lambda) + [(#%expression e) (loop #'e)] + [(~or (case-lambda formals . body) (#%plain-lambda formals . body)) + (add-vars stx)] + [(let-values ([(f) fun]) . body) + #:when (or (syntax-property stx 'kw-lambda) + (syntax-property stx 'opt-lambda)) + (add-vars #'fun)] + [e (void)]))])) + +;; lookup-scoped-tvar-layer: syntax? -> (listof (listof identifier?)) +;; Returns the identifiers associated with a given syntax object. +;; There can be multiple sections of identifiers, which correspond to multiple poly types. +(define (lookup-scoped-tvar-layer stx) + (hash-ref tvar-stx-mapping stx null)) + +;; tvar-annotation? := (listof (listof (or/c (listof identifier?) +;; (list (listof identifier?) identifier?)))) +;; tvar-mapping: (free-id-table/c tvar-annotation?) +;; Keeps track of type variables that should be introduced when type checking +;; the definition for an identifier. +(define tvar-mapping (make-free-id-table)) + +;; lookup-scoped-tvars: identifier -> (or/c #f tvar-annotation?) +;; Lookup an indentifier in the scoped tvar-mapping. +(define (lookup-scoped-tvars id) + (dict-ref tvar-mapping id #f)) + +;; Register type variables for an indentifier in the scoped tvar-mapping. +;; register-scoped-tvars: identifier? tvar-annotation? -> void? +(define (register-scoped-tvars id tvars) + (dict-set! tvar-mapping id tvars)) + diff --git a/collects/typed-racket/private/parse-type.rkt b/collects/typed-racket/private/parse-type.rkt index c2d0e3bce8..2d6d83d00a 100644 --- a/collects/typed-racket/private/parse-type.rkt +++ b/collects/typed-racket/private/parse-type.rkt @@ -27,7 +27,9 @@ ;; context of the given syntax object [parse-type/id (syntax? c:any/c . c:-> . Type/c)] [parse-tc-results (syntax? . c:-> . tc-results/c)] - [parse-tc-results/id (syntax? c:any/c . c:-> . tc-results/c)]) + [parse-tc-results/id (syntax? c:any/c . c:-> . tc-results/c)] + [parse-literal-alls (syntax? . c:-> . (values (listof identifier?) + (listof identifier?)))]) (provide star ddd/bound) (print-complex-filters? #t) @@ -39,27 +41,35 @@ (let* ([stx* (datum->syntax loc datum loc loc)]) (p stx*))) -;; Syntax -> Type -;; Parse the body under a Forall quantifier -(define (parse-all-body s) - (syntax-parse s - [(ty) - (parse-type #'ty)] - [(x ...) +;; The body of a Forall type +(define-syntax-class all-body + #:attributes (type) + (pattern (type)) + (pattern (x ...) #:fail-unless (= 1 (length (for/list ([i (syntax->list #'(x ...))] #:when (and (identifier? i) (free-identifier=? i #'t:->))) - i))) - #f - (parse-type s)])) + i))) #f + #:attr type #'(x ...))) -;; Syntax (Syntax -> Type) -> Type +(define (parse-literal-alls stx) + (syntax-parse stx #:literals (t:All) + [(t:All (~or (vars:id ... v:id dd:ddd) (vars:id ...)) . t:all-body) + (define vars-list (syntax->list #'(vars ...))) + (cons (if (attribute v) + (list vars-list #'v) + vars-list) + (parse-literal-alls #'t.type))] + [_ null])) + + +;; Syntax -> Type ;; Parse a Forall type -(define (parse-all-type stx parse-type) +(define (parse-all-type stx) ;(printf "parse-all-type: ~a \n" (syntax->datum stx)) (syntax-parse stx #:literals (t:All) - [((~and kw t:All) (vars:id ... v:id dd:ddd) . t) + [((~and kw t:All) (vars:id ... v:id dd:ddd) . t:all-body) (when (check-duplicate-identifier (syntax->list #'(vars ... v))) (tc-error "All: duplicate type variable or index")) (let* ([vars (map syntax-e (syntax->list #'(vars ...)))] @@ -67,14 +77,14 @@ (add-disappeared-use #'kw) (extend-indexes v (extend-tvars vars - (make-PolyDots (append vars (list v)) (parse-all-body #'t)))))] - [((~and kw t:All) (vars:id ...) . t) + (make-PolyDots (append vars (list v)) (parse-type #'t.type)))))] + [((~and kw t:All) (vars:id ...) . t:all-body) (when (check-duplicate-identifier (syntax->list #'(vars ...))) (tc-error "All: duplicate type variable")) (let* ([vars (map syntax-e (syntax->list #'(vars ...)))]) (add-disappeared-use #'kw) (extend-tvars vars - (make-Poly vars (parse-all-body #'t))))] + (make-Poly vars (parse-type #'t.type))))] [(t:All (_:id ...) _ _ _ ...) (tc-error "All: too many forms in body of All type")] [(t:All . rest) (tc-error "All: bad syntax")])) @@ -252,7 +262,7 @@ (add-disappeared-use #'kw) (-val (syntax->datum #'t))] [((~and kw t:All) . rest) - (parse-all-type stx parse-type)] + (parse-all-type stx)] [((~and kw t:Opaque) p?) (add-disappeared-use #'kw) (make-Opaque #'p? (syntax-local-certifier))] diff --git a/collects/typed-racket/private/type-annotation.rkt b/collects/typed-racket/private/type-annotation.rkt index b9a312a7b0..faa8aa9a7a 100644 --- a/collects/typed-racket/private/type-annotation.rkt +++ b/collects/typed-racket/private/type-annotation.rkt @@ -1,9 +1,9 @@ #lang racket/base (require "../utils/utils.rkt" - (rep type-rep) - (utils tc-utils) - (env global-env mvar-env) + (rep type-rep) + (utils tc-utils) + (env global-env mvar-env scoped-tvar-env) (except-in (types subtype union resolve utils generalize)) (private parse-type) (contract-req) @@ -57,6 +57,7 @@ (define (type-ascription stx) (define (pt prop) + (add-scoped-tvars stx (parse-literal-alls prop)) (if (syntax? prop) (parse-tc-results prop) (parse-tc-results/id stx prop))) diff --git a/collects/typed-racket/typecheck/tc-lambda-unit.rkt b/collects/typed-racket/typecheck/tc-lambda-unit.rkt index cf7ea59a64..3fb5312a3d 100644 --- a/collects/typed-racket/typecheck/tc-lambda-unit.rkt +++ b/collects/typed-racket/typecheck/tc-lambda-unit.rkt @@ -2,14 +2,14 @@ (require "../utils/utils.rkt" racket/dict racket/list syntax/parse racket/syntax syntax/stx - racket/match syntax/id-table + racket/match syntax/id-table racket/set (contract-req) (except-in (rep type-rep) make-arr) (rename-in (except-in (types abbrev utils union) -> ->* one-of/c) [make-arr* make-arr]) (private type-annotation) (typecheck signatures tc-metafunctions tc-subst check-below) - (env type-env-structs lexical-env tvar-env index-env) + (env type-env-structs lexical-env tvar-env index-env scoped-tvar-env) (utils tc-utils) (for-template racket/base "internal-forms.rkt")) @@ -328,83 +328,124 @@ (define d (syntax-property stx 'typechecker:plambda)) (and d (car (flatten d)))) +(define (has-poly-annotation? form) + (or (plambda-prop form) (cons? (lookup-scoped-tvar-layer form)))) + +(define (remove-poly-layer tvarss) + (filter cons? (map rest tvarss))) + +(define (get-poly-layer tvarss) + (map first tvarss)) + +(define (get-poly-tvarss form) + (let ([plambda-tvars + (let ([p (plambda-prop form)]) + (match (and p (map syntax-e (syntax->list p))) + [#f #f] + [(list var ... dvar '...) + (list (list var dvar))] + [(list id ...) + (list id)]))] + [scoped-tvarss + (for/list ((tvarss (lookup-scoped-tvar-layer form))) + (for/list ((tvar tvarss)) + (match tvar + [(list (list v ...) dotted-v) + (list (map syntax-e v) (syntax-e dotted-v))] + [(list v ...) (map syntax-e v)])))]) + (if plambda-tvars + (cons plambda-tvars scoped-tvarss) + scoped-tvarss))) + + ;; tc/plambda syntax syntax-list syntax-list type -> Poly ;; formals and bodies must by syntax-lists -(define/cond-contract (tc/plambda form formals bodies expected) +(define/cond-contract (tc/plambda form tvarss-list formals bodies expected) (syntax? syntax? syntax? (or/c tc-results/c #f) . -> . Type/c) (define/cond-contract (maybe-loop form formals bodies expected) (syntax? syntax? syntax? tc-results/c . -> . Type/c) (match expected - [(tc-result1: (Function: _)) (tc/mono-lambda/type formals bodies expected)] [(tc-result1: (or (Poly: _ _) (PolyDots: _ _))) - (tc/plambda form formals bodies expected)] - [(tc-result1: (Error:)) (tc/mono-lambda/type formals bodies #f)] + (tc/plambda form (remove-poly-layer tvarss-list) formals bodies expected)] [(tc-result1: (and v (Values: _))) (maybe-loop form formals bodies (values->tc-results v #f))] - [_ (int-err "expected not an appropriate tc-result: ~a" expected)])) + [_ + (define remaining-layers (remove-poly-layer tvarss-list)) + (if (null? remaining-layers) + (tc/mono-lambda/type formals bodies expected) + (tc/plambda form remaining-layers formals bodies expected))])) + ;; check the bodies appropriately + ;; and make both annotated and declared type variables point to the + ;; same actual type variables (the fresh names) + (define (extend-and-loop form ns formals bodies expected) + (let loop ((tvarss tvarss)) + (match tvarss + [(list) (maybe-loop form formals bodies expected)] + [(cons (list (list tvars ...) dotted) rest-tvarss) + (extend-indexes dotted + (extend-tvars/new tvars ns + (loop rest-tvarss)))] + [(cons tvars rest-tvarss) + (extend-tvars/new tvars ns + (loop rest-tvarss))]))) + (define tvarss (get-poly-layer tvarss-list)) + (match expected [(tc-result1: (and t (Poly-fresh: ns fresh-ns expected*))) - (let* ([tvars (let ([p (plambda-prop form)]) - (when (and (pair? p) (eq? '... (car (last p)))) - (tc-error - "Expected a polymorphic function without ..., but given function had ...")) - (and p (map syntax-e (syntax->list p))))]) - ;; make sure the declared type variable arity matches up with the - ;; annotated type variable arity - (when tvars - (unless (= (length tvars) (length ns)) - (tc-error "Expected ~a type variables, but given ~a" - (length ns) (length tvars)))) - ;; check the bodies appropriately - (if tvars - ;; make both annotated and given type variables point to the - ;; same actual type variables (the fresh names) - (extend-tvars/new ns fresh-ns - (extend-tvars/new tvars fresh-ns - (maybe-loop form formals bodies (ret expected*)))) - ;; no plambda: type variables given - (extend-tvars/new ns fresh-ns - (maybe-loop form formals bodies (ret expected*)))) - t)] + ;; make sure the declared and annotated type variable arities match up + ;; with the expected type variable arity + (for ((tvars tvarss)) + (when (and (cons? tvars) (list? (first tvars))) + (tc-error + "Expected a polymorphic function without ..., but given function/annotation had ...")) + (unless (= (length tvars) (length fresh-ns)) + (tc-error "Expected ~a type variables, but given ~a" + (length fresh-ns) (length tvars)))) + (make-Poly #:original-names ns fresh-ns (extend-and-loop form fresh-ns formals bodies (ret expected*)))] [(tc-result1: (and t (PolyDots-names: (list ns ... dvar) expected*))) - (let-values - ([(tvars dotted) - (let ([p (plambda-prop form)]) - (if p - (match (map syntax-e (syntax->list p)) - [(list var ... dvar '...) - (values var dvar)] - [_ (tc-error "Expected a polymorphic function with ..., but given function had no ...")]) - (values ns dvar)))]) - ;; check the body for side effect - (extend-indexes dotted - (extend-tvars tvars - (maybe-loop form formals bodies (ret expected*)))) - t)] - [(or (tc-result1: _) (tc-any-results:) #f) - (match (map syntax-e (syntax->list (plambda-prop form))) - [(list tvars ... dotted-var '...) - (let* ([ty (extend-indexes dotted-var - (extend-tvars tvars - (tc/mono-lambda/type formals bodies #f)))]) - (make-PolyDots (append tvars (list dotted-var)) ty))] - [tvars - (let* (;; manually make some fresh names since - ;; we don't use a match expander - [fresh-tvars (map gensym tvars)] - [ty (extend-tvars/new tvars fresh-tvars - (tc/mono-lambda/type formals bodies #f))]) - ;(printf "plambda: ~a ~a ~a \n" literal-tvars new-tvars ty) - (make-Poly fresh-tvars ty #:original-names tvars))])] - [_ (int-err "not a good expected value: ~a" expected)])) + ;; make sure the declared and annotated type variable arities match up + ;; with the expected type variable arity + (for ((tvars tvarss)) + (match tvars + [(list (list vars ...) dotted) + (unless (= (length vars) (length ns)) + (tc-error "Expected ~a non-dotted type variables, but given ~a" + (length ns) (length vars)))] + [else + (tc-error "Expected a polymorphic function with ..., but function/annotation had no ...")])) + (make-PolyDots (append ns (list dvar)) (extend-and-loop form ns formals bodies (ret expected*)))] + [(or (tc-results: _) (tc-any-results:) #f) + (define lengths + (for/set ((tvars tvarss)) + (match tvars + [(list (list vars ...) dotted) + (length vars)] + [(list vars ...) + (length vars)]))) + (define dots + (for/set ((tvars tvarss)) + (match tvars + [(list (list vars ...) dotted) #t] + [(list vars ...) #f]))) + (unless (= 1 (set-count lengths)) + (tc-error "Expected annotations to have the same number of type variables, but given ~a" + (set->list lengths))) + (unless (= 1 (set-count dots)) + (tc-error "Expected annotations to all have ... or none to have ..., but given both")) + (define dotted (and (set-first dots) (second (first tvarss)))) + (define ns (build-list (set-first lengths) (lambda (_) (gensym)))) + (define results (extend-and-loop form ns formals bodies expected)) + (if dotted + (make-PolyDots (append ns (list dotted)) results) + (make-Poly #:original-names (first tvarss) ns results))])) ;; typecheck a sequence of case-lambda clauses, which is possibly polymorphic ;; tc/lambda/internal syntax syntax-list syntax-list option[type] -> tc-result (define (tc/lambda/internal form formals bodies expected) - (if (or (plambda-prop form) + (if (or (has-poly-annotation? form) (match expected [(tc-result1: t) (or (Poly? t) (PolyDots? t))] [_ #f])) - (ret (tc/plambda form formals bodies expected) true-filter) + (ret (tc/plambda form (get-poly-tvarss form) formals bodies expected) true-filter) (ret (tc/mono-lambda/type formals bodies expected) true-filter))) ;; tc/lambda : syntax syntax-list syntax-list -> tc-result diff --git a/collects/typed-racket/typecheck/tc-toplevel.rkt b/collects/typed-racket/typecheck/tc-toplevel.rkt index b3c040ec9f..49d9bab741 100644 --- a/collects/typed-racket/typecheck/tc-toplevel.rkt +++ b/collects/typed-racket/typecheck/tc-toplevel.rkt @@ -7,7 +7,8 @@ (rep type-rep free-variance) (types utils abbrev type-table) (private parse-type type-annotation type-contract) - (env global-env init-envs type-name-env type-alias-env lexical-env env-req mvar-env) + (env global-env init-envs type-name-env type-alias-env + lexical-env env-req mvar-env scoped-tvar-env) (utils tc-utils mutated-vars) (typecheck provide-handling def-binding tc-structs typechecker) @@ -150,7 +151,8 @@ ;; top-level type annotation [(define-values () (begin (quote-syntax (:-internal id:identifier ty)) (#%plain-app values))) - (register-type/undefined #'id (parse-type #'ty))] + (register-type/undefined #'id (parse-type #'ty)) + (register-scoped-tvars #'id (parse-literal-alls #'ty))] ;; values definitions @@ -236,6 +238,8 @@ [ts (map lookup-type vars)]) (unless (for/and ([v (syntax->list #'(var ...))]) (free-id-table-ref unann-defs v (lambda _ #f))) + (when (= 1 (length vars)) + (add-scoped-tvars #'expr (lookup-scoped-tvars (first vars)))) (tc-expr/check #'expr (ret ts))) (void))]