diff --git a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/static-contracts/constraints.rkt b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/static-contracts/constraints.rkt index bd543879c0..c6b6b422fe 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/static-contracts/constraints.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/static-contracts/constraints.rkt @@ -39,6 +39,7 @@ racket/match racket/list racket/format + racket/function racket/contract racket/dict racket/set @@ -64,20 +65,82 @@ (module structs racket/base (require racket/contract + racket/match + racket/dict + racket/list racket/set syntax/id-table "kinds.rkt") (provide (contract-out + ;; constraint: value must be below max [struct constraint ([value kind-max?] [max contract-kind?])] - [struct kind-max ([variables free-id-table?] [max contract-kind?])] + ;; kind-max: represents the maximum kind across all of the variables and the specified kind + [struct kind-max ([variables free-id-set?] [max contract-kind?])] + ;; contract-restrict: represents a contract with value, recursive-values maps mentioned + ;; recursive parts to kind-maxes, constraints are constraints that need to hold [struct contract-restrict ([value kind-max?] [recursive-values free-id-table?] [constraints (set/c constraint?)])])) + (define free-id-set? free-id-table?) (struct constraint (value max) #:transparent) - (struct kind-max (variables max) #:transparent) - (struct contract-restrict (value recursive-values constraints) #:transparent)) + (struct kind-max (variables max) #:transparent + #:methods gen:custom-write + [(define (write-proc v port mode) + (match-define (kind-max variables max) v) + (define recur + (case mode + [(#t) write] + [(#f) display] + [else (lambda (p port) (print p port mode))])) + (define-values (open close) + (if (equal? mode 0) + (values "(" ")") + (values "#<" ">"))) + (display open port) + (fprintf port "kind-max") + (display " " port) + (display (map syntax-e (dict-keys variables)) port) + (display " " port) + (recur max port) + (display close port))]) + (struct contract-restrict (value recursive-values constraints) + #:methods gen:custom-write + [(define (write-proc v port mode) + (match-define (contract-restrict value recursive-values constraints) v) + (define recur + (case mode + [(#t) write] + [(#f) display] + [else (lambda (p port) (print p port mode))])) + (define-values (open close) + (if (equal? mode 0) + (values "(" ")") + (values "#<" ">"))) + (display open port) + (fprintf port "contract-restrict") + (display " " port) + (recur value port) + + (display " (" port) + (define (recur-pair name val) + (fprintf port "(~a " (syntax->datum name)) + (recur val port) + (display ")" port)) + (define-values (names vals) + (let ((assoc (dict->list recursive-values))) + (values (map car assoc) (map cdr assoc)))) + (when (cons? names) + (recur-pair (first names) (first vals)) + (for ((name (rest names)) + (val (rest vals))) + (display " " port) + (recur-pair name val))) + (display ") " port) + (recur constraints port) + (display close port))] + #:transparent)) (require 'structs) (provide (struct-out kind-max)) @@ -113,12 +176,22 @@ (~a "required " (name bound) " but generated " (name actual))) +(define (trivial-constraint? con) + (match con + [(constraint _ 'impersonator) + #t] + [(constraint (kind-max (app dict-count 0) actual) bound) + (contract-kind<= actual bound)] + [else #f])) + + (define (add-constraint cr max) - (if (equal? 'impersonator max) - cr - (match cr - [(contract-restrict v rec constraints) - (contract-restrict v rec (set-add constraints (constraint v max)))]))) + (match cr + [(contract-restrict v rec constraints) + (define con (constraint v max)) + (if (trivial-constraint? con) + cr + (contract-restrict v rec (set-add constraints con)))])) (define (add-recursive-values cr dict) (match cr @@ -151,6 +224,20 @@ (define (instantiate-cr cr lookup-id) + (define (instantiate-kind-max km) + (match km + [(kind-max ids actual) + (define-values (bound-ids unbound-ids) + (partition (lambda (id) (member id names)) (dict-keys ids))) + (merge-kind-maxes 'flat (cons (kind-max (apply free-id-set unbound-ids) actual) + (for/list ([id (in-list bound-ids)]) + (contract-restrict-value (lookup-id id)))))])) + + (define (instantiate-constraint con) + (match con + [(constraint km bound) + (constraint (instantiate-kind-max km) bound)])) + (match cr [(contract-restrict (kind-max ids max) rec constraints) (define-values (bound-ids unbound-ids) @@ -159,7 +246,9 @@ (contract-restrict (kind-max (apply free-id-set unbound-ids) max) rec - constraints) + (apply set + (filter (negate trivial-constraint?) + (set-map constraints instantiate-constraint)))) (map lookup-id bound-ids)))])) (for ([name names] [cr crs]) diff --git a/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/contract-tests.rkt b/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/contract-tests.rkt index 19956d165c..8fe6f81fa4 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/contract-tests.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/contract-tests.rkt @@ -62,6 +62,9 @@ (t (-polydots (a) -Symbol)) (t (-polydots (a) (->... (list) (a a) -Symbol))) + (t (-mu x (-Syntax x))) + + (t/fail ((-poly (a) (-vec a)) . -> . -Symbol) "cannot generate contract for non-function polymorphic type") (t/fail (-> (-poly (a b) (-> (Un a b) (Un a b))) Univ)