From 2261861b559134bf156ceb7f4311a6d6a47e8a66 Mon Sep 17 00:00:00 2001 From: AlexKnauth Date: Mon, 9 May 2016 11:51:32 -0400 Subject: [PATCH] infer instantiations for polymorphic arguments to polymorphic functions --- macrotypes/examples/mlish.rkt | 126 ++++++++----- macrotypes/examples/tests/mlish/match2.mlish | 6 +- macrotypes/examples/tests/mlish/poly-vals.rkt | 176 ++++++++++++++++++ 3 files changed, 262 insertions(+), 46 deletions(-) create mode 100644 macrotypes/examples/tests/mlish/poly-vals.rkt diff --git a/macrotypes/examples/mlish.rkt b/macrotypes/examples/mlish.rkt index 960fc51..d220194 100644 --- a/macrotypes/examples/mlish.rkt +++ b/macrotypes/examples/mlish.rkt @@ -110,22 +110,48 @@ (~and (~not (~∀ _ _)) (~parse vars-pat #'()) body-pat))])))) - + + ;; matching possibly polymorphic types with renamings + (define-syntax ~?∀* + (pattern-expander + (lambda (stx) + (syntax-case stx () + [(?∀* vars-pat body-pat) + #'(~and (~?∀ vars body) + (~parse vars* (generate-temporaries #'vars)) + (~parse vars** (stx-map add-orig #'vars* #'vars)) + (~parse body* (inst-type #'vars** #'vars #'body)) + (~parse vars-pat #'vars**) + (~parse body-pat #'body*))])))) + ;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id) ;; finds the free Xs in the type (define (find-free-Xs Xs ty) (for/list ([X (in-list (stx->list Xs))] #:when (stx-contains-id? ty X)) X)) - + + ;; wrap-∀/free-Xs : (Syntax-Listof Id) Type -> Type + ;; If the type has free Xs, this wraps the type in an forall with those free Xs. + (define (wrap-∀/free-Xs Xs ty) + (define free-Xs (find-free-Xs Xs ty)) + (if (stx-null? free-Xs) + ty + (syntax-parse ty + [(~?∀ (Y ...) ty) + #:with [X ...] free-Xs + ((current-type-eval) + (datum->syntax #'ty (list #'∀ #'(X ... Y ...) #'ty) #'ty #'ty))]))) + ;; solve for Xs by unifying quantified fn type with the concrete types of stx's args ;; stx = the application stx = (#%app e_fn e_arg ...) ;; tyXs = input and output types from fn type ;; ie (typeof e_fn) = (-> . tyXs) ;; It infers the types of arguments from left-to-right, ;; and it expands and returns all of the arguments. - ;; It returns list of 3 values if successful, else throws a type error + ;; It returns list of 4 values if successful, else throws a type error ;; - a list of all the arguments, expanded + ;; - a list of all the argument types ;; - a list of all the type variables ;; - the constraints for substituting the types (define (solve Xs tyXs stx) @@ -140,18 +166,21 @@ '())) (syntax-parse stx [(_ e_fn . args) - (define-values (as- cs) - (for/fold ([as- null] [cs initial-cs]) + (define-values (Xs* as- a-tys cs) + (for/fold ([Xs Xs] [as- null] [a-tys null] [cs initial-cs]) ([a (in-list (syntax->list #'args))] [tyXin (in-list (syntax->list #'(τ_inX ...)))]) (define ty_in (inst-type/cs Xs cs tyXin)) - (define/with-syntax [a- ty_a] + (define/syntax-parse [a- (~?∀* Ys ty_a)] (infer+erase (if (empty? (find-free-Xs Xs ty_in)) (add-expected-ty a ty_in) a))) - (values + (define XYs (stx-append Xs #'Ys)) + (values + XYs (cons #'a- as-) - (add-constraints Xs cs (list (list ty_in #'ty_a)) + (cons #'ty_a a-tys) + (add-constraints XYs cs (list (list ty_in #'ty_a)) (list (list (inst-type/cs/orig Xs cs ty_in (λ (id1 id2) @@ -159,7 +188,7 @@ (syntax->datum id2)))) #'ty_a)))))) - (list (reverse as-) Xs cs)])])) + (list (reverse as-) (reverse a-tys) (stx-append #'Vs Xs*) cs)])])) (define (raise-app-poly-infer-error stx expected-tys given-tys e_fn) (type-error #:src stx @@ -172,6 +201,15 @@ (string-join (stx-map type->str expected-tys) ", ") (string-join (stx-map type->str given-tys) ", ")))) + ;; inst-type/cs/∀ : (Stx-Listof Id) Constraints Type-Stx -> Type-Stx + ;; Instantiates ty with the substitution, possibly wrapping it in a forall. + (define (inst-type/cs/∀ Xs cs ty) + (wrap-∀/free-Xs Xs (inst-type/cs Xs cs ty))) + ;; inst-types/cs/∀ : (Stx-Listof Id) Constraints (Stx-Listof Type-Stx) -> (Listof Type-Stx) + ;; the plural version of inst-type/cs/∀ + (define (inst-types/cs/∀ Xs cs tys) + (stx-map (lambda (t) (inst-type/cs/∀ Xs cs t)) tys)) + ;; covariant-Xs? : Type -> Bool ;; Takes a possibly polymorphic type, and returns true if all of the ;; type variables are in covariant positions within the type, false @@ -352,7 +390,25 @@ [exn:fail:type:infer? (lambda (e) #t)]) (let ([X+ ((current-type-eval) X)]) (not (or (tyvar? X+) (type? X+)))))) - (stx-remove-dups Xs)))) + (stx-remove-dups Xs))) + + (define old-join (current-join)) + + ;; new-join : Type-Stx Type-Stx -> Type-Stx + ;; Computes the join of two possibly polymorphic types, by solving the + ;; constraint that they should be equal once instantiated. + (define (new-join t1 t2) + (syntax-parse (list t1 t2) + [[(~?∀ (X ...) t1) (~?∀ (Y ...) t2)] + #:with Xs #'(X ... Y ...) + #:with cs (add-constraints #'Xs '() #'([t1 t2])) + #:with [t1* t2*] (inst-types/cs #'Xs #'cs #'[t1 t2]) + #:with t1** ((current-type-eval) #`(?∀ #,(find-free-Xs #'Xs #'t1*) t1*)) + #:with t2** ((current-type-eval) #`(?∀ #,(find-free-Xs #'Xs #'t2*) t2*)) + (old-join #'t1** #'t2**)])) + + (current-join new-join) + ) ;; define -------------------------------------------------- ;; for function defs, define infers type variables @@ -503,22 +559,9 @@ . rst)])) ... (define-syntax (Cons stx) (syntax-parse stx - ; no args and not polymorphic - [C:id #:when (and (stx-null? #'(X ...)) (stx-null? #'(τ ...))) #'(C)] - ; no args but polymorphic, check inferred type - [C:id - #:when (stx-null? #'(τ ...)) - #:with τ-expected (syntax-property #'C 'expected-type) - #:fail-unless (syntax-e #'τ-expected) - (raise - (exn:fail:type:infer - (format "~a (~a:~a): ~a: ~a" - (syntax-source stx) (syntax-line stx) (syntax-column stx) - (syntax-e #'C) - (no-expected-type-fail-msg)) - (current-continuation-marks))) - #:with (NameExpander τ-expected-arg (... ...)) ((current-type-eval) #'τ-expected) - #'(C {τ-expected-arg (... ...)})] + ; no args expected, expand to value + [C:id #:when (stx-null? #'(τ ...)) #'(C)] + ; no args given, expand to function [_:id (⊢ StructName (?∀ (X ...) (ext-stlc:→ τ ... (Name X ...))))] ; HO fn [(C τs e_arg ...) #:when (brace? #'τs) ; commit to this clause @@ -729,7 +772,8 @@ (syntax-parse stx #:datum-literals (with) [(_ e with . clauses) #:fail-when (null? (syntax->list #'clauses)) "no clauses" - #:with [e- τ_e] (infer+erase #'e) + #:with [e- (~?∀ Xs τ_e)] (infer+erase #'e) + #:fail-unless (stx-null? #'Xs) "add annotations" (syntax-parse #'clauses #:datum-literals (->) [([(~seq p ...) -> e_body] ...) #:with (pat ...) (stx-map ; use brace to indicate root pattern @@ -748,7 +792,8 @@ (define-typed-syntax match #:datum-literals (with) [(_ e with . clauses) #:fail-when (null? (syntax->list #'clauses)) "no clauses" - #:with [e- τ_e] (infer+erase #'e) + #:with [e- (~?∀ Xs τ_e)] (infer+erase #'e) + #:fail-unless (stx-null? #'Xs) "add annotations" #:with t_expect (syntax-property stx 'expected-type) ; propagate inferred type (cond [(×? #'τ_e) ;; e is tuple @@ -885,31 +930,26 @@ ;; compute fn type (ie ∀ and →) #:with [e_fn- (~?∀ Xs (~ext-stlc:→ . tyX_args))] (infer+erase #'e_fn) ;; solve for type variables Xs - #:with [(e_arg- ...) Xs* cs] (solve #'Xs #'tyX_args stx) + #:with [(e_arg- ...) (τ_arg- ...) Xs* cs] (solve #'Xs #'tyX_args stx) ;; instantiate polymorphic function type - #:with [τ_in ... τ_out] (inst-types/cs #'Xs* #'cs #'tyX_args) - #:with (unsolved-X ...) (find-free-Xs #'Xs* #'τ_out) + #:with [τ_in ... τ_out] (inst-types/cs/∀ #'Xs* #'cs #'tyX_args) ;; arity check #:fail-unless (stx-length=? #'(τ_in ...) #'e_args) (num-args-fail-msg #'e_fn #'(τ_in ...) #'e_args) ;; compute argument types - #:with (τ_arg ...) (stx-map typeof #'(e_arg- ...)) + #:with (τ_arg ...) (inst-types/cs/∀ #'Xs* #'cs #'(τ_arg- ...)) ;; typecheck args #:fail-unless (typechecks? #'(τ_arg ...) #'(τ_in ...)) (typecheck-fail-msg/multi #'(τ_in ...) #'(τ_arg ...) #'e_args) - #:with τ_out* (if (stx-null? #'(unsolved-X ...)) - #'τ_out - (syntax-parse #'τ_out - [(~?∀ (Y ...) τ_out) - (unless (→? #'τ_out) - (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn)) - (for ([X (in-list (syntax->list #'(unsolved-X ...)))]) - (unless (covariant-X? X #'τ_out) - (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn))) - #'(∀ (unsolved-X ... Y ...) τ_out)])) + #:with τ_out* (syntax-parse #'τ_out + [(~?∀ (X ...) (~?∀ (Y ...) τ_out)) + (for ([X (in-list (syntax->list #'(X ...)))] + #:when (stx-contains-id? #'Xs* X)) + (unless (covariant-X? X #'τ_out) + (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn))) + #'(?∀ (X ... Y ...) τ_out)]) (⊢ (#%app- e_fn- e_arg- ...) : τ_out*)]) - ;; cond and other conditionals (define-typed-syntax cond [(_ [(~or (~and (~datum else) (~parse test #'(ext-stlc:#%datum . #t))) diff --git a/macrotypes/examples/tests/mlish/match2.mlish b/macrotypes/examples/tests/mlish/match2.mlish index 96b3835..d64c350 100644 --- a/macrotypes/examples/tests/mlish/match2.mlish +++ b/macrotypes/examples/tests/mlish/match2.mlish @@ -17,14 +17,14 @@ (match2 (B (tup 2 3)) with [A x -> x] [C (x,y) -> y] - [B x -> x]) #:with-msg "branches have incompatible types: Int and \\(× Int Int\\)") + [B x -> x]) #:with-msg "couldn't unify \\(× Int Int\\) and Int") (typecheck-fail (match2 (B (tup 2 3)) with [A x -> (tup x x)] [C x -> x] [B x -> x]) - #:with-msg "branches have incompatible types: \\(× Int Int\\) and \\(× Int \\(× Int Int\\)\\)") + #:with-msg "couldn't unify \\(× Int Int\\) and Int") (check-type (match2 (B (tup 2 3)) with @@ -52,7 +52,7 @@ (match2 (A (tup 2 3)) with [B (x,y) -> y] [A x -> x] - [C x -> x]) #:with-msg "branches have incompatible types") + [C x -> x]) #:with-msg "couldn't unify \\(× Int Int\\) and Int") (check-type (match2 (A 1) with diff --git a/macrotypes/examples/tests/mlish/poly-vals.rkt b/macrotypes/examples/tests/mlish/poly-vals.rkt new file mode 100644 index 0000000..0dcea59 --- /dev/null +++ b/macrotypes/examples/tests/mlish/poly-vals.rkt @@ -0,0 +1,176 @@ +#lang s-exp "../../mlish.rkt" + +(require "../rackunit-typechecking.rkt") + +(define-type (Listof X) + Nil + (Cons X (Listof X))) + +(define-type (Option X) + None + (Some X)) + +(check-type (λ () Nil) : (→/test (Listof X))) +(check-type (λ () (Cons 1 Nil)) : (→ (Listof Int))) +(check-type (λ () (Cons Nil Nil)) : (→/test (Listof (Listof X)))) + +(define (nil* → (List X)) nil) +(define (cons* [x : X] [xs : (List X)] → (List X)) (cons x xs)) +(define (tup* [x : X] [y : Y] → (× X Y)) (tup x y)) + +(check-type (λ () (cons* 1 (nil*))) : (→/test (List Int))) +(check-type (λ () (cons* (nil*) (nil*))) : (→/test (List (List X)))) + +(check-type (λ () (tup* 1 2)) : (→/test (× Int Int))) +(check-type (λ () (tup* (nil*) (nil*))) : (→/test (× (List X) (List Y)))) + +(define (f [x : X] [y : Y] → (× X Y)) + (tup* x y)) + +(check-type f : (→/test X Y (× X Y))) + +(check-type + (tup* 1 2) + : (× Int Int) + -> (tup* 1 2)) + +(check-type (λ () (tup* Nil Nil)) : (→/test (× (Listof X) (Listof Y)))) + +(check-type + (if #t + Nil + (Cons 1 Nil)) + : (Listof Int)) + +(check-type + (λ () + (if #t + Nil + (Cons 1 Nil))) + : (→ (Listof Int))) + +(check-type + (λ () + (if #t + Nil + Nil)) + : (→/test (Listof X))) + +(check-type + (λ () + (if #t + Nil + (Cons Nil Nil))) + : (→/test (Listof (Listof X)))) + + +(define (g [t : (× Int Float)] → (× Int Float)) + (for/fold ([t t]) + () + (match t with + [c c. -> + (tup* c c.)]))) + +(check-type + (λ () + (let () + (tup* 1 2))) + : (→/test (× Int Int))) + +(define (zipwith [f : (→ X Y Z)] [xs : (Listof X)] [ys : (Listof Y)] -> (Listof Z)) + (match xs with + [Nil -> Nil] + [Cons x xs -> + (match ys with + [Nil -> Nil] + [Cons y ys -> + (Cons (f x y) (zipwith f xs ys))])])) + +(check-type + (zipwith Cons + (Cons 1 (Cons 2 (Cons 3 Nil))) + (Cons (Cons 2 (Cons 3 Nil)) + (Cons (Cons 4 (Cons 6 Nil)) + (Cons (Cons 6 (Cons 9 Nil)) + Nil)))) + : (Listof (Listof Int)) + -> (Cons (Cons 1 (Cons 2 (Cons 3 Nil))) + (Cons (Cons 2 (Cons 4 (Cons 6 Nil))) + (Cons (Cons 3 (Cons 6 (Cons 9 Nil))) + Nil)))) + +(check-type + (zipwith cons* + (Cons 1 (Cons 2 (Cons 3 Nil))) + (Cons (list 2 3) (Cons (list 4 6) (Cons (list 6 9) Nil)))) + : (Listof (List Int)) + -> (Cons (list 1 2 3) (Cons (list 2 4 6) (Cons (list 3 6 9) Nil)))) + +(define (first [xs : (Listof X)] → (Option X)) + (match xs with + [Nil -> None] + [Cons x xs -> (Some x)])) + +(define (map [f : (→ X Y)] [xs : (Listof X)] -> (Listof Y)) + (match xs with + [Nil -> Nil] + [Cons x xs -> + (Cons (f x) (map f xs))])) + +(check-type + (map first (Cons (Cons 1 (Cons 2 Nil)) Nil)) + : (Listof (Option Int)) + -> (Cons (Some 1) Nil)) + +(check-type + (map first (Cons Nil Nil)) + : (Listof (Option Int)) + -> (Cons None Nil)) + +(check-type + (λ () + (map first (Cons Nil Nil))) + : (→/test (Listof (Option X)))) + +(check-type + (λ () + (map first Nil)) + : (→/test (Listof (Option X)))) + +(define (last [xs : (List X)] → (Option X)) + (for/fold ([res None]) + ([x (in-list xs)]) + (Some x))) + +(check-type + (map last (Cons (cons* 1 (cons* 2 (nil*))) Nil)) + : (Listof (Option Int)) + -> (Cons (Some 2) Nil)) + +(check-type + (map last (Cons (nil*) Nil)) + : (Listof (Option Int)) + -> (Cons None Nil)) + +(check-type + (λ () + (map last (Cons (nil*) Nil))) + : (→/test (Listof (Option X)))) + +(check-type + (λ () + (map last Nil)) + : (→/test (Listof (Option X)))) + +(check-type + (λ (x) (add1 x)) + : (→ Int Int)) + +(define (h → (→ A (Listof A) (Listof A))) + (λ (x xs) + (Cons x xs))) + +(check-type + h + : (→/test (→ A (Listof A) (Listof A)))) +