diff --git a/macrotypes/type-constraints.rkt b/macrotypes/type-constraints.rkt index 8d6fd33..848a43a 100644 --- a/macrotypes/type-constraints.rkt +++ b/macrotypes/type-constraints.rkt @@ -1,6 +1,7 @@ #lang racket/base (provide add-constraints + add-constraints/var? lookup lookup-Xs/keep-unsolved inst-type @@ -23,9 +24,15 @@ ;; unification algorithm for local type inference. (define (add-constraints Xs substs new-cs [orig-cs new-cs]) (define Xs* (stx->list Xs)) + (define (X? X) + (member X Xs* free-identifier=?)) + (add-constraints/var? Xs* X? substs new-cs orig-cs)) + +(define (add-constraints/var? Xs* var? substs new-cs [orig-cs new-cs]) + (define Xs (stx->list Xs*)) (define Ys (stx-map stx-car substs)) (define-syntax-class var - [pattern x:id #:when (member #'x Xs* free-identifier=?)]) + [pattern x:id #:when (var? #'x)]) (syntax-parse new-cs [() substs] [([a:var b] . rst) @@ -36,26 +43,29 @@ ;; or #'a already maps to a type that conflicts with #'b. ;; In either case, whatever #'a maps to must be equivalent ;; to #'b, so add that to the constraints. - (add-constraints + (add-constraints/var? Xs + var? substs (cons (list (lookup #'a substs) #'b) #'rst) orig-cs)] [else (define entry (list #'a #'b)) - (add-constraints - Xs* + (add-constraints/var? + Xs + var? ;; Add the mapping #'a -> #'b to the substitution, (add-substitution-entry entry substs) ;; and substitute that in each of the constraints. (cs-substitute-entry entry #'rst) orig-cs)])] [([a b:var] . rst) - (add-constraints Xs* - substs - #'([b a] . rst) - orig-cs)] + (add-constraints/var? Xs + var? + substs + #'([b a] . rst) + orig-cs)] [([a b] . rst) ;; If #'a and #'b are base types, check that they're equal. ;; Identifers not within Xs count as base types. @@ -74,25 +84,28 @@ (string-join (map type->str (stx-map stx-car orig-cs)) ", ") (string-join (map type->str (stx-map stx-cadr orig-cs)) ", ")) #'a #'b)) - (add-constraints Xs* - substs - #'rst - orig-cs)] + (add-constraints/var? Xs + var? + substs + #'rst + orig-cs)] [else (syntax-parse #'[a b] [_ #:when (typecheck? #'a #'b) - (add-constraints Xs - substs - #'rst - orig-cs)] + (add-constraints/var? Xs + var? + substs + #'rst + orig-cs)] [((~Any tycons1 τ1 ...) (~Any tycons2 τ2 ...)) #:when (typecheck? #'tycons1 #'tycons2) #:when (stx-length=? #'[τ1 ...] #'[τ2 ...]) - (add-constraints Xs - substs - #'((τ1 τ2) ... . rst) - orig-cs)] + (add-constraints/var? Xs + var? + substs + #'((τ1 τ2) ... . rst) + orig-cs)] [else (type-error #:src (get-orig #'b) #:msg (format "couldn't unify ~~a and ~~a\n expected: ~a\n given: ~a" diff --git a/turnstile/examples/infer.rkt b/turnstile/examples/infer.rkt new file mode 100644 index 0000000..4e00459 --- /dev/null +++ b/turnstile/examples/infer.rkt @@ -0,0 +1,160 @@ +#lang turnstile +(extends "ext-stlc.rkt" #:except #%app λ) +(require (only-in "sysf.rkt" ∀ ~∀ ∀? Λ)) +(reuse cons [head hd] [tail tl] nil [isnil nil?] List list #:from "stlc+cons.rkt") +(require (only-in "stlc+cons.rkt" ~List)) +(reuse tup × proj #:from "stlc+tup.rkt") +(reuse define-type-alias #:from "stlc+reco+var.rkt") +(require (for-syntax macrotypes/type-constraints)) +(provide hd tl nil? ∀) + +;; (Some [X ...] τ_body (Constraints (Constraint τ_1 τ_2) ...)) +(define-type-constructor Some #:arity = 2 #:bvs >= 0) +(define-type-constructor Constraint #:arity = 2) +(define-type-constructor Constraints #:arity >= 0) +(define-syntax Cs + (syntax-parser + [(_ [a b] ...) + (Cs #'([a b] ...))])) +(begin-for-syntax + (define (?∀ Xs τ) + (if (stx-null? Xs) + τ + #`(∀ #,Xs #,τ))) + (define (?Some Xs τ cs) + (if (and (stx-null? Xs) (stx-null? cs)) + τ + #`(Some #,Xs #,τ (Cs #,@cs)))) + (define (Cs cs) + (syntax-parse cs + [([a b] ...) + #'(Constraints (Constraint a b) ...)])) + (define-syntax ~?Some + (pattern-expander + (syntax-parser + [(?Some Xs-pat τ-pat Cs-pat) + #:with τ (generate-temporary) + #'(~and τ + (~parse (~Some Xs-pat τ-pat Cs-pat) + (if (Some? #'τ) + #'τ + ((current-type-eval) #'(Some [] τ (Cs))))))]))) + (define-syntax ~Cs + (pattern-expander + (syntax-parser #:literals (...) + [(_ [a b] ooo:...) + #:with cs (generate-temporary) + #'(~and cs + (~parse (~Constraints (~Constraint a b) ooo) + (if (syntax-e #'cs) + #'cs + ((current-type-eval) #'(Cs)))))])))) + +(begin-for-syntax + ;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id) + ;; finds the free Xs in the type + (define (find-free-Xs Xs ty) + (for/list ([X (in-list (stx->list Xs))] + #:when (stx-contains-id? ty X)) + X)) + + ;; constrainable-X? : Id Solved-Constraints (Stx-Listof Id) -> Boolean + (define (constrainable-X? X cs Vs) + (for/or ([c (in-list (stx->list cs))]) + (or (bound-identifier=? X (stx-car c)) + (and (member (stx-car c) Vs bound-identifier=?) + (stx-contains-id? (stx-cadr c) X) + )))) + + ;; find-constrainable-vars : (Stx-Listof Id) Solved-Constraints (Stx-Listof Id) -> (Listof Id) + (define (find-constrainable-vars Xs cs Vs) + (for/list ([X (in-list Xs)] #:when (constrainable-X? X cs Vs)) + X)) + + ;; set-minus/Xs : (Listof Id) (Listof Id) -> (Listof Id) + (define (set-minus/Xs Xs Ys) + (for/list ([X (in-list Xs)] + #:when (not (member X Ys bound-identifier=?))) + X)) + ;; set-intersect/Xs : (Listof Id) (Listof Id) -> (Listof Id) + (define (set-intersect/Xs Xs Ys) + (for/list ([X (in-list Xs)] + #:when (member X Ys bound-identifier=?)) + X)) + + ;; some/inst/generalize : (Stx-Listof Id) Type-Stx Constraints -> Type-Stx + (define (some/inst/generalize Xs* ty* cs1) + (define Xs (stx->list Xs*)) + (define cs2 (add-constraints/var? Xs identifier? '() cs1)) + (define Vs (set-minus/Xs (stx-map stx-car cs2) Xs)) + (define constrainable-vars + (find-constrainable-vars Xs cs2 Vs)) + (define constrainable-Xs + (set-intersect/Xs Xs constrainable-vars)) + (define concrete-constrained-vars + (for/list ([X (in-list constrainable-vars)] + #:when (empty? (find-free-Xs Xs (or (lookup X cs2) X)))) + X)) + (define unconstrainable-Xs + (set-minus/Xs Xs constrainable-Xs)) + (define ty (inst-type/cs/orig constrainable-vars cs2 ty*)) + ;; pruning constraints that are useless now + (define concrete-constrainable-Xs + (for/list ([X (in-list constrainable-Xs)] + #:when (empty? (find-free-Xs constrainable-Xs (or (lookup X cs2) X)))) + X)) + (define cs3 + (for/list ([c (in-list cs2)] + #:when (not (member (stx-car c) concrete-constrainable-Xs bound-identifier=?))) + c)) + (?Some + (set-minus/Xs constrainable-Xs concrete-constrainable-Xs) + (?∀ (find-free-Xs unconstrainable-Xs ty) ty) + cs3)) + + (define (tycons id args) + (define/syntax-parse [X ...] + (for/list ([arg (in-list (stx->list args))]) + (add-orig (generate-temporary arg) (get-orig arg)))) + (define/syntax-parse [arg ...] args) + (define/syntax-parse (~∀ (X- ...) body) + ((current-type-eval) #`(∀ (X ...) (#,id X ...)))) + (inst-type/cs #'[X- ...] #'([X- arg] ...) #'body)) + + ) + +(define-typed-syntax λ + [(λ (x:id ...) body:expr) ≫ + [#:with [X ...] + (for/list ([X (in-list (generate-temporaries #'[x ...]))]) + (add-orig X X))] + [([X : #%type ≫ X-] ...) ([x : X ≫ x-] ...) + ⊢ [[body ≫ body-] ⇒ : τ_body*]] + [#:with (~?Some [V ...] τ_body (~Cs [id_2 τ_2] ...)) (syntax-local-introduce #'τ_body*)] + [#:with τ_fn (some/inst/generalize #'[X- ... V ...] + #'(→ X- ... τ_body) + #'([id_2 τ_2] ...))] + -------- + [⊢ [[_ ≫ (λ- (x- ...) body-)] ⇒ : τ_fn]]]) + +(define-typed-syntax #%app + [(_ e_fn e_arg ...) ≫ + [#:with [A ...] (generate-temporaries #'[e_arg ...])] + [#:with B (generate-temporary 'result)] + [⊢ [[e_fn ≫ e_fn-] ⇒ : τ_fn*]] + [#:with (~?Some [V1 ...] τ_fn (~Cs [τ_3 τ_4] ...)) (syntax-local-introduce #'τ_fn*)] + [#:with τ_fn-expected (tycons #'→ #'[A ... B])] + [⊢ [[e_arg ≫ e_arg-] ⇒ : τ_arg*] ...] + [#:with [(~?Some [V2 ...] τ_arg (~Cs [τ_5 τ_6] ...)) ...] + (syntax-local-introduce #'[τ_arg* ...])] + [#:with τ_out (some/inst/generalize #'[A ... B V1 ... V2 ... ...] + #'B + #'([τ_fn-expected τ_fn] + [τ_3 τ_4] ... + [A τ_arg] ... + [τ_5 τ_6] ... ...))] + -------- + [⊢ [[_ ≫ (#%app- e_fn- e_arg- ...)] ⇒ : τ_out]]]) + + + diff --git a/turnstile/examples/tests/run-all-tests.rkt b/turnstile/examples/tests/run-all-tests.rkt index 7c1366c..e2ed66d 100644 --- a/turnstile/examples/tests/run-all-tests.rkt +++ b/turnstile/examples/tests/run-all-tests.rkt @@ -32,6 +32,7 @@ ;; type inference (require macrotypes/examples/tests/infer-tests) +(require "tlb-infer-tests.rkt") ;; type and effects (require "stlc+effect-tests.rkt") diff --git a/turnstile/examples/tests/tlb-infer-tests.rkt b/turnstile/examples/tests/tlb-infer-tests.rkt new file mode 100644 index 0000000..d238b34 --- /dev/null +++ b/turnstile/examples/tests/tlb-infer-tests.rkt @@ -0,0 +1,45 @@ +#lang s-exp "../infer.rkt" +(require "rackunit-typechecking.rkt") + +(check-type (λ (x) 5) : (∀ (X) (→ X Int))) +(check-type (λ (x) x) : (∀ (X) (→ X X))) + +(check-type (λ (x) (λ (y) 6)) : (∀ (X) (→ X (∀ (Y) (→ Y Int))))) +(check-type (λ (x) (λ (y) x)) : (∀ (X) (→ X (∀ (Y) (→ Y X))))) +(check-type (λ (x) (λ (y) y)) : (∀ (X) (→ X (∀ (Y) (→ Y Y))))) + +(check-type (λ (x) (λ (y) (λ (z) 7))) : (∀ (X) (→ X (∀ (Y) (→ Y (∀ (Z) (→ Z Int))))))) +(check-type (λ (x) (λ (y) (λ (z) x))) : (∀ (X) (→ X (∀ (Y) (→ Y (∀ (Z) (→ Z X))))))) +(check-type (λ (x) (λ (y) (λ (z) y))) : (∀ (X) (→ X (∀ (Y) (→ Y (∀ (Z) (→ Z Y))))))) +(check-type (λ (x) (λ (y) (λ (z) z))) : (∀ (X) (→ X (∀ (Y) (→ Y (∀ (Z) (→ Z Z))))))) + +(check-type (+ 1 2) : Int) +(check-type (λ (x) (+ x 2)) : (→ Int Int)) +(check-type (λ (x y) (+ 1 2)) : (∀ (X Y) (→ X Y Int))) +(check-type (λ (x y) (+ x 2)) : (∀ (Y) (→ Int Y Int))) +(check-type (λ (x y) (+ 1 y)) : (∀ (X) (→ X Int Int))) +(check-type (λ (x y) (+ x y)) : (→ Int Int Int)) + +(check-type (λ (x) (λ (y) (+ 1 2))) : (∀ (X) (→ X (∀ (Y) (→ Y Int))))) +(check-type (λ (x) (λ (y) (+ x 2))) : (→ Int (∀ (Y) (→ Y Int)))) +(check-type (λ (x) (λ (y) (+ 1 y))) : (∀ (X) (→ X (→ Int Int)))) +(check-type (λ (x) (λ (y) (+ x y))) : (→ Int (→ Int Int))) + +(check-type (λ (x) (λ (y) (λ (z) (+ 1 2)))) : (∀ (X) (→ X (∀ (Y) (→ Y (∀ (Z) (→ Z Int))))))) +(check-type (λ (x) (λ (y) (λ (z) (+ x 2)))) : (→ Int (∀ (Y) (→ Y (∀ (Z) (→ Z Int)))))) +(check-type (λ (x) (λ (y) (λ (z) (+ y 2)))) : (∀ (X) (→ X (→ Int (∀ (Z) (→ Z Int)))))) +(check-type (λ (x) (λ (y) (λ (z) (+ z 2)))) : (∀ (X) (→ X (∀ (Y) (→ Y (→ Int Int)))))) +(check-type (λ (x) (λ (y) (λ (z) (+ x y)))) : (→ Int (→ Int (∀ (Z) (→ Z Int))))) +(check-type (λ (x) (λ (y) (λ (z) (+ x z)))) : (→ Int (∀ (Y) (→ Y (→ Int Int))))) +(check-type (λ (x) (λ (y) (λ (z) (+ y z)))) : (∀ (X) (→ X (→ Int (→ Int Int))))) +(check-type (λ (x) (λ (y) (λ (z) (+ (+ x y) z)))) : (→ Int (→ Int (→ Int Int)))) + +(check-type (λ (f a) (f a)) : (∀ (A B) (→ (→ A B) A B))) + +(check-type (λ (a f g) (g (f a))) + : (∀ (A C B) (→ A (→ A B) (→ B C) C))) +(check-type (λ (a f g) (g (f a) (+ (f 1) (f 2)))) + : (∀ (C) (→ Int (→ Int Int) (→ Int Int C) C))) +(check-type (λ (a f g) (g (λ () (f a)) (+ (f 1) (f 2)))) + : (∀ (C) (→ Int (→ Int Int) (→ (→ Int) Int C) C))) +