diff --git a/typed-racket-lib/typed-racket/static-contracts/combinators/any.rkt b/typed-racket-lib/typed-racket/static-contracts/combinators/any.rkt index f39d68d0..7fec6601 100644 --- a/typed-racket-lib/typed-racket/static-contracts/combinators/any.rkt +++ b/typed-racket-lib/typed-racket/static-contracts/combinators/any.rkt @@ -28,7 +28,8 @@ [(define (sc-map v f) v) (define (sc-traverse v f) (void)) (define (sc->contract v f) #'any/c) - (define (sc->constraints v f) (simple-contract-restrict 'flat))] + (define (sc->constraints v f) (simple-contract-restrict 'flat)) + (define (sc-terminal-kind v) 'flat)] #:methods gen:custom-write [(define write-proc any-write-proc)]) (define-match-expander any/sc: diff --git a/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt b/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt index b3008b5d..3bb68cb3 100644 --- a/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt +++ b/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt @@ -31,6 +31,7 @@ [(define (sc->contract v f) (function-sc->contract v f)) (define (sc-map v f) (function-sc-map v f)) (define (sc-traverse v f) (function-sc-map v f) (void)) + (define (sc-terminal-kind v) (function-sc-terminal-kind v)) (define (sc->constraints v f) (function-sc-constraints v f))]) (define (split-function-args ctcs mand-args-end opt-args-end @@ -45,7 +46,10 @@ (and range-end (drop (take ctcs range-end) rest-end)))) (define (function-sc->contract sc recur) - (match-define (function-combinator args indices mand-kws opt-kws typed-side?) sc) + (match-define (function-combinator args indices mand-kws opt-kws typed-side?) sc) + + (define-values (mand-scs opt-scs mand-kw-scs opt-kw-scs rest-sc range-scs) + (apply split-function-args args indices)) (define-values (mand-ctcs opt-ctcs mand-kw-ctcs opt-kw-ctcs rest-ctc range-ctcs) (apply split-function-args (map recur args) indices)) @@ -62,20 +66,19 @@ #`(values #,@range-ctcs) #'any)) - (cond + (cond [(and (null? mand-kws) (null? opt-kws) (null? opt-ctcs) (not rest-ctc) ;; currently simple-result-> only handles up to arity 3 - (member (length mand-ctcs) '(1 0 2 3) #;(list 0 1)) + (member (length mand-ctcs) '(0 1 2 3)) (and range-ctcs (= 1 (length range-ctcs))) - (eq? 'flat (sc-terminal-kind (last args))) - ;(for/and ([a args]) (eq? 'flat (sc-terminal-kind a))) + (for/and ([a args]) (eq? 'flat (sc-terminal-kind a))) (not typed-side?)) #`(simple-result-> #,@range-ctcs #,(length mand-ctcs))] [else #`((#,@mand-ctcs #,@mand-kws-stx) - (#,@opt-ctcs #,@opt-kws-stx) + (#,@opt-ctcs #,@opt-kws-stx) #,@rest-ctc-stx . ->* . #,range-ctc)])) @@ -122,13 +125,13 @@ (list mand-args opt-args mand-kw-args opt-kw-args rest range)))])) (define (function-sc-map v f) - (match-define (function-combinator args indices mand-kws opt-kws typed-side?) v) + (match-define (function-combinator args indices mand-kws opt-kws typed-side?) v) (define-values (mand-args opt-args mand-kw-args opt-kw-args rest-arg range-args) (apply split-function-args args indices)) (define new-args - (append + (append (map (lambda (arg) (f arg 'contravariant)) (append mand-args opt-args mand-kw-args opt-kw-args (if rest-arg (list rest-arg) null))) (if range-args @@ -139,13 +142,36 @@ (function-combinator new-args indices mand-kws opt-kws typed-side?)) +(define (function-sc-terminal-kind v) + (match-define (function-combinator args indices mand-kws opt-kws typed-side?) v) + (define-values (mand-args opt-args mand-kw-args opt-kw-args rest-arg range-args) + (apply split-function-args args indices)) + (if (and (not rest-arg) + (null? (append mand-kw-args mand-args opt-kw-args opt-args)) + typed-side?) + ;; currently we only handle this trivial case + ;; we could probably look at the actual kind of `range-args` as well + (if (not range-args) 'flat #f) + #f)) + + (define (function-sc-constraints v f) - (match-define (function-combinator args indices mand-kws opt-kws typed-side?) v) - (merge-restricts* 'chaperone (map f args))) + (match-define (function-combinator args indices mand-kws opt-kws typed-side?) v) + (define-values (mand-args opt-args mand-kw-args opt-kw-args rest-arg range-args) + (apply split-function-args args indices)) + (if (and (not rest-arg) + (null? (append mand-kw-args mand-args opt-kw-args opt-args)) + typed-side?) + ;; arity-0 functions end up being flat contracts when they're + ;; from the typed side and the result is flat + (if range-args + (merge-restricts* 'flat (map f range-args)) + (merge-restricts* 'flat null)) + (merge-restricts* 'chaperone (map f args)))) (define (function-sc-equal? a b recur) - (match-define (function-combinator a-args a-indices a-mand-kws a-opt-kws a-typed-side?) a) - (match-define (function-combinator b-args b-indices b-mand-kws b-opt-kws b-typed-side?) b) + (match-define (function-combinator a-args a-indices a-mand-kws a-opt-kws a-typed-side?) a) + (match-define (function-combinator b-args b-indices b-mand-kws b-opt-kws b-typed-side?) b) (and (equal? a-typed-side? b-typed-side?) (recur a-indices b-indices) @@ -154,10 +180,9 @@ (recur a-args b-args))) (define (function-sc-hash v recur) - (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v) + (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v) (+ (recur v-indices) (recur v-mand-kws) (recur v-opt-kws) (recur v-args))) (define (function-sc-hash2 v recur) - (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v) + (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v) (+ (recur v-indices) (recur v-mand-kws) (recur v-opt-kws) (recur v-args))) - diff --git a/typed-racket-test/succeed/poly-simple-contract.rkt b/typed-racket-test/succeed/poly-simple-contract.rkt new file mode 100644 index 00000000..391521c5 --- /dev/null +++ b/typed-racket-test/succeed/poly-simple-contract.rkt @@ -0,0 +1,16 @@ +#lang typed/racket/base + + +(require/typed + racket/struct + [make-constructor-style-printer + (All (A) (-> (-> A (U Symbol String)) + (-> A (Sequenceof Any)) + (-> A Output-Port (U #t #f 0 1) Void)))]) + + +((make-constructor-style-printer + (lambda ([x : String]) x) + (lambda ([x : String]) null)) + "" + (open-output-string) #t)