diff --git a/typed-racket-lib/typed-racket/private/type-contract.rkt b/typed-racket-lib/typed-racket/private/type-contract.rkt index cd7a6e85..25fe11e2 100644 --- a/typed-racket-lib/typed-racket/private/type-contract.rkt +++ b/typed-racket-lib/typed-racket/private/type-contract.rkt @@ -170,6 +170,7 @@ typed-racket/utils/evt-contract typed-racket/utils/sealing-contract typed-racket/utils/promise-not-name-contract + typed-racket/utils/simple-result-arrow racket/sequence racket/contract/parametric)) @@ -664,7 +665,7 @@ (map conv opt-kws)))) (define range (map t->sc rngs)) (define rest (and rst (listof/sc (t->sc/neg rst)))) - (function/sc (process-dom mand-args) opt-args mand-kws opt-kws rest range)) + (function/sc (from-typed? typed-side) (process-dom mand-args) opt-args mand-kws opt-kws rest range)) (handle-range first-arr convert-arr)] [else (define ((f case->) a) @@ -681,6 +682,7 @@ (and rst (listof/sc (t->sc/neg rst))) (map t->sc rngs)) (function/sc + (from-typed? typed-side) (process-dom (map t->sc/neg dom)) null (map conv mand-kws) diff --git a/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt b/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt index 37c1dfae..b66a23d9 100644 --- a/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt +++ b/typed-racket-lib/typed-racket/static-contracts/combinators/function.rkt @@ -6,12 +6,13 @@ (require "../structures.rkt" "../constraints.rkt" racket/list racket/match racket/contract - (for-template racket/base racket/contract/base) + (for-template racket/base racket/contract/base "../../utils/simple-result-arrow.rkt") (for-syntax racket/base syntax/parse)) (provide (contract-out - [function/sc (-> (listof static-contract?) + [function/sc (-> boolean? + (listof static-contract?) (listof static-contract?) (listof (list/c keyword? static-contract?)) (listof (list/c keyword? static-contract?)) @@ -21,7 +22,7 @@ ->/sc:) -(struct function-combinator combinator (indices mand-kws opt-kws) +(struct function-combinator combinator (indices mand-kws opt-kws typed-side?) #:property prop:combinator-name "->/sc" #:methods gen:equal+hash [(define (equal-proc a b recur) (function-sc-equal? a b recur)) (define (hash-proc v recur) (function-sc-hash v recur)) @@ -44,7 +45,7 @@ (and range-end (drop (take ctcs range-end) rest-end)))) (define (function-sc->contract sc recur) - (match-define (function-combinator args indices mand-kws opt-kws) sc) + (match-define (function-combinator args indices mand-kws opt-kws typed-side?) sc) (define-values (mand-ctcs opt-ctcs mand-kw-ctcs opt-kw-ctcs rest-ctc range-ctcs) (apply split-function-args (map recur args) indices)) @@ -61,14 +62,23 @@ #`(values #,@range-ctcs) #'any)) - - #`((#,@mand-ctcs #,@mand-kws-stx) - (#,@opt-ctcs #,@opt-kws-stx) - #,@rest-ctc-stx - . ->* . #,range-ctc)) + (cond + [(and (null? mand-kws) (null? opt-kws) + (null? opt-ctcs) + (not rest-ctc) + (= 1 (length mand-ctcs)) + (and range-ctcs (= 1 (length range-ctcs))) + (eq? 'flat (sc-terminal-kind (last args))) + (not typed-side?)) + #`(simple-result-> #,@range-ctcs)] + [else + #`((#,@mand-ctcs #,@mand-kws-stx) + (#,@opt-ctcs #,@opt-kws-stx) + #,@rest-ctc-stx + . ->* . #,range-ctc)])) -(define (function/sc mand-args opt-args mand-kw-args opt-kw-args rest range) +(define (function/sc typed-side? mand-args opt-args mand-kw-args opt-kw-args rest range) (define mand-args-end (length mand-args)) (define opt-args-end (+ mand-args-end (length opt-args))) (define mand-kw-args-end (+ opt-args-end (length mand-kw-args))) @@ -90,14 +100,15 @@ (or range null)) end-indices mand-kws - opt-kws)) + opt-kws + typed-side?)) (define-match-expander ->/sc: (syntax-parser [(_ mand-args opt-args mand-kw-args opt-kw-args rest range) #'(and (? function-combinator?) (app (match-lambda - [(function-combinator args indices mand-kws opt-kws) + [(function-combinator args indices mand-kws opt-kws typed-side?) (define-values (mand-args* opt-args* mand-kw-args* opt-kw-args* rest* range*) (apply split-function-args args indices)) (list @@ -109,7 +120,7 @@ (list mand-args opt-args mand-kw-args opt-kw-args rest range)))])) (define (function-sc-map v f) - (match-define (function-combinator args indices mand-kws opt-kws) v) + (match-define (function-combinator args indices mand-kws opt-kws typed-side?) v) (define-values (mand-args opt-args mand-kw-args opt-kw-args rest-arg range-args) (apply split-function-args args indices)) @@ -124,26 +135,27 @@ empty))) - (function-combinator new-args indices mand-kws opt-kws)) + (function-combinator new-args indices mand-kws opt-kws typed-side?)) (define (function-sc-constraints v f) - (match-define (function-combinator args indices mand-kws opt-kws) v) + (match-define (function-combinator args indices mand-kws opt-kws typed-side?) v) (merge-restricts* 'chaperone (map f args))) (define (function-sc-equal? a b recur) - (match-define (function-combinator a-args a-indices a-mand-kws a-opt-kws) a) - (match-define (function-combinator b-args b-indices b-mand-kws b-opt-kws) b) + (match-define (function-combinator a-args a-indices a-mand-kws a-opt-kws a-typed-side?) a) + (match-define (function-combinator b-args b-indices b-mand-kws b-opt-kws b-typed-side?) b) (and + (equal? a-typed-side? b-typed-side?) (recur a-indices b-indices) (recur a-mand-kws b-mand-kws) (recur a-opt-kws b-opt-kws) (recur a-args b-args))) (define (function-sc-hash v recur) - (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws) v) + (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v) (+ (recur v-indices) (recur v-mand-kws) (recur v-opt-kws) (recur v-args))) (define (function-sc-hash2 v recur) - (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws) v) + (match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v) (+ (recur v-indices) (recur v-mand-kws) (recur v-opt-kws) (recur v-args))) diff --git a/typed-racket-lib/typed-racket/static-contracts/optimize.rkt b/typed-racket-lib/typed-racket/static-contracts/optimize.rkt index a20f9f5a..e3466019 100644 --- a/typed-racket-lib/typed-racket/static-contracts/optimize.rkt +++ b/typed-racket-lib/typed-racket/static-contracts/optimize.rkt @@ -93,6 +93,7 @@ (fail)) ;; All the checks passed (function/sc + #t (take longest-args (length shortest-args)) (drop longest-args (length shortest-args)) empty @@ -110,7 +111,7 @@ (define (trusted-side-reduce sc) (match sc [(->/sc: mand-args opt-args mand-kw-args opt-kw-args rest-arg (list (any/sc:) ...)) - (function/sc mand-args opt-args mand-kw-args opt-kw-args rest-arg #f)] + (function/sc #t mand-args opt-args mand-kw-args opt-kw-args rest-arg #f)] [(arr/sc: args rest (list (any/sc:) ...)) (arr/sc args rest #f)] [(none/sc:) any/sc] diff --git a/typed-racket-lib/typed-racket/utils/simple-result-arrow.rkt b/typed-racket-lib/typed-racket/utils/simple-result-arrow.rkt new file mode 100644 index 00000000..2b9ee10c --- /dev/null +++ b/typed-racket-lib/typed-racket/utils/simple-result-arrow.rkt @@ -0,0 +1,80 @@ +#lang racket/base +(require racket/unsafe/ops racket/contract/base racket/contract/combinator + racket/format) + +(provide simple-result->) + +(define (simple-result-> c) + (define c* (coerce-flat-contract 'simple-result-> c)) + (define pred (flat-contract-predicate c*)) + (define n (contract-name c*)) + (make-chaperone-contract + #:name `(-> any/c ,n) + #:first-order (λ (v) (and (procedure? v) (procedure-arity-includes? v 1))) + #:late-neg-projection + (λ (blm) + (lambda (v neg) + (if (and (equal? 1 (procedure-arity v)) + (equal? 1 (procedure-result-arity v))) + (unsafe-chaperone-procedure + v + (λ (arg) + (define res (v arg)) + (unless (with-contract-continuation-mark (cons blm neg) (pred res)) + (raise-blame-error + blm #f + (list 'expected: (~s n) 'given: (~s res)))) + res)) + (unsafe-chaperone-procedure + v + (case-lambda + [(arg) + (call-with-values (λ () (v arg)) + (case-lambda [(res) + (unless (with-contract-continuation-mark + (cons blm neg) + (pred res)) + (raise-blame-error + blm #f + (list 'expected: (~s n) 'given: (~s res)))) + res] + [results + (raise-blame-error + v results + (list 'expected "one value" + 'given (~a (length results) + " values")))]))] + [args + (raise-blame-error + blm #f + (list 'expected: "one argument" + 'given: (~a (length args) " arguments")))]))))))) + +(module+ test + (struct m (x)) + (define val (m 1)) + (define c0 (-> any/c real?)) + (define c1 (unconstrained-domain-> real?)) + (define c2 (simple-result-> real?)) + + (define f0 (contract c0 m-x 'pos 'neg)) + (define f1 (contract c1 m-x 'pos 'neg)) + (define f2 (contract c2 m-x 'pos 'neg)) + (define f3 (contract c2 number->string 'pos 'neg)) + (define N 1000000) + (collect-garbage) + 'f0 + (time (for/sum ([i (in-range N)]) + (f0 val))) + (collect-garbage) + 'f1 + (time (for/sum ([i (in-range N)]) + (f1 val))) + (collect-garbage) + 'f2 + (time (for/sum ([i (in-range N)]) + (f2 val))) + + 'm-x + (time (for/sum ([i (in-range N)]) + (m-x val)))) diff --git a/typed-racket-test/unit-tests/contract-tests.rkt b/typed-racket-test/unit-tests/contract-tests.rkt index 2890b104..8eb8c75a 100644 --- a/typed-racket-test/unit-tests/contract-tests.rkt +++ b/typed-racket-test/unit-tests/contract-tests.rkt @@ -67,6 +67,7 @@ (namespace-require 'typed-racket/utils/any-wrap) (namespace-require 'typed-racket/utils/evt-contract) (namespace-require 'typed-racket/utils/opaque-object) + (namespace-require 'typed-racket/utils/simple-result-arrow) (namespace-require '(submod typed-racket/private/type-contract predicates)) (namespace-require 'typed/racket/class) (current-namespace))) diff --git a/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt b/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt index 5139323a..9d75e8e3 100644 --- a/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt +++ b/typed-racket-test/unit-tests/static-contract-optimizer-tests.rkt @@ -220,42 +220,48 @@ #:neg (promise/sc (box/sc set?/sc))) (check-optimize - (function/sc (list (listof/sc any/sc)) + (function/sc #t + (list (listof/sc any/sc)) (list) (list) (list) #f (list (listof/sc any/sc))) #:pos - (function/sc (list list?/sc) + (function/sc #t + (list list?/sc) (list) (list) (list) #f #f) #:neg - (function/sc (list any/sc) + (function/sc #t + (list any/sc) (list) (list) (list) #f (list list?/sc))) (check-optimize - (function/sc (list (listof/sc any/sc)) + (function/sc #t + (list (listof/sc any/sc)) (list) (list) (list) #f (list any/sc)) #:pos - (function/sc (list list?/sc) + (function/sc #t + (list list?/sc) (list) (list) (list) #f #f) #:neg - (function/sc (list any/sc) + (function/sc #t + (list any/sc) (list) (list) (list) @@ -323,7 +329,8 @@ (list (arr/sc empty #f (list set?/sc)) (arr/sc (list identifier?/sc) #f (list (listof/sc set?/sc))))) - #:pos (function/sc (list) + #:pos (function/sc #t + (list) (list identifier?/sc) (list) (list)