implement turnstile/examples/infer.rkt

This commit is contained in:
AlexKnauth 2016-06-24 12:11:17 -04:00
parent da3ecfa780
commit 310087cc97
4 changed files with 239 additions and 20 deletions

View File

@ -1,6 +1,7 @@
#lang racket/base
(provide add-constraints
add-constraints/var?
lookup
lookup-Xs/keep-unsolved
inst-type
@ -23,9 +24,15 @@
;; unification algorithm for local type inference.
(define (add-constraints Xs substs new-cs [orig-cs new-cs])
(define Xs* (stx->list Xs))
(define (X? X)
(member X Xs* free-identifier=?))
(add-constraints/var? Xs* X? substs new-cs orig-cs))
(define (add-constraints/var? Xs* var? substs new-cs [orig-cs new-cs])
(define Xs (stx->list Xs*))
(define Ys (stx-map stx-car substs))
(define-syntax-class var
[pattern x:id #:when (member #'x Xs* free-identifier=?)])
[pattern x:id #:when (var? #'x)])
(syntax-parse new-cs
[() substs]
[([a:var b] . rst)
@ -36,26 +43,29 @@
;; or #'a already maps to a type that conflicts with #'b.
;; In either case, whatever #'a maps to must be equivalent
;; to #'b, so add that to the constraints.
(add-constraints
(add-constraints/var?
Xs
var?
substs
(cons (list (lookup #'a substs) #'b)
#'rst)
orig-cs)]
[else
(define entry (list #'a #'b))
(add-constraints
Xs*
(add-constraints/var?
Xs
var?
;; Add the mapping #'a -> #'b to the substitution,
(add-substitution-entry entry substs)
;; and substitute that in each of the constraints.
(cs-substitute-entry entry #'rst)
orig-cs)])]
[([a b:var] . rst)
(add-constraints Xs*
substs
#'([b a] . rst)
orig-cs)]
(add-constraints/var? Xs
var?
substs
#'([b a] . rst)
orig-cs)]
[([a b] . rst)
;; If #'a and #'b are base types, check that they're equal.
;; Identifers not within Xs count as base types.
@ -74,25 +84,28 @@
(string-join (map type->str (stx-map stx-car orig-cs)) ", ")
(string-join (map type->str (stx-map stx-cadr orig-cs)) ", "))
#'a #'b))
(add-constraints Xs*
substs
#'rst
orig-cs)]
(add-constraints/var? Xs
var?
substs
#'rst
orig-cs)]
[else
(syntax-parse #'[a b]
[_
#:when (typecheck? #'a #'b)
(add-constraints Xs
substs
#'rst
orig-cs)]
(add-constraints/var? Xs
var?
substs
#'rst
orig-cs)]
[((~Any tycons1 τ1 ...) (~Any tycons2 τ2 ...))
#:when (typecheck? #'tycons1 #'tycons2)
#:when (stx-length=? #'[τ1 ...] #'[τ2 ...])
(add-constraints Xs
substs
#'((τ1 τ2) ... . rst)
orig-cs)]
(add-constraints/var? Xs
var?
substs
#'((τ1 τ2) ... . rst)
orig-cs)]
[else
(type-error #:src (get-orig #'b)
#:msg (format "couldn't unify ~~a and ~~a\n expected: ~a\n given: ~a"

View File

@ -0,0 +1,160 @@
#lang turnstile
(extends "ext-stlc.rkt" #:except #%app λ)
(require (only-in "sysf.rkt" ~∀ ∀? Λ))
(reuse cons [head hd] [tail tl] nil [isnil nil?] List list #:from "stlc+cons.rkt")
(require (only-in "stlc+cons.rkt" ~List))
(reuse tup × proj #:from "stlc+tup.rkt")
(reuse define-type-alias #:from "stlc+reco+var.rkt")
(require (for-syntax macrotypes/type-constraints))
(provide hd tl nil? )
;; (Some [X ...] τ_body (Constraints (Constraint τ_1 τ_2) ...))
(define-type-constructor Some #:arity = 2 #:bvs >= 0)
(define-type-constructor Constraint #:arity = 2)
(define-type-constructor Constraints #:arity >= 0)
(define-syntax Cs
(syntax-parser
[(_ [a b] ...)
(Cs #'([a b] ...))]))
(begin-for-syntax
(define (?∀ Xs τ)
(if (stx-null? Xs)
τ
#`( #,Xs #,τ)))
(define (?Some Xs τ cs)
(if (and (stx-null? Xs) (stx-null? cs))
τ
#`(Some #,Xs #,τ (Cs #,@cs))))
(define (Cs cs)
(syntax-parse cs
[([a b] ...)
#'(Constraints (Constraint a b) ...)]))
(define-syntax ~?Some
(pattern-expander
(syntax-parser
[(?Some Xs-pat τ-pat Cs-pat)
#:with τ (generate-temporary)
#'(~and τ
(~parse (~Some Xs-pat τ-pat Cs-pat)
(if (Some? #'τ)
#'τ
((current-type-eval) #'(Some [] τ (Cs))))))])))
(define-syntax ~Cs
(pattern-expander
(syntax-parser #:literals (...)
[(_ [a b] ooo:...)
#:with cs (generate-temporary)
#'(~and cs
(~parse (~Constraints (~Constraint a b) ooo)
(if (syntax-e #'cs)
#'cs
((current-type-eval) #'(Cs)))))]))))
(begin-for-syntax
;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id)
;; finds the free Xs in the type
(define (find-free-Xs Xs ty)
(for/list ([X (in-list (stx->list Xs))]
#:when (stx-contains-id? ty X))
X))
;; constrainable-X? : Id Solved-Constraints (Stx-Listof Id) -> Boolean
(define (constrainable-X? X cs Vs)
(for/or ([c (in-list (stx->list cs))])
(or (bound-identifier=? X (stx-car c))
(and (member (stx-car c) Vs bound-identifier=?)
(stx-contains-id? (stx-cadr c) X)
))))
;; find-constrainable-vars : (Stx-Listof Id) Solved-Constraints (Stx-Listof Id) -> (Listof Id)
(define (find-constrainable-vars Xs cs Vs)
(for/list ([X (in-list Xs)] #:when (constrainable-X? X cs Vs))
X))
;; set-minus/Xs : (Listof Id) (Listof Id) -> (Listof Id)
(define (set-minus/Xs Xs Ys)
(for/list ([X (in-list Xs)]
#:when (not (member X Ys bound-identifier=?)))
X))
;; set-intersect/Xs : (Listof Id) (Listof Id) -> (Listof Id)
(define (set-intersect/Xs Xs Ys)
(for/list ([X (in-list Xs)]
#:when (member X Ys bound-identifier=?))
X))
;; some/inst/generalize : (Stx-Listof Id) Type-Stx Constraints -> Type-Stx
(define (some/inst/generalize Xs* ty* cs1)
(define Xs (stx->list Xs*))
(define cs2 (add-constraints/var? Xs identifier? '() cs1))
(define Vs (set-minus/Xs (stx-map stx-car cs2) Xs))
(define constrainable-vars
(find-constrainable-vars Xs cs2 Vs))
(define constrainable-Xs
(set-intersect/Xs Xs constrainable-vars))
(define concrete-constrained-vars
(for/list ([X (in-list constrainable-vars)]
#:when (empty? (find-free-Xs Xs (or (lookup X cs2) X))))
X))
(define unconstrainable-Xs
(set-minus/Xs Xs constrainable-Xs))
(define ty (inst-type/cs/orig constrainable-vars cs2 ty*))
;; pruning constraints that are useless now
(define concrete-constrainable-Xs
(for/list ([X (in-list constrainable-Xs)]
#:when (empty? (find-free-Xs constrainable-Xs (or (lookup X cs2) X))))
X))
(define cs3
(for/list ([c (in-list cs2)]
#:when (not (member (stx-car c) concrete-constrainable-Xs bound-identifier=?)))
c))
(?Some
(set-minus/Xs constrainable-Xs concrete-constrainable-Xs)
(?∀ (find-free-Xs unconstrainable-Xs ty) ty)
cs3))
(define (tycons id args)
(define/syntax-parse [X ...]
(for/list ([arg (in-list (stx->list args))])
(add-orig (generate-temporary arg) (get-orig arg))))
(define/syntax-parse [arg ...] args)
(define/syntax-parse (~∀ (X- ...) body)
((current-type-eval) #`( (X ...) (#,id X ...))))
(inst-type/cs #'[X- ...] #'([X- arg] ...) #'body))
)
(define-typed-syntax λ
[(λ (x:id ...) body:expr)
[#:with [X ...]
(for/list ([X (in-list (generate-temporaries #'[x ...]))])
(add-orig X X))]
[([X : #%type X-] ...) ([x : X x-] ...)
[[body body-] : τ_body*]]
[#:with (~?Some [V ...] τ_body (~Cs [id_2 τ_2] ...)) (syntax-local-introduce #'τ_body*)]
[#:with τ_fn (some/inst/generalize #'[X- ... V ...]
#'( X- ... τ_body)
#'([id_2 τ_2] ...))]
--------
[ [[_ (λ- (x- ...) body-)] : τ_fn]]])
(define-typed-syntax #%app
[(_ e_fn e_arg ...)
[#:with [A ...] (generate-temporaries #'[e_arg ...])]
[#:with B (generate-temporary 'result)]
[ [[e_fn e_fn-] : τ_fn*]]
[#:with (~?Some [V1 ...] τ_fn (~Cs [τ_3 τ_4] ...)) (syntax-local-introduce #'τ_fn*)]
[#:with τ_fn-expected (tycons #' #'[A ... B])]
[ [[e_arg e_arg-] : τ_arg*] ...]
[#:with [(~?Some [V2 ...] τ_arg (~Cs [τ_5 τ_6] ...)) ...]
(syntax-local-introduce #'[τ_arg* ...])]
[#:with τ_out (some/inst/generalize #'[A ... B V1 ... V2 ... ...]
#'B
#'([τ_fn-expected τ_fn]
[τ_3 τ_4] ...
[A τ_arg] ...
[τ_5 τ_6] ... ...))]
--------
[ [[_ (#%app- e_fn- e_arg- ...)] : τ_out]]])

View File

@ -32,6 +32,7 @@
;; type inference
(require macrotypes/examples/tests/infer-tests)
(require "tlb-infer-tests.rkt")
;; type and effects
(require "stlc+effect-tests.rkt")

View File

@ -0,0 +1,45 @@
#lang s-exp "../infer.rkt"
(require "rackunit-typechecking.rkt")
(check-type (λ (x) 5) : ( (X) ( X Int)))
(check-type (λ (x) x) : ( (X) ( X X)))
(check-type (λ (x) (λ (y) 6)) : ( (X) ( X ( (Y) ( Y Int)))))
(check-type (λ (x) (λ (y) x)) : ( (X) ( X ( (Y) ( Y X)))))
(check-type (λ (x) (λ (y) y)) : ( (X) ( X ( (Y) ( Y Y)))))
(check-type (λ (x) (λ (y) (λ (z) 7))) : ( (X) ( X ( (Y) ( Y ( (Z) ( Z Int)))))))
(check-type (λ (x) (λ (y) (λ (z) x))) : ( (X) ( X ( (Y) ( Y ( (Z) ( Z X)))))))
(check-type (λ (x) (λ (y) (λ (z) y))) : ( (X) ( X ( (Y) ( Y ( (Z) ( Z Y)))))))
(check-type (λ (x) (λ (y) (λ (z) z))) : ( (X) ( X ( (Y) ( Y ( (Z) ( Z Z)))))))
(check-type (+ 1 2) : Int)
(check-type (λ (x) (+ x 2)) : ( Int Int))
(check-type (λ (x y) (+ 1 2)) : ( (X Y) ( X Y Int)))
(check-type (λ (x y) (+ x 2)) : ( (Y) ( Int Y Int)))
(check-type (λ (x y) (+ 1 y)) : ( (X) ( X Int Int)))
(check-type (λ (x y) (+ x y)) : ( Int Int Int))
(check-type (λ (x) (λ (y) (+ 1 2))) : ( (X) ( X ( (Y) ( Y Int)))))
(check-type (λ (x) (λ (y) (+ x 2))) : ( Int ( (Y) ( Y Int))))
(check-type (λ (x) (λ (y) (+ 1 y))) : ( (X) ( X ( Int Int))))
(check-type (λ (x) (λ (y) (+ x y))) : ( Int ( Int Int)))
(check-type (λ (x) (λ (y) (λ (z) (+ 1 2)))) : ( (X) ( X ( (Y) ( Y ( (Z) ( Z Int)))))))
(check-type (λ (x) (λ (y) (λ (z) (+ x 2)))) : ( Int ( (Y) ( Y ( (Z) ( Z Int))))))
(check-type (λ (x) (λ (y) (λ (z) (+ y 2)))) : ( (X) ( X ( Int ( (Z) ( Z Int))))))
(check-type (λ (x) (λ (y) (λ (z) (+ z 2)))) : ( (X) ( X ( (Y) ( Y ( Int Int))))))
(check-type (λ (x) (λ (y) (λ (z) (+ x y)))) : ( Int ( Int ( (Z) ( Z Int)))))
(check-type (λ (x) (λ (y) (λ (z) (+ x z)))) : ( Int ( (Y) ( Y ( Int Int)))))
(check-type (λ (x) (λ (y) (λ (z) (+ y z)))) : ( (X) ( X ( Int ( Int Int)))))
(check-type (λ (x) (λ (y) (λ (z) (+ (+ x y) z)))) : ( Int ( Int ( Int Int))))
(check-type (λ (f a) (f a)) : ( (A B) ( ( A B) A B)))
(check-type (λ (a f g) (g (f a)))
: ( (A C B) ( A ( A B) ( B C) C)))
(check-type (λ (a f g) (g (f a) (+ (f 1) (f 2))))
: ( (C) ( Int ( Int Int) ( Int Int C) C)))
(check-type (λ (a f g) (g (λ () (f a)) (+ (f 1) (f 2))))
: ( (C) ( Int ( Int Int) ( ( Int) Int C) C)))