diff --git a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/infer/infer-unit.rkt b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/infer/infer-unit.rkt index 5ee1a5e8..b2a719bf 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/infer/infer-unit.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/infer/infer-unit.rkt @@ -723,9 +723,10 @@ #f])))) ;; C : cset? - set of constraints found by the inference engine +;; X : (listof symbol?) - type variables that must have entries ;; Y : (listof symbol?) - index variables that must have entries ;; R : Type/c - result type into which we will be substituting -(define/cond-contract (subst-gen C Y R) +(define/cond-contract (subst-gen C X Y R) (cset? (listof symbol?) (or/c Values/c AnyValues? ValuesDots?) . -> . (or/c #f substitution/c)) (define var-hash (free-vars-hash (free-vars* R))) (define idx-hash (free-vars-hash (free-idxs* R))) @@ -816,7 +817,7 @@ (for/hash ([(k v) (in-hash cmap)]) (values k (t-subst (constraint->type v var-hash)))))]) ;; verify that we got all the important variables - (and (for/and ([v (in-list (fv R))]) + (and (for/and ([v (in-list X)]) (let ([entry (hash-ref subst v #f)]) ;; Make sure we got a subst entry for a type var ;; (i.e. just a type to substitute) @@ -867,8 +868,8 @@ [cs (and expected-cset (cgen/list null X Y S T #:expected-cset expected-cset))] [cs* (% cset-meet cs expected-cset)]) - (and cs* (if R (subst-gen cs* Y R) #t)))) - infer)) ;to export a variable binding and not syntax + (and cs* (if R (subst-gen cs* X Y R) #t)))) + infer)) ;to export a variable binding and not syntax ;; like infer, but T-var is the vararg type: (define (infer/vararg X Y S T T-var R [expected #f]) @@ -903,6 +904,6 @@ #:return-unless cs #f (define m (cset-meet cs expected-cset)) #:return-unless m #f - (subst-gen m (list dotted-var) R))) + (subst-gen m X (list dotted-var) R))) diff --git a/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/infer-tests.rkt b/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/infer-tests.rkt index 71d2657a..b2961484 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/infer-tests.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/unit-tests/infer-tests.rkt @@ -20,12 +20,15 @@ (list (quote elems) ...)))) (begin-for-syntax + (define-splicing-syntax-class result + (pattern (~seq) #:with v #'#f) + (pattern (~seq #:result v:expr))) (define-splicing-syntax-class vars (pattern (~seq) #:with vars #'empty) - (pattern (~seq #:vars vars:expr) )) + (pattern (~seq #:vars vars:expr))) (define-splicing-syntax-class indices (pattern (~seq) #:with indices #'empty) - (pattern (~seq #:indices indices:expr) )) + (pattern (~seq #:indices indices:expr))) (define-splicing-syntax-class pass (pattern (~seq) #:with pass #'#t) (pattern #:pass #:with pass #'#t) @@ -33,20 +36,20 @@ (define-syntax (infer-t stx) (syntax-parse stx - ([_ S:expr T:expr :vars :indices :pass] + ([_ S:expr T:expr R:result :vars :indices :pass] (syntax/loc stx (test-case (format "~a ~a~a" S T (if pass "" " should fail")) - (define result (infer vars indices (list S) (list T) #f)) - (unless (equal? result pass) + (define result (infer vars indices (list S) (list T) R.v)) + (unless (if pass result (not result)) (fail-check "Could not infer a substitution"))))))) (define-syntax (infer-l stx) (syntax-parse stx - ([_ S:expr T:expr :vars :indices :pass] + ([_ S:expr T:expr R:result :vars :indices :pass] (syntax/loc stx (test-case (format "~a ~a~a" S T (if pass "" " should fail")) - (define result (infer vars indices S T #f)) - (unless (equal? result pass) + (define result (infer vars indices S T R.v)) + (unless (if pass result (not result)) (fail-check "Could not infer a substitution"))))))) @@ -87,6 +90,7 @@ (test-suite "Tests for infer" (infer-t Univ Univ) (infer-t (-v a) Univ) + (infer-t (-v a) (-v a) #:result (-v a)) (infer-t Univ (-v a) #:fail) (infer-t Univ (-v a) #:vars '(a)) (infer-t (-v a) Univ #:vars '(a)) @@ -101,6 +105,9 @@ (infer-t (make-ListDots -Symbol 'b) (make-ListDots Univ 'b) #:indices '(b)) (infer-t (make-ListDots (-v b) 'b) (make-ListDots (-v b) 'b) #:indices '(b)) (infer-t (make-ListDots (-v b) 'b) (make-ListDots Univ 'b) #:indices '(b)) + (infer-t (-pair (-v a) (make-ListDots (-v b) 'b)) + (-pair (-v a) (make-ListDots (-v b) 'b)) + #:result (-v a)) [infer-t (->... null ((-v a) a) (-v b)) (-> -Symbol -String) #:vars '(b) #:indices '(a)] [infer-t (->... null ((-v a) a) (make-ListDots (-v a) 'a)) (-> -String -Symbol (-lst* -String -Symbol)) #:indices '(a)]