diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt index 690eb0a..865667f 100644 --- a/tapl/mlish.rkt +++ b/tapl/mlish.rkt @@ -114,21 +114,23 @@ (define (inst-types tys-solved Xs tys) (stx-map (lambda (t) (inst-type tys-solved Xs t)) tys)) - ;; computes unbound ids in tys, to be used as tyvars + ;; compute unbound tyvars in one unexpanded type ty + (define (compute-tyvar1 ty) + (syntax-parse ty + [X:id #'(X)] + [() #'()] + [(C t ...) (stx-appendmap compute-tyvar1 #'(t ...))])) + ;; computes unbound ids in (unexpanded) tys, to be used as tyvars (define (compute-tyvars tys) - (if (stx-null? tys) - #'() - (let L ([Xs #'()]) ; compute unbound ids; treat as tyvars - (define ctx (stx-car tys)) - (with-handlers - ([exn:fail:syntax:unbound? - (λ (e) - (define X (stx-car (exn:fail:syntax-exprs e))) - ;; X is tainted, so need to launder it - (define Y (datum->syntax ctx (syntax->datum X))) - (L (cons Y Xs)))]) - ((current-type-eval) #`(∀ #,Xs (ext-stlc:→ . #,tys))) - (stx-sort Xs)))))) + (define Xs (stx-appendmap compute-tyvar1 tys)) + (filter + (lambda (X) + (with-handlers + ([exn:fail:syntax:unbound? (lambda (e) #t)] + [exn:fail:type:infer? (lambda (e) #t)]) + (let ([X+ ((current-type-eval) X)]) + (not (or (tyvar? X+) (type? X+)))))) + (stx-remove-dups Xs)))) ;; define -------------------------------------------------- ;; for function defs, define infers type variables @@ -267,16 +269,20 @@ [_:id #:when (and (not (stx-null? #'(X ...))) (not (stx-null? #'(τ ...)))) - (type-error - #:src stx - #:msg - (string-append - (format "constructor ~a must instantiate ~a type argument(s): " - 'Cons (stx-length #'(X ...))) - (string-join (stx-map type->str #'(X ...)) ", ") - "\n" - (format "and be applied to ~a arguments with type(s): "(stx-length #'(τ ...))) - (string-join (stx-map type->str #'(τ ...)) ", ")))] + (raise + (exn:fail:type:infer + (string-append + (format "TYPE-ERROR: ~a (~a:~a):" + (syntax-source stx) (syntax-line stx) (syntax-column stx)) + "\n" + (format "Constructor ~a must be applied to ~a argument(s) with type(s): " + 'Cons (stx-length #'(τ ...))) + (string-join (stx-map type->str #'(τ ...)) ", ") + "\n" + (format "The arguments should instantiate ~a type argument(s): " + (stx-length #'(X ...))) + (string-join (stx-map type->str #'(X ...)) ", ")) + (current-continuation-marks)))] [(C τs e_arg ...) #:when (brace? #'τs) ; commit to this clause #:with {~! τ_X:type (... ...)} #'τs @@ -905,7 +911,7 @@ (⊢ (for*/list ([x- e-] ...) body-) : (List ty_body))]) (define-typed-syntax for/fold [(_ ([acc init]) ([x:id e] ...) body) - #:with [init- ty_init] (infer+erase #'init) + #:with [init- ty_init] (infer+erase #`(pass-expected init #,stx)) #:with ([e- (ty)] ...) (⇑s (e ...) as Sequence) #:with [(acc- x- ...) body- ty_body] (infer/ctx+erase #'([acc : ty_init][x : ty] ...) #'body) diff --git a/tapl/stx-utils.rkt b/tapl/stx-utils.rkt index 914b62c..e07801c 100644 --- a/tapl/stx-utils.rkt +++ b/tapl/stx-utils.rkt @@ -57,6 +57,9 @@ (define (stx-appendmap f stx) (stx-flatten (stx-map f stx))) +(define (stx-remove-dups Xs) + (remove-duplicates (stx->list Xs) free-identifier=?)) + (define (stx-drop stx n) (drop (stx->list stx) n)) diff --git a/tapl/tests/mlish-tests.rkt b/tapl/tests/mlish-tests.rkt index 5cb5f52..409704b 100644 --- a/tapl/tests/mlish-tests.rkt +++ b/tapl/tests/mlish-tests.rkt @@ -52,7 +52,7 @@ (check-type (g2 (Nil {Bool})) : (List Bool) ⇒ (Nil {Bool})) (check-type (g2 (Nil {(List Int)})) : (List (List Int)) ⇒ (Nil {(List Int)})) (check-type (g2 (Nil {(→ Int Int)})) : (List (→ Int Int)) ⇒ (Nil {(List (→ Int Int))})) -;; same as tests above, but without annotations +;; annotations unneeded: same as tests above, but without annotations (check-type (g2 Nil) : (List Int) ⇒ Nil) (check-type (g2 Nil) : (List Bool) ⇒ Nil) (check-type (g2 Nil) : (List (List Int)) ⇒ Nil) @@ -278,6 +278,39 @@ (define-type (Pairof A B) (C A B)) (check-type (match (C 1 2) with [C a b -> None]) : (Option Int) -> None) +;; type variable inference + +; F should remain valid tyvar, even though it's bound +(define (F [x : X] -> X) x) +(define (tvf1 [x : F] -> F) x) +(check-type tvf1 : (→/test X X)) + +; G should remain valid tyvar +(define-type (Type1 X) (G X)) +(define (tvf5 [x : G] -> G) x) +(check-type tvf5 : (→/test X X)) + +; TY should not be tyvar, bc it's a valid type +(define-type-alias TY (Pairof Int Int)) +(define (tvf2 [x : TY] -> TY) x) +(check-not-type tvf2 : (→/test X X)) + +; same with Bool +(define (tvf3 [x : Bool] -> Bool) x) +(check-not-type tvf3 : (→/test X X)) + +;; X in lam should not be a new tyvar +(define (tvf4 [x : X] -> (→ X X)) + (λ ([y : X]) x)) +(check-type tvf4 : (→/test X (→ X X))) +(check-not-type tvf4 : (→/test X (→ Y X))) + +(check-type (λ ([x : X]) (λ ([y : X]) y)) : (→/test X (→ X X))) +(check-not-type (λ ([x : X]) (λ ([y : X]) y)) : (→/test {X} X (→/test {Y} Y Y))) +(check-type (λ ([x : X]) (λ ([y : Y]) y)) : (→/test {X} X (→/test {Y} Y Y))) +(check-not-type (λ ([x : X]) (λ ([y : Y]) x)) : (→/test X (→ X X))) + + ; ext-stlc tests -------------------------------------------------- ; tests for stlc extensions diff --git a/tapl/tests/mlish/polyrecur.mlish b/tapl/tests/mlish/polyrecur.mlish index cf30d45..8854b72 100644 --- a/tapl/tests/mlish/polyrecur.mlish +++ b/tapl/tests/mlish/polyrecur.mlish @@ -79,7 +79,7 @@ [Leaf x -> (list x)] [Node x rst -> (cons x - (for/fold ([acc (nil {X})]) ([p (in-list (flatten rst))]) + (for/fold ([acc nil]) ([p (in-list (flatten rst))]) (match p with [x y -> (cons x (cons y acc))])))])) diff --git a/tapl/tests/mlish/result.mlish b/tapl/tests/mlish/result.mlish index dfb3c63..109f5aa 100644 --- a/tapl/tests/mlish/result.mlish +++ b/tapl/tests/mlish/result.mlish @@ -14,13 +14,15 @@ (provide ok) (provide error) +(check-type ok : (→/test A (Result A B))) +(check-type error : (→/test B (Result A B))) (check-type (inst ok Int String) : (→ Int (Result Int String))) -(check-type (inst error Int String) : (→ String (Result Int String))) +(check-type (inst error String Int) : (→ String (Result Int String))) (check-type - (list (Ok {Int String} 3) (Error "abject failure") (Ok 4)) + (list (Ok 3) (Error "abject failure") (Ok 4)) : (List (Result Int String)) - -> (list (Ok {Int String} 3) (Error "abject failure") (Ok 4))) + -> (list (Ok 3) (Error "abject failure") (Ok 4))) (define (result-bind [a : (Result A Er)] [f : (→ A (Result B Er))] → (Result B Er)) @@ -74,59 +76,58 @@ (define (read-tree [str : (List Char)] → (Read-Result (Tree Int))) (cond - [(isnil str) - (error "expected a tree of integers, given nothing")] - [(equal? (head str) #\( ) - (do result-bind - [tree1+str : (× (Tree Int) (List Char)) - <- (read-tree (tail str))] - [(cond [(equal? (head (proj tree1+str 1)) #\space) - ((inst ok Unit String) (void))] - [else - ((inst error Unit String) "expected a space")])] - [int+str : (× Int (List Char)) - <- (read-int (tail (proj tree1+str 1)) nil)] - [(cond [(equal? (head (proj int+str 1)) #\space) - ((inst ok Unit String) (void))] - [else - ((inst error Unit String) "expected a space")])] - [tree2+str : (× (Tree Int) (List Char)) - <- (read-tree (tail (proj int+str 1)))] - [(cond [(equal? (head (proj tree2+str 1)) #\) ) - ((inst ok Unit String) (void))] - [else - ((inst error Unit String) "expected a `)`")])] - ((inst ok (× (Tree Int) (List Char)) String) - (tup (Node (proj tree1+str 0) - (proj int+str 0) - (proj tree2+str 0)) - (tail (proj tree2+str 1)))))] - [(digit? (head str)) - (do result-bind - [int+str : (× Int (List Char)) - <- (read-int str nil)] - ((inst ok (× (Tree Int) (List Char)) String) - (tup (Leaf (proj int+str 0)) - (proj int+str 1))))] - [else - (error "expected either a `(` or a digit")])) + [(isnil str) + (error "expected a tree of integers, given nothing")] + [(equal? (head str) #\( ) + (let ([do-ok (inst ok Unit String)] + [do-error (inst error String Unit)]) + (do result-bind + [tree1+str : (× (Tree Int) (List Char)) + <- (read-tree (tail str))] + [(cond [(equal? (head (proj tree1+str 1)) #\space) + (do-ok (void))] + [else (do-error "expected a space")])] + [int+str : (× Int (List Char)) + <- (read-int (tail (proj tree1+str 1)) nil)] + [(cond [(equal? (head (proj int+str 1)) #\space) + (do-ok (void))] + [else (do-error "expected a space")])] + [tree2+str : (× (Tree Int) (List Char)) + <- (read-tree (tail (proj int+str 1)))] + [(cond [(equal? (head (proj tree2+str 1)) #\) ) + (do-ok (void))] + [else (do-error "expected a `)`")])] + (ok + (tup (Node (proj tree1+str 0) + (proj int+str 0) + (proj tree2+str 0)) + (tail (proj tree2+str 1))))))] + [(digit? (head str)) + (do result-bind + [int+str : (× Int (List Char)) + <- (read-int str nil)] + (ok + (tup (Leaf (proj int+str 0)) + (proj int+str 1))))] + [else + (error "expected either a `(` or a digit")])) (check-type (read-tree (string->list "42")) : (Read-Result (Tree Int)) - -> ((inst ok (× (Tree Int) (List Char)) String) + -> (ok (tup (Leaf 42) nil))) (check-type (read-tree (string->list "x")) : (Read-Result (Tree Int)) - -> ((inst error (× (Tree Int) (List Char)) String) + -> (error "expected either a `(` or a digit")) (check-type (read-tree (string->list "(42 43 (44 45 46))")) : (Read-Result (Tree Int)) - -> ((inst ok (× (Tree Int) (List Char)) String) + -> (ok (tup (Node (Leaf 42) 43 (Node (Leaf 44) 45 (Leaf 46))) nil))) diff --git a/tapl/typecheck.rkt b/tapl/typecheck.rkt index 659ff2d..2bf2f9a 100644 --- a/tapl/typecheck.rkt +++ b/tapl/typecheck.rkt @@ -135,6 +135,10 @@ (define-syntax add-expected (syntax-parser [(_ e τ) (syntax-property #'e 'expected-type #'τ)])) +(define-syntax pass-expected + (syntax-parser + [(_ e stx) (syntax-property #'e 'expected-type + (syntax-property #'stx 'expected-type))])) (define-for-syntax (add-expected-ty e ty) (or (and (syntax-e ty) (syntax-property e 'expected-type ((current-type-eval) ty))) @@ -170,6 +174,8 @@ (define ty (syntax-property stx tag)) (if (cons? ty) (car ty) ty)) + (define (tyvar? X) (syntax-property X 'tyvar)) + (define type-pat "[A-Za-z]+") ;; - infers type of e @@ -284,9 +290,11 @@ (expand/df #`(λ (tv ...) (let-syntax ([tv (make-rename-transformer - (assign-type - (assign-type #'tv #'k) - #'ok #:tag '#,tag))] ...) + (syntax-property + (assign-type + (assign-type #'tv #'k) + #'ok #:tag '#,tag) + 'tyvar #t))] ...) (λ (x ...) (let-syntax ([x