diff --git a/collects/tests/typed-racket/fail/case-lambda1.rkt b/collects/tests/typed-racket/fail/case-lambda1.rkt new file mode 100644 index 0000000000..6d98ac7668 --- /dev/null +++ b/collects/tests/typed-racket/fail/case-lambda1.rkt @@ -0,0 +1,12 @@ +#; +(exn-pred 1) +#lang typed/racket +(: f (case-> + (Symbol Symbol * -> Integer) + (Symbol * -> Symbol))) +(define f (case-lambda + ((x . y) 4) + (w 'x) + )) + +((ann f (Symbol * -> Symbol)) 'x) diff --git a/collects/tests/typed-racket/fail/case-lambda2.rkt b/collects/tests/typed-racket/fail/case-lambda2.rkt new file mode 100644 index 0000000000..73bf35840a --- /dev/null +++ b/collects/tests/typed-racket/fail/case-lambda2.rkt @@ -0,0 +1,11 @@ +#lang typed/racket +(: f (case-> + (Symbol -> Symbol) + (Symbol Symbol -> Symbol))) +(define f (case-lambda + ((x) x) + (w w) + ((x y) x) + )) + +(f 'x 'y) diff --git a/collects/tests/typed-racket/fail/case-lambda3.rkt b/collects/tests/typed-racket/fail/case-lambda3.rkt new file mode 100644 index 0000000000..e34df02640 --- /dev/null +++ b/collects/tests/typed-racket/fail/case-lambda3.rkt @@ -0,0 +1,16 @@ +#; +(exn-pred 1) +#lang typed/racket + + +(: f (case-> + (String -> String) + (String String -> String) + (String Symbol * -> String))) +(define f + (case-lambda + ((x) x) + ((x y) y) + ((x . w) x))) + +(f "x" 'y) diff --git a/collects/tests/typed-racket/fail/case-lambda4.rkt b/collects/tests/typed-racket/fail/case-lambda4.rkt new file mode 100644 index 0000000000..71e7c63595 --- /dev/null +++ b/collects/tests/typed-racket/fail/case-lambda4.rkt @@ -0,0 +1,12 @@ +#; +(exn-pred 1) +#lang typed/racket + + +(: f (case-> + (Symbol * -> String))) +(define f + (case-lambda + ((x . w) "hello"))) + +(f 'x 'y) diff --git a/collects/tests/typed-racket/fail/missing-rest-arguments.rkt b/collects/tests/typed-racket/fail/missing-rest-arguments.rkt new file mode 100644 index 0000000000..ee3ca0f409 --- /dev/null +++ b/collects/tests/typed-racket/fail/missing-rest-arguments.rkt @@ -0,0 +1,8 @@ +#; +(exn-pred 2) +#lang typed/racket + +(: f (Symbol Symbol * -> Symbol)) +(: g (All (A ...) (Symbol A ... A -> Symbol))) +(define (f x y) x) +(define (g x y) x) diff --git a/collects/tests/typed-racket/optimizer/tests/case-lambda-dead-branch.rkt b/collects/tests/typed-racket/optimizer/tests/case-lambda-dead-branch.rkt new file mode 100644 index 0000000000..e64dc39be6 --- /dev/null +++ b/collects/tests/typed-racket/optimizer/tests/case-lambda-dead-branch.rkt @@ -0,0 +1,18 @@ +#; +( +TR opt: case-lambda-dead-branch.rkt 12:5 (x y) -- dead case-lambda branch +TR opt: case-lambda-dead-branch.rkt 18:5 (x y) -- dead case-lambda branch +) +#lang typed/racket + +(: f (case-> (Symbol Symbol -> String))) +(define f + (case-lambda + (w "hello") + ((x y) (add1 "hello")))) + +(: g (case-> (Symbol -> String))) +(define g + (case-lambda + ((x) "hello") + ((x y) (add1 "hello")))) diff --git a/collects/tests/typed-racket/succeed/case-lambda1.rkt b/collects/tests/typed-racket/succeed/case-lambda1.rkt new file mode 100644 index 0000000000..fa3cb79c80 --- /dev/null +++ b/collects/tests/typed-racket/succeed/case-lambda1.rkt @@ -0,0 +1,10 @@ +#lang typed/racket + + +(: f (case-> + (String Symbol * -> (U String Symbol)))) +(define f + (case-lambda + (w (first w)))) + +(f "x" 'y) diff --git a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/optimizer/dead-code.rkt b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/optimizer/dead-code.rkt index d60b84f41b..50a14243ed 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/optimizer/dead-code.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/optimizer/dead-code.rkt @@ -1,6 +1,6 @@ #lang racket/base -(require syntax/parse +(require syntax/parse syntax/stx (for-template racket/base racket/flonum racket/fixnum) "../utils/utils.rkt" (types type-table) @@ -54,4 +54,20 @@ (quasisyntax/loc/origin this-syntax #'kw (#%expression (begin #,(optimize/drop-pure #'tst) - #,((optimize) #'els))))))) + #,((optimize) #'els)))))) + (pattern ((~and kw (~literal case-lambda)) (formals . bodies) ...) + #:when (for/or ((formals (syntax->list #'(formals ...)))) + (dead-case-lambda-branch? formals)) + #:with opt + (quasisyntax/loc/origin + this-syntax #'kw + (case-lambda + #,@(for/list ((formals (syntax->list #'(formals ...))) + (bodies (syntax->list #'(bodies ...))) + #:unless (and (dead-case-lambda-branch? formals) + (log-optimization + "dead case-lambda branch" + "Unreachable case-lambda branch elimination." + formals))) + (cons formals (stx-map (optimize) bodies))))))) + diff --git a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/typecheck/tc-lambda-unit.rkt b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/typecheck/tc-lambda-unit.rkt index b745c1a84d..ad2e94104c 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/typecheck/tc-lambda-unit.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/typecheck/tc-lambda-unit.rkt @@ -8,6 +8,7 @@ (rename-in (except-in (types abbrev utils union) -> ->* one-of/c) [make-arr* make-arr]) (private type-annotation syntax-properties) + (types type-table) (typecheck signatures tc-metafunctions tc-subst check-below) (env type-env-structs lexical-env tvar-env index-env scoped-tvar-env) (utils tc-utils) @@ -88,23 +89,20 @@ (with-lexical-env/extend arg-list arg-types (make-lam-result (for/list ([al (in-list arg-list)] - [at (in-list arg-types)] - [a-ty (in-list arg-tys)]) (list al at)) + [at (in-list arg-types)]) + (list al at)) null (and rest-ty (list (or rest (generate-temporary)) rest-ty)) - ;; make up a fake name if none exists, this is an error case anyway (and drest (list (or rest (generate-temporary)) drest)) (tc-exprs/check (syntax->list body) ret-ty)))) ;; Check that the number of formal arguments is valid for the expected type. ;; Thus it must be able to accept the number of arguments that the expected - ;; type has. So we check for three cases, if the function doesn't accept - ;; enough arguments, if it requires too many arguments, or if it doesn't - ;; support rest arguments if needed. + ;; type has. So we check for two cases: if the function doesn't accept + ;; enough arguments, or if it requires too many arguments. ;; This allows a form like (lambda args body) to have the type (-> Symbol ;; Number) with out a rest arg. (when (or (and (< arg-len tys-len) (not rest)) - (> arg-len tys-len) - (and (or rest-ty drest) (not rest))) + (and (> arg-len tys-len) (not (or rest-ty drest)))) (tc-error/delayed (expected-str tys-len rest-ty drest arg-len rest))) (cond [(not rest) @@ -121,13 +119,20 @@ (list rest) (list (make-ListDots dty b)) (check-body))))] [else - (let ([rest-type (cond - [rest-ty rest-ty] - [(type-annotation rest) (get-type rest #:default Univ)] - [else Univ])]) - (with-lexical-env/extend - (list rest) (list (-lst rest-type)) - (check-body rest-type)))]))) + (define base-rest-type + (cond + [rest-ty rest-ty] + [(type-annotation rest) (get-type rest #:default Univ)] + [else Univ])) + (define extra-types + (if (<= arg-len tys-len) + (drop arg-tys arg-len) + null)) + (define rest-type (apply Un base-rest-type extra-types)) + + (with-lexical-env/extend + (list rest) (list (-lst rest-type)) + (check-body rest-type))]))) ;; typecheck a single lambda, with argument list and body ;; drest-ty and drest-bound are both false or not false @@ -236,16 +241,19 @@ #f (tc-exprs (syntax->list body))))]))])) -(struct formals (positional rest) #:transparent) +;; positional: natural? - the number of positional arguments +;; rest: boolean? - if there is a positional argument +;; syntax: syntax? - the improper syntax list of identifiers +(struct formals (positional rest syntax) #:transparent) -(define (make-formals s) - (let loop ([s s] [acc null]) +(define (make-formals stx) + (let loop ([s stx] [acc null]) (cond [(pair? s) (loop (cdr s) (cons (car s) acc))] - [(null? s) (formals (reverse acc) #f)] + [(null? s) (formals (reverse acc) #f stx)] [(pair? (syntax-e s)) (loop (stx-cdr s) (cons (stx-car s) acc))] - [(null? (syntax-e s)) (formals (reverse acc) #f)] - [else (formals (reverse acc) s)]))) + [(null? (syntax-e s)) (formals (reverse acc) #f stx)] + [else (formals (reverse acc) s stx)]))) (define (formals->list formals) (append @@ -254,14 +262,46 @@ (list (formals-rest formals)) empty))) -;; TODO Not use this bad broken definition of arity + +;; An arity is a list (List Natural Boolean), with the number of positional +;; arguments and whether there is a rest argument. +;; +;; An arities-seen is a list (List (Listof Natural) (U Natural Infinity)), +;; with the list of positional only arities seen and the least number of +;; positional arguments on an arity with a rest argument seen. (define (formals->arity formals) - (+ (length (formals-positional formals)) (if (formals-rest formals) 1 0))) + (list + (length (formals-positional formals)) + (and (formals-rest formals) #t))) + +(define initial-arities-seen (list empty +inf.0)) + +;; arities-seen-add : arities-seen? arity? -> arities-seen? +;; Adds the arity to the arities encoded in the arity-seen. +(define (arities-seen-add arities-seen arity) + (match-define (list positionals min-rest) arities-seen) + (match-define (list new-positional new-rest) arity) + (define new-min-rest + (if new-rest + (min new-positional min-rest) + min-rest)) + (list + (filter (λ (n) (< n new-min-rest)) (cons new-positional positionals)) + new-min-rest)) -;; tc/mono-lambda : (listof formals) (listof syntax?) (or/c #f tc-results) -> (listof lam-result) +;; arities-seen-seen-before? : arities-seen? arity? -> boolean? +;; Determines if the arity would have been covered by an existing arity in the arity-seen +(define (arities-seen-seen-before? arities-seen arity) + (match-define (list positionals min-rest) arities-seen) + (match-define (list new-positional new-rest) arity) + (or (>= new-positional min-rest) + (and (member new-positional positionals) (not new-rest)))) + + +;; tc/mono-lambda : (listof (list formals syntax?)) (or/c #f tc-results) -> (listof lam-result) ;; typecheck a sequence of case-lambda clauses -(define (tc/mono-lambda formals bodies expected) +(define (tc/mono-lambda formals+bodies expected) (define expected-type (match expected [(tc-result1: t) @@ -279,49 +319,52 @@ (for/list ([a (in-list argss)] [f (in-list fs)] [r (in-list rests)] [dr (in-list drests)] #:when (if (formals-rest fml) (>= (length a) (length (formals-positional fml))) - (and (not r) (not dr) (= (length a) (length (formals-positional fml)))))) - f)] + ((if (or r dr) <= =) (length a) (length (formals-positional fml))))) + f)] [_ null])) - (let go [(formals formals) - (bodies bodies) - (formals* null) - (bodies* null) - (nums-seen null)] - (cond - [(null? formals) - (apply append - (for/list ([f* (in-list formals*)] [b* (in-list bodies*)]) - (match (find-matching-arities f*) - ;; very conservative -- only do anything interesting if we get exactly one thing that matches - [(list) - (if (and (= 1 (length formals*)) expected-type) - (tc-error/expr #:return (list (lam-result null null (list (generate-temporary) Univ) - #f (ret (Un)))) - "Expected a function of type ~a, but got a function with the wrong arity" - expected-type) - (tc/lambda-clause f* b*))] - [(list (arr: argss rets rests drests '()) ...) - (for/list ([args (in-list argss)] [ret (in-list rets)] [rest (in-list rests)] [drest (in-list drests)]) - (tc/lambda-clause/check - f* b* args (values->tc-results ret (formals->list f*)) rest drest))])))] - [(member (formals->arity (car formals)) nums-seen) - ;; we check this clause, but it doesn't contribute to the overall type - (tc/lambda-clause (car formals) (car bodies)) - ;; FIXME - warn about dead clause here - (go (cdr formals) (cdr bodies) formals* bodies* nums-seen)] - [else - (go (cdr formals) - (cdr bodies) - (cons (car formals) formals*) - (cons (car bodies) bodies*) - (cons (formals->arity (car formals)) nums-seen))]))) + (define-values (used-formals+bodies arities-seen) + (for/fold ((formals+bodies* empty) (arities-seen initial-arities-seen)) + ((formal+body formals+bodies)) + (match formal+body + [(list formal body) + (define arity (formals->arity formal)) + (values + (cond + [(or (arities-seen-seen-before? arities-seen arity) + (and expected-type (null? (find-matching-arities formal)))) + (warn-unreachable body) + (add-dead-case-lambda-branch (formals-syntax formal)) + (if (check-unreachable-code?) + (cons formal+body formals+bodies*) + formals+bodies*)] + [else + (cons formal+body formals+bodies*)]) + (arities-seen-add arities-seen arity))]))) + + + (apply append + (for/list ([fb* (in-list used-formals+bodies)]) + (match-define (list f* b*) fb*) + (match (find-matching-arities f*) + [(list) + (if (and (= 1 (length used-formals+bodies)) expected-type) + ;; TODO improve error message. + (tc-error/expr #:return (list (lam-result null null (list (generate-temporary) Univ) #f (ret (Un)))) + "Expected a function of type ~a, but got a function with the wrong arity" + expected-type) + (tc/lambda-clause f* b*))] + [(list (arr: argss rets rests drests '()) ...) + (for/list ([args (in-list argss)] [ret (in-list rets)] [rest (in-list rests)] [drest (in-list drests)]) + (tc/lambda-clause/check + f* b* args (values->tc-results ret (formals->list f*)) rest drest))])))) (define (tc/mono-lambda/type formals bodies expected) (make-Function (map lam-result->type (tc/mono-lambda - (stx-map make-formals formals) - (syntax->list bodies) + (map list + (stx-map make-formals formals) + (syntax->list bodies)) expected)))) (define (plambda-prop stx) @@ -470,7 +513,3 @@ (with-lexical-env/extend (list name) (list ft) (begin (tc-exprs/check (syntax->list body) return) (ret ft)))))) - -;(trace tc/mono-lambda) - - diff --git a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/types/type-table.rkt b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/types/type-table.rkt index 8c1fe55d13..a62dd27efe 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/types/type-table.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/types/type-table.rkt @@ -100,6 +100,18 @@ (eq? t? (hash-ref tautology-contradiction-table e 'not-there))) (values (mk 'tautology) (mk 'contradiction) (mk 'neither)))) +;; keeps track of case-lambda branches that never get evaluated, so that the +;; optimizer can eliminate dead code. The key is the formals syntax object. +;; 1 possible value: #t +(define case-lambda-dead-table (make-hasheq)) + +(define (add-dead-case-lambda-branch formals) + (when (optimize?) + (hash-set! case-lambda-dead-table formals #t))) +(define (dead-case-lambda-branch? formals) + (hash-ref case-lambda-dead-table formals #f)) + + (provide/cond-contract [add-typeof-expr (syntax? tc-results/c . -> . any/c)] [type-of (syntax? . -> . tc-results/c)] @@ -116,4 +128,6 @@ [add-neither (syntax? . -> . any)] [tautology? (syntax? . -> . boolean?)] [contradiction? (syntax? . -> . boolean?)] - [neither? (syntax? . -> . boolean?)]) + [neither? (syntax? . -> . boolean?)] + [add-dead-case-lambda-branch (syntax? . -> . any)] + [dead-case-lambda-branch? (syntax? . -> . boolean?)]) diff --git a/pkgs/typed-racket-pkgs/typed-racket-tests/tests/typed-racket/unit-tests/typecheck-tests.rkt b/pkgs/typed-racket-pkgs/typed-racket-tests/tests/typed-racket/unit-tests/typecheck-tests.rkt index c93a559412..2ed17077c4 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-tests/tests/typed-racket/unit-tests/typecheck-tests.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-tests/tests/typed-racket/unit-tests/typecheck-tests.rkt @@ -1641,6 +1641,12 @@ #:expected (ret (-poly (a) (cl->* (t:-> a a) (t:-> a a a))))] [tc-err (plambda: (A) ((x : A)) x) #:expected (ret (list -Symbol -Symbol))] + + [tc-e/t + (case-lambda + [w 'result] + [(x) (add1 "hello")]) + (->* (list) Univ (-val 'result) : -true-lfilter)] ) (test-suite "check-type tests"