Add the simple-result-> combinator to Typed Racket.

This is used for functions with a single argument imported with
`require/typed`, and avoids unneccessary checks. This produces a
3x speedup on the following benchmark:

  #lang racket/base
  (module m racket/base
    (provide f)
    (define (f x) x))
  (module n typed/racket/base
    (require/typed
     (submod ".." m)
     [f (-> Integer Integer)])
    (time
     (for ([x (in-range 1000000)])
       (f 1) (f 2) (f 3) (f 4))))
  (require 'n)

on top of the previous improvment from using `unsafe-procedure-chaperone`
and `procedure-result-arity`.
This commit is contained in:
Sam Tobin-Hochstadt 2016-01-16 19:18:34 -05:00
parent 7217e2e531
commit 838431c176
6 changed files with 131 additions and 28 deletions

View File

@ -170,6 +170,7 @@
typed-racket/utils/evt-contract
typed-racket/utils/sealing-contract
typed-racket/utils/promise-not-name-contract
typed-racket/utils/simple-result-arrow
racket/sequence
racket/contract/parametric))
@ -664,7 +665,7 @@
(map conv opt-kws))))
(define range (map t->sc rngs))
(define rest (and rst (listof/sc (t->sc/neg rst))))
(function/sc (process-dom mand-args) opt-args mand-kws opt-kws rest range))
(function/sc (from-typed? typed-side) (process-dom mand-args) opt-args mand-kws opt-kws rest range))
(handle-range first-arr convert-arr)]
[else
(define ((f case->) a)
@ -681,6 +682,7 @@
(and rst (listof/sc (t->sc/neg rst)))
(map t->sc rngs))
(function/sc
(from-typed? typed-side)
(process-dom (map t->sc/neg dom))
null
(map conv mand-kws)

View File

@ -6,12 +6,13 @@
(require "../structures.rkt" "../constraints.rkt"
racket/list racket/match
racket/contract
(for-template racket/base racket/contract/base)
(for-template racket/base racket/contract/base "../../utils/simple-result-arrow.rkt")
(for-syntax racket/base syntax/parse))
(provide
(contract-out
[function/sc (-> (listof static-contract?)
[function/sc (-> boolean?
(listof static-contract?)
(listof static-contract?)
(listof (list/c keyword? static-contract?))
(listof (list/c keyword? static-contract?))
@ -21,7 +22,7 @@
->/sc:)
(struct function-combinator combinator (indices mand-kws opt-kws)
(struct function-combinator combinator (indices mand-kws opt-kws typed-side?)
#:property prop:combinator-name "->/sc"
#:methods gen:equal+hash [(define (equal-proc a b recur) (function-sc-equal? a b recur))
(define (hash-proc v recur) (function-sc-hash v recur))
@ -44,7 +45,7 @@
(and range-end (drop (take ctcs range-end) rest-end))))
(define (function-sc->contract sc recur)
(match-define (function-combinator args indices mand-kws opt-kws) sc)
(match-define (function-combinator args indices mand-kws opt-kws typed-side?) sc)
(define-values (mand-ctcs opt-ctcs mand-kw-ctcs opt-kw-ctcs rest-ctc range-ctcs)
(apply split-function-args (map recur args) indices))
@ -61,14 +62,23 @@
#`(values #,@range-ctcs)
#'any))
#`((#,@mand-ctcs #,@mand-kws-stx)
(#,@opt-ctcs #,@opt-kws-stx)
#,@rest-ctc-stx
. ->* . #,range-ctc))
(cond
[(and (null? mand-kws) (null? opt-kws)
(null? opt-ctcs)
(not rest-ctc)
(= 1 (length mand-ctcs))
(and range-ctcs (= 1 (length range-ctcs)))
(eq? 'flat (sc-terminal-kind (last args)))
(not typed-side?))
#`(simple-result-> #,@range-ctcs)]
[else
#`((#,@mand-ctcs #,@mand-kws-stx)
(#,@opt-ctcs #,@opt-kws-stx)
#,@rest-ctc-stx
. ->* . #,range-ctc)]))
(define (function/sc mand-args opt-args mand-kw-args opt-kw-args rest range)
(define (function/sc typed-side? mand-args opt-args mand-kw-args opt-kw-args rest range)
(define mand-args-end (length mand-args))
(define opt-args-end (+ mand-args-end (length opt-args)))
(define mand-kw-args-end (+ opt-args-end (length mand-kw-args)))
@ -90,14 +100,15 @@
(or range null))
end-indices
mand-kws
opt-kws))
opt-kws
typed-side?))
(define-match-expander ->/sc:
(syntax-parser
[(_ mand-args opt-args mand-kw-args opt-kw-args rest range)
#'(and (? function-combinator?)
(app (match-lambda
[(function-combinator args indices mand-kws opt-kws)
[(function-combinator args indices mand-kws opt-kws typed-side?)
(define-values (mand-args* opt-args* mand-kw-args* opt-kw-args* rest* range*)
(apply split-function-args args indices))
(list
@ -109,7 +120,7 @@
(list mand-args opt-args mand-kw-args opt-kw-args rest range)))]))
(define (function-sc-map v f)
(match-define (function-combinator args indices mand-kws opt-kws) v)
(match-define (function-combinator args indices mand-kws opt-kws typed-side?) v)
(define-values (mand-args opt-args mand-kw-args opt-kw-args rest-arg range-args)
(apply split-function-args args indices))
@ -124,26 +135,27 @@
empty)))
(function-combinator new-args indices mand-kws opt-kws))
(function-combinator new-args indices mand-kws opt-kws typed-side?))
(define (function-sc-constraints v f)
(match-define (function-combinator args indices mand-kws opt-kws) v)
(match-define (function-combinator args indices mand-kws opt-kws typed-side?) v)
(merge-restricts* 'chaperone (map f args)))
(define (function-sc-equal? a b recur)
(match-define (function-combinator a-args a-indices a-mand-kws a-opt-kws) a)
(match-define (function-combinator b-args b-indices b-mand-kws b-opt-kws) b)
(match-define (function-combinator a-args a-indices a-mand-kws a-opt-kws a-typed-side?) a)
(match-define (function-combinator b-args b-indices b-mand-kws b-opt-kws b-typed-side?) b)
(and
(equal? a-typed-side? b-typed-side?)
(recur a-indices b-indices)
(recur a-mand-kws b-mand-kws)
(recur a-opt-kws b-opt-kws)
(recur a-args b-args)))
(define (function-sc-hash v recur)
(match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws) v)
(match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v)
(+ (recur v-indices) (recur v-mand-kws) (recur v-opt-kws) (recur v-args)))
(define (function-sc-hash2 v recur)
(match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws) v)
(match-define (function-combinator v-args v-indices v-mand-kws v-opt-kws typed-side?) v)
(+ (recur v-indices) (recur v-mand-kws) (recur v-opt-kws) (recur v-args)))

View File

@ -93,6 +93,7 @@
(fail))
;; All the checks passed
(function/sc
#t
(take longest-args (length shortest-args))
(drop longest-args (length shortest-args))
empty
@ -110,7 +111,7 @@
(define (trusted-side-reduce sc)
(match sc
[(->/sc: mand-args opt-args mand-kw-args opt-kw-args rest-arg (list (any/sc:) ...))
(function/sc mand-args opt-args mand-kw-args opt-kw-args rest-arg #f)]
(function/sc #t mand-args opt-args mand-kw-args opt-kw-args rest-arg #f)]
[(arr/sc: args rest (list (any/sc:) ...))
(arr/sc args rest #f)]
[(none/sc:) any/sc]

View File

@ -0,0 +1,80 @@
#lang racket/base
(require racket/unsafe/ops racket/contract/base racket/contract/combinator
racket/format)
(provide simple-result->)
(define (simple-result-> c)
(define c* (coerce-flat-contract 'simple-result-> c))
(define pred (flat-contract-predicate c*))
(define n (contract-name c*))
(make-chaperone-contract
#:name `(-> any/c ,n)
#:first-order (λ (v) (and (procedure? v) (procedure-arity-includes? v 1)))
#:late-neg-projection
(λ (blm)
(lambda (v neg)
(if (and (equal? 1 (procedure-arity v))
(equal? 1 (procedure-result-arity v)))
(unsafe-chaperone-procedure
v
(λ (arg)
(define res (v arg))
(unless (with-contract-continuation-mark (cons blm neg) (pred res))
(raise-blame-error
blm #f
(list 'expected: (~s n) 'given: (~s res))))
res))
(unsafe-chaperone-procedure
v
(case-lambda
[(arg)
(call-with-values (λ () (v arg))
(case-lambda [(res)
(unless (with-contract-continuation-mark
(cons blm neg)
(pred res))
(raise-blame-error
blm #f
(list 'expected: (~s n) 'given: (~s res))))
res]
[results
(raise-blame-error
v results
(list 'expected "one value"
'given (~a (length results)
" values")))]))]
[args
(raise-blame-error
blm #f
(list 'expected: "one argument"
'given: (~a (length args) " arguments")))])))))))
(module+ test
(struct m (x))
(define val (m 1))
(define c0 (-> any/c real?))
(define c1 (unconstrained-domain-> real?))
(define c2 (simple-result-> real?))
(define f0 (contract c0 m-x 'pos 'neg))
(define f1 (contract c1 m-x 'pos 'neg))
(define f2 (contract c2 m-x 'pos 'neg))
(define f3 (contract c2 number->string 'pos 'neg))
(define N 1000000)
(collect-garbage)
'f0
(time (for/sum ([i (in-range N)])
(f0 val)))
(collect-garbage)
'f1
(time (for/sum ([i (in-range N)])
(f1 val)))
(collect-garbage)
'f2
(time (for/sum ([i (in-range N)])
(f2 val)))
'm-x
(time (for/sum ([i (in-range N)])
(m-x val))))

View File

@ -67,6 +67,7 @@
(namespace-require 'typed-racket/utils/any-wrap)
(namespace-require 'typed-racket/utils/evt-contract)
(namespace-require 'typed-racket/utils/opaque-object)
(namespace-require 'typed-racket/utils/simple-result-arrow)
(namespace-require '(submod typed-racket/private/type-contract predicates))
(namespace-require 'typed/racket/class)
(current-namespace)))

View File

@ -220,42 +220,48 @@
#:neg (promise/sc (box/sc set?/sc)))
(check-optimize
(function/sc (list (listof/sc any/sc))
(function/sc #t
(list (listof/sc any/sc))
(list)
(list)
(list)
#f
(list (listof/sc any/sc)))
#:pos
(function/sc (list list?/sc)
(function/sc #t
(list list?/sc)
(list)
(list)
(list)
#f
#f)
#:neg
(function/sc (list any/sc)
(function/sc #t
(list any/sc)
(list)
(list)
(list)
#f
(list list?/sc)))
(check-optimize
(function/sc (list (listof/sc any/sc))
(function/sc #t
(list (listof/sc any/sc))
(list)
(list)
(list)
#f
(list any/sc))
#:pos
(function/sc (list list?/sc)
(function/sc #t
(list list?/sc)
(list)
(list)
(list)
#f
#f)
#:neg
(function/sc (list any/sc)
(function/sc #t
(list any/sc)
(list)
(list)
(list)
@ -323,7 +329,8 @@
(list
(arr/sc empty #f (list set?/sc))
(arr/sc (list identifier?/sc) #f (list (listof/sc set?/sc)))))
#:pos (function/sc (list)
#:pos (function/sc #t
(list)
(list identifier?/sc)
(list)
(list)