racket/collects/unstable/inline.rkt
Vincent St-Amour b715a6fed5 Add define-inline.
Drop-in replacement for define that guarantees inlining.
2012-12-17 13:57:31 -05:00

220 lines
8.4 KiB
Racket

#lang racket/base
(require (for-syntax syntax/parse syntax/define
racket/syntax racket/base)
racket/stxparam)
(provide define-inline)
(begin-for-syntax
(define-splicing-syntax-class actual
(pattern (~seq (~optional kw:keyword) arg:expr)
#:with tmp (generate-temporary #'arg)
#:attr for-aux (list (attribute kw) (attribute tmp))))
(define-syntax-class formals
(pattern () ; done
#:with aux-args #'()
#:with ids #'()
#:with rest-arg #'()) ; rest-arg is () for no rest arg or (id)
(pattern rest:id
#:with aux-args #'rest
#:with ids #'()
#:with rest-arg #'(rest))
(pattern (x:id . rest:formals)
#:with aux-args #'(x . rest.aux-args)
#:with ids #'(x . rest.ids)
#:with rest-arg #'rest.rest-arg)
(pattern ([x:id default:expr] . rest:formals)
;; for aux-args, wrap defaults in syntax
#:with aux-args #'([x #'default] . rest.aux-args)
#:with ids #'(x . rest.ids)
#:with rest-arg #'rest.rest-arg)
(pattern (kw:keyword x:id . rest:formals)
#:with aux-args #'(kw x . rest.aux-args)
#:with ids #'(x . rest.ids)
#:with rest-arg #'rest.rest-arg)
(pattern (kw:keyword [x:id default:expr] . rest:formals)
#:with aux-args #'(kw [x #'default] . rest.aux-args)
#:with ids #'(x . rest.ids)
#:with rest-arg #'rest.rest-arg)))
(define-syntax (define-inline stx)
(syntax-parse stx
[(_ (name . formals) . body)
(define-values (name lam) (normalize-definition stx #'lambda #t #t))
(syntax-parse lam
[(_ args . body)
#`(define-inline-helper (#,name . args) . body)])]
[_
(raise-syntax-error
'define-inline "only supports definitions with function headers" stx)]))
(define-syntax (define-inline-helper stx)
(syntax-parse stx
[(_ (name:id . args:formals) . body)
(with-syntax* ([internal-name (format-id #'name "~a-internal" #'name)]
[inline-name (format-id #'name "~a-inline" #'name)]
[function-aux (format-id #'name "~a-aux" #'name)]
[(arg-id ...) #'args.ids]
[(rest-arg-id ...) #'args.rest-arg]
[(tmp-arg-id ...) (generate-temporaries #'(arg-id ...))]
[(tmp-rest-arg-id ...)
(generate-temporaries #'(rest-arg-id ...))]
[body*
#'(let-syntax ([name (make-rename-transformer
(quote-syntax internal-name))])
. body)])
#`(begin
;; define a function version that recursive calls fall back to, to
;; avoid infinite expansion
(define (internal-name . args) body*)
(define-syntax-parameter name
(syntax-id-rules ()
[(_ . rest) (inline-name . rest)]
[_ internal-name])) ; higher-order use
(define-syntax (inline-name stx*)
;; generate a compile-time function that takes care of linking
;; formals and actuals (so we don't have to handle keyword
;; arguments manually)
(define (function-aux . args.aux-args)
;; default values for optional arguments can refer to previous
;; arguments, which makes things tricky
(with-syntax* ([tmp-arg-id arg-id] ...
[tmp-rest-arg-id rest-arg-id] ...)
#'(let* ([arg-id tmp-arg-id] ...
[rest-arg-id
(list . tmp-rest-arg-id)] ...)
body*)))
(...
(syntax-parse stx*
[(_ arg*:actual ...)
;; let*-bind the actuals, to ensure that they're evaluated
;; only once, and in order
#`(syntax-parameterize
([name (make-rename-transformer #'internal-name)])
(let* ([arg*.tmp arg*.arg] ...)
#,(let* ([arg-entries (attribute arg*.for-aux)]
[keyword-entries (filter car arg-entries)]
[positional-entries
(filter (lambda (x) (not (car x)))
arg-entries)]
[sorted-keyword-entries
(sort keyword-entries
string<?
#:key (lambda (kw)
(keyword->string
(syntax-e kw))))])
(keyword-apply
function-aux
(map (lambda (x) (syntax-e (car x)))
sorted-keyword-entries)
(map cadr sorted-keyword-entries)
(map cadr positional-entries)))))])))))]))
(module+ test
(require rackunit)
;; Compares output to make sure that things are evaluated the right number of
;; times, and in the right order.
(define-syntax-rule (test/output expr res out)
(let ()
(define str (open-output-string))
(check-equal? (parameterize ([current-output-port str]) expr) res)
(check-equal? (get-output-string str) out)))
(define-inline (f x)
(+ x x))
(test/output (f (begin (display "arg1") 1)) 2 "arg1")
(define-inline (f2 x y)
(+ x y))
(test/output (f2 (begin (display "arg1") 1) (begin (display "arg2") 1))
2 "arg1arg2")
(define-inline (g #:x [x 0])
(f x))
(test/output (g #:x (begin (display "arg1") 1)) 2 "arg1")
(test/output (g) 0 "")
(define-inline (h #:x x)
(f x))
(test/output (h #:x (begin (display "arg1") 1)) 2 "arg1")
(define-inline (i [x 0])
(f x))
(test/output (i (begin (display "arg1") 1)) 2 "arg1")
(test/output (i) 0 "")
(define-inline (j x #:y [y 0])
(+ x y))
(test/output (j (begin (display "arg1") 1) #:y (begin (display "arg2") 2))
3 "arg1arg2")
(test/output (j #:y (begin (display "arg1") 2) (begin (display "arg2") 1))
3 "arg1arg2")
(test/output (j (begin (display "arg1") 1)) 1 "arg1")
(define-inline (k x [y x])
(+ x y))
(test/output (k (begin (display "arg1") 1) (begin (display "arg2") 2))
3 "arg1arg2")
(test/output (k (begin (display "arg1") 1)) 2 "arg1")
(define-inline (l . x)
(+ (apply + x) (apply + x)))
(test/output (l (begin (display "arg1") 1) (begin (display "arg2") 2))
6 "arg1arg2")
(define-inline (l2 y . x)
(+ y y (apply + x) (apply + x)))
(test/output (l2 (begin (display "arg0") 3)
(begin (display "arg1") 1)
(begin (display "arg2") 2))
12 "arg0arg1arg2")
(define-inline (l3 y [z 0] . x)
(+ y y z z z (apply + x) (apply + x)))
(test/output (l3 (begin (display "arg0") 3)
(begin (display "arg1") 1)
(begin (display "arg2") 2))
13 "arg0arg1arg2")
(test/output (l3 (begin (display "arg0") 3))
6 "arg0")
(define-inline (l4 #:x [x 0] . y)
(+ x x x (apply + y) (apply + y)))
(test/output (l4 #:x (begin (display "arg1") 1))
3 "arg1")
(test/output (l4 #:x (begin (display "arg1") 1)
(begin (display "arg2") 2)
(begin (display "arg3") 3))
13 "arg1arg2arg3")
(test/output (l4 (begin (display "arg2") 2)
(begin (display "arg3") 3))
10 "arg2arg3")
;; test for function fallback (recursion)
(define-inline (sum . l)
(if (null? l) 0 (+ (car l) (apply sum (cdr l)))))
(check-equal? (sum 1 2 3 4) 10)
;; higher-order use
(define-inline (add2 x)
(+ x 2))
(check-equal? (add2 3) 5)
(check-equal? (map add2 '(1 2 3)) '(3 4 5))
;; currying syntax
(define-inline ((adder x) y) (+ x y))
(test/output (let ([add2 (adder (begin (display "arg1") 2))])
(+ (add2 (begin (display "arg2") 1))
(add2 (begin (display "arg2") 2))))
7 "arg1arg2arg2")
(define-inline (even? x) (if (zero? x) #t (odd? (sub1 x))))
(define-inline (odd? x) (if (zero? x) #f (even? (sub1 x))))
(check-equal? (odd? 2) #f)
)