[math] cleaner

This commit is contained in:
ben 2016-03-09 00:55:31 -05:00
parent 59f5b165b1
commit 62109f33db
5 changed files with 225 additions and 109 deletions

View File

@ -10,91 +10,9 @@
expt: expt:
;; -- define-num: let-num:
(for-syntax
nat/expand
int/expand
number/expand)
) )
(require (for-syntax (require
typed/racket/base (only-in trivial/private/math
(only-in racket/format ~a) +: -: *: /: expt: let-num: define-num:))
(only-in racket/syntax format-id)
syntax/id-table
syntax/parse
trivial/private/common
))
;; =============================================================================
(begin-for-syntax
(define-syntax-class/predicate nat/expand exact-nonnegative-integer?)
(define-syntax-class/predicate int/expand integer?)
(define-syntax-class/predicate number/expand number?)
)
(define-syntax make-numeric-operator
(syntax-parser
[(_ f:id)
#:with f: (format-id #'f "~a:" (syntax-e #'f))
#'(define-syntax (f: stx)
(syntax-parse stx
[(g e* (... ...))
#:with e+* (for/list ([e (in-list (syntax->list #'(e* (... ...))))])
(expand-expr e))
(let ([e++ (reduce/op f (syntax->list #'e+*) #:src stx)])
(if (list? e++)
(quasisyntax/loc #'g (f #,@e++))
(quasisyntax/loc #'g #,e++)))]
[g:id
(syntax/loc #'g f)]
[(g e* (... ...))
(syntax/loc #'g (f e* (... ...)))]))]))
(make-numeric-operator +)
(make-numeric-operator -)
(make-numeric-operator *)
(make-numeric-operator /)
(define-syntax (expt: stx)
(syntax-parse stx
[(_ n1:number/expand n2:number/expand)
#:with n (expt (syntax-e #'n1.expanded) (syntax-e #'n2.expanded))
(syntax/loc stx 'n)]
[_:id
(syntax/loc stx expt)]
[(_ e* ...)
(syntax/loc stx (expt e* ...))]))
;; -----------------------------------------------------------------------------
(define-for-syntax (division-by-zero stx)
(raise-syntax-error '/ "division by zero" stx))
;; Simplify a list of expressions using an associative binary operator.
;; Return either:
;; - A numeric value
;; - A list of syntax objects, to be spliced back in the source code
(define-for-syntax (reduce/op op expr* #:src stx)
(let loop ([prev #f] ;; (U #f Number), candidate for reduction
[acc '()] ;; (Listof Syntax), irreducible arguments
[e* expr*]) ;; (Listof Syntax), arguments to process
(if (null? e*)
;; then: finished, return a number (prev) or list of expressions (acc)
(if (null? acc)
prev
(reverse (if prev (cons prev acc) acc)))
;; else: pop the next argument from e*, fold if it's a constant
(let ([v (quoted-stx-value? (car e*))])
(if (number? v)
;; then: reduce the number
(if prev
;; Watch for division-by-zero
(if (and (zero? v) (eq? / op))
(division-by-zero stx)
(loop (op prev v) acc (cdr e*)))
(loop v acc (cdr e*)))
;; else: save value in acc
(let ([acc+ (cons (car e*) (if prev (cons prev acc) acc))])
(loop #f acc+ (cdr e*))))) )))

View File

@ -11,7 +11,10 @@
;; Otherwise, return #f. ;; Otherwise, return #f.
define-syntax-class/predicate define-syntax-class/predicate
;; (stx-> Identifier (-> Any Boolean) SyntaxClassDef) ;; TODO
lift-predicate
;; TODO
make-value-property make-value-property
;; TODO ;; TODO
@ -25,7 +28,9 @@
syntax/parse syntax/parse
syntax/id-table syntax/id-table
(for-syntax (only-in typed/racket/base let let-syntax #%app)) (for-syntax (only-in typed/racket/base let let-syntax #%app))
(for-template (only-in typed/racket/base quote))) (for-template
(prefix-in r: (only-in racket/base quote))
(prefix-in tr: (only-in typed/racket/base quote))))
;; ============================================================================= ;; =============================================================================
@ -34,10 +39,10 @@
#:attributes (evidence expanded) #:attributes (evidence expanded)
(pattern e (pattern e
#:with e+ (expand-expr #'e) #:with e+ (expand-expr #'e)
#:with p+ (p? #'e+) #:with p+ (p? (syntax/loc #'e e+))
#:when (if (syntax-e #'p+) #t (begin (printf "ERROR we failed iwth ~a\n" (syntax->datum #'e+)) #f)) ;; TODO remove this #:when (syntax-e #'p+)
#:attr evidence #'p+ #:attr evidence (syntax/loc #'e p+)
#:attr expanded #'e+))) #:attr expanded (syntax/loc #'e e+))))
(define (expand-expr stx) (define (expand-expr stx)
(local-expand stx 'expression '())) (local-expand stx 'expression '()))
@ -45,11 +50,20 @@
(define (quoted-stx-value? stx) (define (quoted-stx-value? stx)
(and (and
(syntax? stx) (syntax? stx)
(syntax-case stx (quote) (syntax-parse stx #:literals (r:quote tr:quote) #:datum-literals (quote)
[(quote v) [((~or r:quote tr:quote quote) v)
(syntax-e #'v)] (syntax-e #'v)]
[else #f]))) [else #f])))
(define (lift-predicate p?)
(lambda (stx)
(cond
[(p? stx) stx]
[(p? (syntax-e stx)) (syntax-e stx)]
[(p? (quoted-stx-value? stx))
stx]
[else #f])))
;; In: ;; In:
;; - name : Symbol, like format-spec or vector-length or db-schema ;; - name : Symbol, like format-spec or vector-length or db-schema
;; - parser : (Syntax -> Value) ;; - parser : (Syntax -> Value)
@ -77,28 +91,28 @@
(lambda (stx) (lambda (stx)
(syntax-parse stx (syntax-parse stx
[(_ name:id v) [(_ name:id v)
#:with v+ (expand-expr #'v) #:with v+ (expand-expr (syntax/loc stx v))
#:when (syntax-e #'v+) #:when (syntax-e (syntax/loc stx v+))
#:with m (f-parse #'v+) #:with m (f-parse (syntax/loc stx v+))
#:when (syntax-e #'m) #:when (syntax-e (syntax/loc stx m))
(free-id-table-set! #'name (syntax-e #'m)) #:with define-stx (format-id stx "define")
(free-id-table-set! tbl #'name (syntax-e #'m))
(syntax/loc stx (syntax/loc stx
(define name v+))] (define-stx name v+))]
[_ #f]))) [_ #f])))
(define f-let (define f-let
(lambda (stx) (lambda (stx)
(syntax-parse stx (syntax-parse stx
[(_ ([name*:id v*] ...) e* ...) [(_ ([name*:id v*] ...) e* ...)
#:with (v+* ...) (map expand-expr (syntax-e #'(v* ...))) #:with (v+* ...) (map expand-expr (syntax-e (syntax/loc stx (v* ...))))
#:when (andmap syntax-e (syntax-e #'(v+* ...))) #:with (m* ...) (map f-parse (syntax-e (syntax/loc stx (v+* ...))))
#:with (m* ...) (map f-parse (syntax-e #'(v+* ...))) #:when (andmap syntax-e (syntax-e (syntax/loc stx (m* ...))))
#:when (andmap syntax-e (syntax-e #'(m* ...)))
#:with let-stx (format-id stx "let") #:with let-stx (format-id stx "let")
#:with let-syntax-stx (format-id stx "let-syntax") #:with let-syntax-stx (format-id stx "let-syntax")
(quasisyntax/loc stx (quasisyntax/loc stx
(let-stx ([name* v+*] ...) (let-stx ([name* v+*] ...)
(let-syntax-stx ([name* (make-rename-transformer (let-syntax-stx ([name* (make-rename-transformer
(syntax-property #'name* '#,key 'm* ...))] ...) (syntax-property #'name* '#,key 'm*))] ...)
e* ...)))] e* ...)))]
[_ #f]))) [_ #f])))
(values (values
@ -107,8 +121,13 @@
f-define f-define
f-let)) f-let))
(define ((make-alias id-stx parser) stx) (define ((make-alias id-sym parser) stx)
(or (parser stx) (or (parser stx)
(syntax-parse stx (syntax-parse stx
[_:id (quasisyntax/loc stx #,id-stx)] [_:id
[(_ e* ...) (quasisyntax/loc stx (#,id-stx e* ...))]))) #:with id-stx (format-id stx "~a" id-sym)
(syntax/loc stx id-stx)]
[(_ e* ...)
#:with id-stx (format-id stx "~a" id-sym)
#:with app-stx (format-id stx "#%app")
(syntax/loc stx (app-stx id-stx e* ...))])))

119
private/math.rkt Normal file
View File

@ -0,0 +1,119 @@
#lang typed/racket/base
;; Constant-folding math operators.
;; Where possible, they simplify their arguments.
;; TODO the or- stuff is not so pretty, but it's working anyway
(provide
+: -: *: /:
;; Same signature as the racket/base operators,
;; but try to simplify arguments during expansion.
expt:
define-num: let-num:
;; --
(for-syntax
nat/expand
int/expand
num/expand)
)
(require (for-syntax
typed/racket/base
(only-in racket/format ~a)
(only-in racket/syntax format-id)
syntax/id-table
syntax/parse
trivial/private/common
))
;; =============================================================================
(begin-for-syntax
(define (division-by-zero stx)
(raise-syntax-error '/ "division by zero" stx))
;; Simplify a list of expressions using an associative binary operator.
;; Return either:
;; - A numeric value
;; - A list of syntax objects, to be spliced back in the source code
(define (reduce/op op stx)
(define expr* (syntax-e stx))
(cond
[(list? expr*)
(let loop ([prev #f] ;; (U #f Number), candidate for reduction
[acc '()] ;; (Listof Syntax), irreducible arguments
[e* expr*]) ;; (Listof Syntax), arguments to process
(if (null? e*)
;; then: finished, return a number (prev) or list of expressions (acc)
(if (null? acc)
prev
(reverse (if prev (cons prev acc) acc)))
;; else: pop the next argument from e*, fold if it's a constant
(syntax-parse (car e*)
[v:num/expand
(define v (or (quoted-stx-value? #'v.expanded)
(quoted-stx-value? #'v.evidence)))
;; then: reduce the number
(if prev
;; Watch for division-by-zero
(if (and (zero? v) (eq? / op))
(division-by-zero stx)
(loop (op prev v) acc (cdr e*)))
(loop v acc (cdr e*)))]
[v
;; else: save value in acc
(let ([acc+ (cons (car e*) (if prev (cons prev acc) acc))])
(loop #f acc+ (cdr e*)))])))]
[else #f]))
(define-values (nat-key nat? nat-define nat-let)
(make-value-property 'number:natural (lift-predicate exact-nonnegative-integer?)))
(define-syntax-class/predicate nat/expand nat?)
(define-values (int-key int? int-define int-let)
(make-value-property 'number:integer (lift-predicate integer?)))
(define-syntax-class/predicate int/expand int?)
(define-values (num-key num? num-define num-let)
(make-value-property 'number:number (lift-predicate number?)))
(define-syntax-class/predicate num/expand num?)
)
;; -----------------------------------------------------------------------------
(define-syntax define-num: (make-alias 'define num-define))
(define-syntax let-num: (make-alias 'let num-let))
(define-syntax make-numeric-operator
(syntax-parser
[(_ f:id)
#:with f: (format-id #'f "~a:" (syntax-e #'f))
#'(define-syntax f: (make-alias #'f
(lambda (stx) (syntax-parse stx
[(_ e* (... ...))
#:with f-id (format-id stx "~a" 'f)
(let ([e+ (reduce/op f #'(e* (... ...)))])
(if (list? e+)
(quasisyntax/loc stx (#%app f-id #,@e+))
(quasisyntax/loc stx #,e+)))]
[_ #f]))))]))
(make-numeric-operator +)
(make-numeric-operator -)
(make-numeric-operator *)
(make-numeric-operator /)
(define-syntax expt: (make-alias 'expt
(lambda (stx) (syntax-parse stx
[(_ n1:num/expand n2:num/expand)
(let ([v1 (or (quoted-stx-value? #'n1.expanded)
(quoted-stx-value? #'n1.evidence))]
[v2 (or (quoted-stx-value? #'n2.expanded)
(quoted-stx-value? #'n2.evidence))])
(and v1 v2 ;; Should never fail
(quasisyntax/loc stx #,(expt v1 v2))))]
[_ #f]))))

View File

@ -14,6 +14,7 @@
(ann (let ([n 5]) (*: n 1/5 1)) One) (ann (let ([n 5]) (*: n 1/5 1)) One)
(ann (let ([n 4]) (/: n n)) One) (ann (let ([n 4]) (/: n n)) One)
(ann (let ([n 2]) (expt: 3 (-: n n))) One) (ann (let ([n 2]) (expt: 3 (-: n n))) One)
(ann (expt: 3 2) Zero)
;; -- lambda => back to racket/base ;; -- lambda => back to racket/base
(ann ((lambda ([f : (-> Natural Natural Natural)]) (f 0 0)) +:) Zero) (ann ((lambda ([f : (-> Natural Natural Natural)]) (f 0 0)) +:) Zero)
(ann ((lambda ([f : (-> Natural Natural Integer)]) (f 0 0)) -:) Zero) (ann ((lambda ([f : (-> Natural Natural Integer)]) (f 0 0)) -:) Zero)

View File

@ -19,6 +19,17 @@
(ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) +:) Integer) (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) +:) Integer)
0) 0)
(check-equal?
(let-num: ([n -4] [m 5])
(ann (+: m n -1) Zero))
0)
(check-equal?
(let ()
(define-num: n 6)
(define-num: m -8)
(ann (+: n 2 m) Zero))
0)
;; -- -: ;; -- -:
(check-equal? (ann (-: 0 0) Zero) 0) (check-equal? (ann (-: 0 0) Zero) 0)
@ -31,6 +42,17 @@
(ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) -:) Integer) (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) -:) Integer)
0) 0)
(check-equal?
(let-num: ([n 4] [m 5])
(ann (-: m n 1) Zero))
0)
(check-equal?
(let ()
(define-num: n 6)
(define-num: m -8)
(ann (-: n m 14) Zero))
0)
;; -- *: ;; -- *:
(check-equal? (ann (*: 0 1315) Zero) 0) (check-equal? (ann (*: 0 1315) Zero) 0)
@ -43,6 +65,18 @@
(ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) *:) Integer) (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) *:) Integer)
0) 0)
(check-equal?
(let-num: ([n 4] [m 5])
(ann (-: (*: m n) 20) Zero))
0)
(check-equal?
(let ()
(define-num: n 2)
(define-num: m -8)
(ann (-: (*: n -2 m) 32) Zero))
0)
;; -- /: ;; -- /:
(check-equal? (ann (/: 0 1) Zero) 0) (check-equal? (ann (/: 0 1) Zero) 0)
@ -54,6 +88,18 @@
(ann ((lambda ([f : (-> Integer Integer Exact-Rational)]) (f 1 1)) /:) Real) (ann ((lambda ([f : (-> Integer Integer Exact-Rational)]) (f 1 1)) /:) Real)
1) 1)
(check-equal?
(let-num: ([n 4] [m 12])
(ann (-: (/: m n) 3) Zero))
0)
(check-equal?
(let ()
(define-num: n 2)
(define-num: m -8)
(ann (+: (/: m n) 4) Zero))
0)
;; -- Nested ;; -- Nested
(check-equal? (check-equal?
@ -90,7 +136,20 @@
(ann (expt: (+: 5 -5) 78) Zero) (ann (expt: (+: 5 -5) 78) Zero)
0) 0)
(check-equal? (check-equal?
(ann (expt: (*: 2 2) (expt: 2 2)) Index) (ann (-: (expt: (*: 2 2) (expt: 2 2)) 256) Zero)
0)
(check-equal?
(ann (expt: (* 2 2) (expt: 2 2)) Natural) ;; Not an index
256) 256)
(check-equal?
(let-num: ([n1 5] [n2 4])
(ann (-: (expt: n1 n2) 625) Zero))
0)
(check-equal?
(let ()
(define-num: n1 8)
(define-num: n2 2)
(ann (-: (expt: n1 n2) 64) Zero))
0)
) )