diff --git a/collects/tests/typed-scheme/fail/bad-ann.rkt b/collects/tests/typed-scheme/fail/bad-ann.rkt index 0c783845..1a70f854 100644 --- a/collects/tests/typed-scheme/fail/bad-ann.rkt +++ b/collects/tests/typed-scheme/fail/bad-ann.rkt @@ -3,7 +3,7 @@ #lang typed/scheme -(: f : Number -> Number) +(: f : Number -> Number) (define (f a b) (+ a b)) diff --git a/collects/tests/typed-scheme/succeed/values-dots.rkt b/collects/tests/typed-scheme/succeed/values-dots.rkt index ed8ace35..f97ef0f4 100644 --- a/collects/tests/typed-scheme/succeed/values-dots.rkt +++ b/collects/tests/typed-scheme/succeed/values-dots.rkt @@ -1,4 +1,4 @@ -#lang typed-scheme +#lang typed/scheme/base (require typed-scheme/base-env/extra-procs) diff --git a/collects/tests/typed-scheme/unit-tests/typecheck-tests.rkt b/collects/tests/typed-scheme/unit-tests/typecheck-tests.rkt index a1d376bf..8da1e06c 100644 --- a/collects/tests/typed-scheme/unit-tests/typecheck-tests.rkt +++ b/collects/tests/typed-scheme/unit-tests/typecheck-tests.rkt @@ -1329,6 +1329,15 @@ [tc-e (#%variable-reference +) -Variable-Reference] [tc-e (apply (λ: ([x : String] [y : String]) (string-append x y)) (list "foo" "bar")) -String] [tc-e (apply (plambda: (a) ([x : a] [y : a]) x) (list "foo" "bar")) -String] + + [tc-e (ann + (case-lambda [(x) (add1 x)] + [(x y) (add1 x)]) + (case-> (Integer -> Integer) + (Integer Integer -> Integer))) + #:ret (ret (cl->* (t:-> -Integer -Integer) + (t:-> -Integer -Integer -Integer)) + (-FS -top -bot))] ) (test-suite "check-type tests" diff --git a/collects/typed-scheme/typecheck/tc-lambda-unit.rkt b/collects/typed-scheme/typecheck/tc-lambda-unit.rkt index 8fe31c63..f3dafeba 100644 --- a/collects/typed-scheme/typecheck/tc-lambda-unit.rkt +++ b/collects/typed-scheme/typecheck/tc-lambda-unit.rkt @@ -193,34 +193,52 @@ (cons (stx-car s) (loop (cdr (syntax-e s))))] [(null? (syntax-e s)) null] [else (list s)]))) - (define (go formals bodies formals* bodies* nums-seen) + (define (find-expected tc-r fml) + (match tc-r + [(tc-result1: (Function: (and fs (list (arr: argss rets rests drests '()) ...)))) + (cond [(syntax->list fml) + (for/list ([a argss] [f fs] [r rests] [dr drests] + #:when (and (not r) (not dr) (= (length a) (length (syntax->list fml))))) + f)] + [else + (for/list ([a argss] [f fs] [r rests] [dr drests] + #:when (and (or r dr) (= (length a) (sub1 (syntax-len fml))))) + f)])] + [_ null])) + (define (go expected formals bodies formals* bodies* nums-seen) (cond [(null? formals) - (map tc/lambda-clause (reverse formals*) (reverse bodies*))] + (apply append + (for/list ([f* formals*] [b* bodies*]) + (match (find-expected expected f*) + ;; very conservative -- only do anything interesting if we get exactly one thing that matches + [(list) + (if (and (= 1 (length formals*)) expected) + (tc-error/expr #:return (list (lam-result null null (list #'here Univ) #f (ret (Un)))) + "Expected a function of type ~a, but got a function with the wrong arity" + (match expected [(tc-result1: t) t])) + (list (tc/lambda-clause f* b*)))] + [(list (arr: argss rets rests drests '()) ...) + (for/list ([args argss] [ret rets] [rest rests] [drest drests]) + (tc/lambda-clause/check + f* b* args (values->tc-results ret (formals->list f*)) rest drest))])))] [(memv (syntax-len (car formals)) nums-seen) ;; we check this clause, but it doesn't contribute to the overall type (tc/lambda-clause (car formals) (car bodies)) - (go (cdr formals) (cdr bodies) formals* bodies* nums-seen)] + ;; FIXME - warn about dead clause here + (go expected (cdr formals) (cdr bodies) formals* bodies* nums-seen)] [else - (go (cdr formals) (cdr bodies) + (go expected + (cdr formals) (cdr bodies) (cons (car formals) formals*) (cons (car bodies) bodies*) - (cons (syntax-len (car formals)) nums-seen))])) - (cond - ;; special case for not-case-lambda - [(and expected - (= 1 (length (syntax->list formals)))) - (let loop ([expected expected]) - (match expected - [(tc-result1: (and t (Mu: _ _))) (loop (ret (unfold t)))] - [(tc-result1: (Function: (list (arr: argss rets rests drests '()) ...))) - (let ([fmls (car (syntax->list formals))]) - (for/list ([args argss] [ret rets] [rest rests] [drest drests]) - (tc/lambda-clause/check fmls (car (syntax->list bodies)) - args (values->tc-results ret (formals->list fmls)) rest drest)))] - [_ (go (syntax->list formals) (syntax->list bodies) null null null)]))] - ;; otherwise - [else (go (syntax->list formals) (syntax->list bodies) null null null)])) + (cons (syntax-len (car formals)) nums-seen))])) + (let loop ([expected expected]) + (match expected + [(tc-result1: (and t (Mu: _ _))) (loop (ret (unfold t)))] + [(tc-result1: (Function: (list (arr: argss rets rests drests '()) ...))) + (go expected (syntax->list formals) (syntax->list bodies) null null null)] + [_ (go #f (syntax->list formals) (syntax->list bodies) null null null)]))) (define (tc/mono-lambda/type formals bodies expected) (define t (make-Function (map lam-result->type (tc/mono-lambda formals bodies expected))))