trivial/math.rkt
2016-03-03 15:18:08 -05:00

101 lines
3.1 KiB
Racket

#lang typed/racket/base
;; Constant-folding math operators.
;; Where possible, they simplify their arguments.
(provide
+: -: *: /:
;; Same signature as the racket/base operators,
;; but try to simplify arguments during expansion.
expt:
;; --
(for-syntax
nat/expand
int/expand
number/expand)
)
(require (for-syntax
typed/racket/base
(only-in racket/format ~a)
(only-in racket/syntax format-id)
syntax/id-table
syntax/parse
trivial/private/common
))
;; =============================================================================
(begin-for-syntax
(define-syntax-class/predicate nat/expand exact-nonnegative-integer?)
(define-syntax-class/predicate int/expand integer?)
(define-syntax-class/predicate number/expand number?)
)
(define-syntax make-numeric-operator
(syntax-parser
[(_ f:id)
#:with f: (format-id #'f "~a:" (syntax-e #'f))
#'(define-syntax (f: stx)
(syntax-parse stx
[(g e* (... ...))
#:with e+* (for/list ([e (in-list (syntax->list #'(e* (... ...))))])
(expand-expr e))
(let ([e++ (reduce/op f (syntax->list #'e+*) #:src stx)])
(if (list? e++)
(quasisyntax/loc #'g (f #,@e++))
(quasisyntax/loc #'g #,e++)))]
[g:id
(syntax/loc #'g f)]
[(g e* (... ...))
(syntax/loc #'g (f e* (... ...)))]))]))
(make-numeric-operator +)
(make-numeric-operator -)
(make-numeric-operator *)
(make-numeric-operator /)
(define-syntax (expt: stx)
(syntax-parse stx
[(_ n1:number/expand n2:number/expand)
#:with n (expt (syntax-e #'n1.expanded) (syntax-e #'n2.expanded))
(syntax/loc stx 'n)]
[_:id
(syntax/loc stx expt)]
[(_ e* ...)
(syntax/loc stx (expt e* ...))]))
;; -----------------------------------------------------------------------------
(define-for-syntax (division-by-zero stx)
(raise-syntax-error '/ "division by zero" stx))
;; Simplify a list of expressions using an associative binary operator.
;; Return either:
;; - A numeric value
;; - A list of syntax objects, to be spliced back in the source code
(define-for-syntax (reduce/op op expr* #:src stx)
(let loop ([prev #f] ;; (U #f Number), candidate for reduction
[acc '()] ;; (Listof Syntax), irreducible arguments
[e* expr*]) ;; (Listof Syntax), arguments to process
(if (null? e*)
;; then: finished, return a number (prev) or list of expressions (acc)
(if (null? acc)
prev
(reverse (if prev (cons prev acc) acc)))
;; else: pop the next argument from e*, fold if it's a constant
(let ([v (quoted-stx-value? (car e*))])
(if (number? v)
;; then: reduce the number
(if prev
;; Watch for division-by-zero
(if (and (zero? v) (eq? / op))
(division-by-zero stx)
(loop (op prev v) acc (cdr e*)))
(loop v acc (cdr e*)))
;; else: save value in acc
(let ([acc+ (cons (car e*) (if prev (cons prev acc) acc))])
(loop #f acc+ (cdr e*))))) )))