diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt index 86bfc51..508d515 100644 --- a/tapl/mlish.rkt +++ b/tapl/mlish.rkt @@ -375,6 +375,9 @@ (provide →/test) (define-syntax →/test (syntax-parser + [(_ (~and Xs (X:id ...)) . rst) + #:when (brace? #'Xs) + #'(∀ (X ...) (ext-stlc:→ . rst))] [(_ . rst) (let L ([Xs #'()]) ; compute unbound ids; treat as tyvars (with-handlers ([exn:fail:syntax:unbound? @@ -433,7 +436,16 @@ (syntax->datum #'e_fn) (type->str #'τ_fn)) #:with (~∀ Xs (~ext-stlc:→ τ_inX ... τ_outX)) #'τ_fn ;; ) instantiate polymorphic fn type - #:with (τ_solved ...) (solve #'Xs #'(τ_inX ...) (syntax/loc stx (e_fn e_arg ...))) + ; try to solve with expected-type first + #:with expected-ty (get-expected-type stx) + #:with maybe-solved + (and (syntax-e #'expected-ty) + (let ([cs (compute-constraints (list (list #'τ_outX ((current-type-eval) #'expected-ty))))]) + (filter (lambda (x) x) (stx-map (λ (X) (lookup X cs)) #'Xs)))) + ;; else use arg types + #:with (τ_solved ...) (if (and (syntax-e #'maybe-solved) (stx-length=? #'maybe-solved #'Xs)) + #'maybe-solved + (solve #'Xs #'(τ_inX ...) (syntax/loc stx (e_fn e_arg ...)))) ;; #:with cs (compute-constraints #'((τ_inX τ_arg) ...)) ;; #:with (τ_solved ...) (filter (λ (x) x) (stx-map (λ (y) (lookup y #'cs)) #'(X ...))) ;; #:fail-unless (stx-length=? #'(X ...) #'(τ_solved ...)) diff --git a/tapl/tests/mlish/inst.mlish b/tapl/tests/mlish/inst.mlish new file mode 100644 index 0000000..342178a --- /dev/null +++ b/tapl/tests/mlish/inst.mlish @@ -0,0 +1,24 @@ +#lang s-exp "../../mlish.rkt" +(require "../rackunit-typechecking.rkt") + +;; tests for instantiation of polymorphic functions and constructors + +(define-type (Result A B) + (Ok A) + (Error B)) + +(define {A B} (ok [a : A] -> (Result A B)) + (Ok a)) + +(check-type ok : (→/test {A B} A (Result A B))) ; test inferred +(check-type (inst ok Int String) : (→/test Int (Result Int String))) + +(define (f -> (Result Int String)) + (ok 1)) + +(check-type f : (→/test (Result Int String))) + +(define (g -> (Result Int String)) + (Ok 1)) + +(check-type g : (→/test (Result Int String))) diff --git a/tapl/tests/run-all-mlish-tests.rkt b/tapl/tests/run-all-mlish-tests.rkt index a7af038..2e9c7c6 100644 --- a/tapl/tests/run-all-mlish-tests.rkt +++ b/tapl/tests/run-all-mlish-tests.rkt @@ -18,3 +18,4 @@ (require "mlish/term.mlish") (require "mlish/find.mlish") (require "mlish/alex.mlish") +(require "mlish/inst.mlish")