From 62109f33db483fad110922c01b05763b2a9afe2b Mon Sep 17 00:00:00 2001 From: ben Date: Wed, 9 Mar 2016 00:55:31 -0500 Subject: [PATCH] [math] cleaner --- math.rkt | 90 ++-------------------------------- private/common.rkt | 63 +++++++++++++++--------- private/math.rkt | 119 +++++++++++++++++++++++++++++++++++++++++++++ test/math-fail.rkt | 1 + test/math-pass.rkt | 61 ++++++++++++++++++++++- 5 files changed, 225 insertions(+), 109 deletions(-) create mode 100644 private/math.rkt diff --git a/math.rkt b/math.rkt index 633fa54..c4d10a5 100644 --- a/math.rkt +++ b/math.rkt @@ -10,91 +10,9 @@ expt: - ;; -- - (for-syntax - nat/expand - int/expand - number/expand) + define-num: let-num: ) -(require (for-syntax - typed/racket/base - (only-in racket/format ~a) - (only-in racket/syntax format-id) - syntax/id-table - syntax/parse - trivial/private/common -)) - -;; ============================================================================= - -(begin-for-syntax - (define-syntax-class/predicate nat/expand exact-nonnegative-integer?) - (define-syntax-class/predicate int/expand integer?) - (define-syntax-class/predicate number/expand number?) -) - -(define-syntax make-numeric-operator - (syntax-parser - [(_ f:id) - #:with f: (format-id #'f "~a:" (syntax-e #'f)) - #'(define-syntax (f: stx) - (syntax-parse stx - [(g e* (... ...)) - #:with e+* (for/list ([e (in-list (syntax->list #'(e* (... ...))))]) - (expand-expr e)) - (let ([e++ (reduce/op f (syntax->list #'e+*) #:src stx)]) - (if (list? e++) - (quasisyntax/loc #'g (f #,@e++)) - (quasisyntax/loc #'g #,e++)))] - [g:id - (syntax/loc #'g f)] - [(g e* (... ...)) - (syntax/loc #'g (f e* (... ...)))]))])) - -(make-numeric-operator +) -(make-numeric-operator -) -(make-numeric-operator *) -(make-numeric-operator /) - -(define-syntax (expt: stx) - (syntax-parse stx - [(_ n1:number/expand n2:number/expand) - #:with n (expt (syntax-e #'n1.expanded) (syntax-e #'n2.expanded)) - (syntax/loc stx 'n)] - [_:id - (syntax/loc stx expt)] - [(_ e* ...) - (syntax/loc stx (expt e* ...))])) - -;; ----------------------------------------------------------------------------- - -(define-for-syntax (division-by-zero stx) - (raise-syntax-error '/ "division by zero" stx)) - -;; Simplify a list of expressions using an associative binary operator. -;; Return either: -;; - A numeric value -;; - A list of syntax objects, to be spliced back in the source code -(define-for-syntax (reduce/op op expr* #:src stx) - (let loop ([prev #f] ;; (U #f Number), candidate for reduction - [acc '()] ;; (Listof Syntax), irreducible arguments - [e* expr*]) ;; (Listof Syntax), arguments to process - (if (null? e*) - ;; then: finished, return a number (prev) or list of expressions (acc) - (if (null? acc) - prev - (reverse (if prev (cons prev acc) acc))) - ;; else: pop the next argument from e*, fold if it's a constant - (let ([v (quoted-stx-value? (car e*))]) - (if (number? v) - ;; then: reduce the number - (if prev - ;; Watch for division-by-zero - (if (and (zero? v) (eq? / op)) - (division-by-zero stx) - (loop (op prev v) acc (cdr e*))) - (loop v acc (cdr e*))) - ;; else: save value in acc - (let ([acc+ (cons (car e*) (if prev (cons prev acc) acc))]) - (loop #f acc+ (cdr e*))))) ))) +(require + (only-in trivial/private/math + +: -: *: /: expt: let-num: define-num:)) diff --git a/private/common.rkt b/private/common.rkt index 83bebb2..a59c104 100644 --- a/private/common.rkt +++ b/private/common.rkt @@ -11,7 +11,10 @@ ;; Otherwise, return #f. define-syntax-class/predicate - ;; (stx-> Identifier (-> Any Boolean) SyntaxClassDef) + ;; TODO + + lift-predicate + ;; TODO make-value-property ;; TODO @@ -25,7 +28,9 @@ syntax/parse syntax/id-table (for-syntax (only-in typed/racket/base let let-syntax #%app)) - (for-template (only-in typed/racket/base quote))) + (for-template + (prefix-in r: (only-in racket/base quote)) + (prefix-in tr: (only-in typed/racket/base quote)))) ;; ============================================================================= @@ -34,10 +39,10 @@ #:attributes (evidence expanded) (pattern e #:with e+ (expand-expr #'e) - #:with p+ (p? #'e+) - #:when (if (syntax-e #'p+) #t (begin (printf "ERROR we failed iwth ~a\n" (syntax->datum #'e+)) #f)) ;; TODO remove this - #:attr evidence #'p+ - #:attr expanded #'e+))) + #:with p+ (p? (syntax/loc #'e e+)) + #:when (syntax-e #'p+) + #:attr evidence (syntax/loc #'e p+) + #:attr expanded (syntax/loc #'e e+)))) (define (expand-expr stx) (local-expand stx 'expression '())) @@ -45,11 +50,20 @@ (define (quoted-stx-value? stx) (and (syntax? stx) - (syntax-case stx (quote) - [(quote v) + (syntax-parse stx #:literals (r:quote tr:quote) #:datum-literals (quote) + [((~or r:quote tr:quote quote) v) (syntax-e #'v)] [else #f]))) +(define (lift-predicate p?) + (lambda (stx) + (cond + [(p? stx) stx] + [(p? (syntax-e stx)) (syntax-e stx)] + [(p? (quoted-stx-value? stx)) + stx] + [else #f]))) + ;; In: ;; - name : Symbol, like format-spec or vector-length or db-schema ;; - parser : (Syntax -> Value) @@ -77,28 +91,28 @@ (lambda (stx) (syntax-parse stx [(_ name:id v) - #:with v+ (expand-expr #'v) - #:when (syntax-e #'v+) - #:with m (f-parse #'v+) - #:when (syntax-e #'m) - (free-id-table-set! #'name (syntax-e #'m)) + #:with v+ (expand-expr (syntax/loc stx v)) + #:when (syntax-e (syntax/loc stx v+)) + #:with m (f-parse (syntax/loc stx v+)) + #:when (syntax-e (syntax/loc stx m)) + #:with define-stx (format-id stx "define") + (free-id-table-set! tbl #'name (syntax-e #'m)) (syntax/loc stx - (define name v+))] + (define-stx name v+))] [_ #f]))) (define f-let (lambda (stx) (syntax-parse stx [(_ ([name*:id v*] ...) e* ...) - #:with (v+* ...) (map expand-expr (syntax-e #'(v* ...))) - #:when (andmap syntax-e (syntax-e #'(v+* ...))) - #:with (m* ...) (map f-parse (syntax-e #'(v+* ...))) - #:when (andmap syntax-e (syntax-e #'(m* ...))) + #:with (v+* ...) (map expand-expr (syntax-e (syntax/loc stx (v* ...)))) + #:with (m* ...) (map f-parse (syntax-e (syntax/loc stx (v+* ...)))) + #:when (andmap syntax-e (syntax-e (syntax/loc stx (m* ...)))) #:with let-stx (format-id stx "let") #:with let-syntax-stx (format-id stx "let-syntax") (quasisyntax/loc stx (let-stx ([name* v+*] ...) (let-syntax-stx ([name* (make-rename-transformer - (syntax-property #'name* '#,key 'm* ...))] ...) + (syntax-property #'name* '#,key 'm*))] ...) e* ...)))] [_ #f]))) (values @@ -107,8 +121,13 @@ f-define f-let)) -(define ((make-alias id-stx parser) stx) +(define ((make-alias id-sym parser) stx) (or (parser stx) (syntax-parse stx - [_:id (quasisyntax/loc stx #,id-stx)] - [(_ e* ...) (quasisyntax/loc stx (#,id-stx e* ...))]))) + [_:id + #:with id-stx (format-id stx "~a" id-sym) + (syntax/loc stx id-stx)] + [(_ e* ...) + #:with id-stx (format-id stx "~a" id-sym) + #:with app-stx (format-id stx "#%app") + (syntax/loc stx (app-stx id-stx e* ...))]))) diff --git a/private/math.rkt b/private/math.rkt new file mode 100644 index 0000000..b76e10c --- /dev/null +++ b/private/math.rkt @@ -0,0 +1,119 @@ +#lang typed/racket/base + +;; Constant-folding math operators. +;; Where possible, they simplify their arguments. + +;; TODO the or- stuff is not so pretty, but it's working anyway + +(provide + +: -: *: /: + ;; Same signature as the racket/base operators, + ;; but try to simplify arguments during expansion. + + expt: + + define-num: let-num: + + ;; -- + (for-syntax + nat/expand + int/expand + num/expand) +) + +(require (for-syntax + typed/racket/base + (only-in racket/format ~a) + (only-in racket/syntax format-id) + syntax/id-table + syntax/parse + trivial/private/common +)) + +;; ============================================================================= + +(begin-for-syntax + (define (division-by-zero stx) + (raise-syntax-error '/ "division by zero" stx)) + + ;; Simplify a list of expressions using an associative binary operator. + ;; Return either: + ;; - A numeric value + ;; - A list of syntax objects, to be spliced back in the source code + (define (reduce/op op stx) + (define expr* (syntax-e stx)) + (cond + [(list? expr*) + (let loop ([prev #f] ;; (U #f Number), candidate for reduction + [acc '()] ;; (Listof Syntax), irreducible arguments + [e* expr*]) ;; (Listof Syntax), arguments to process + (if (null? e*) + ;; then: finished, return a number (prev) or list of expressions (acc) + (if (null? acc) + prev + (reverse (if prev (cons prev acc) acc))) + ;; else: pop the next argument from e*, fold if it's a constant + (syntax-parse (car e*) + [v:num/expand + (define v (or (quoted-stx-value? #'v.expanded) + (quoted-stx-value? #'v.evidence))) + ;; then: reduce the number + (if prev + ;; Watch for division-by-zero + (if (and (zero? v) (eq? / op)) + (division-by-zero stx) + (loop (op prev v) acc (cdr e*))) + (loop v acc (cdr e*)))] + [v + ;; else: save value in acc + (let ([acc+ (cons (car e*) (if prev (cons prev acc) acc))]) + (loop #f acc+ (cdr e*)))])))] + [else #f])) + + (define-values (nat-key nat? nat-define nat-let) + (make-value-property 'number:natural (lift-predicate exact-nonnegative-integer?))) + (define-syntax-class/predicate nat/expand nat?) + + (define-values (int-key int? int-define int-let) + (make-value-property 'number:integer (lift-predicate integer?))) + (define-syntax-class/predicate int/expand int?) + + (define-values (num-key num? num-define num-let) + (make-value-property 'number:number (lift-predicate number?))) + (define-syntax-class/predicate num/expand num?) +) + +;; ----------------------------------------------------------------------------- + +(define-syntax define-num: (make-alias 'define num-define)) +(define-syntax let-num: (make-alias 'let num-let)) + +(define-syntax make-numeric-operator + (syntax-parser + [(_ f:id) + #:with f: (format-id #'f "~a:" (syntax-e #'f)) + #'(define-syntax f: (make-alias #'f + (lambda (stx) (syntax-parse stx + [(_ e* (... ...)) + #:with f-id (format-id stx "~a" 'f) + (let ([e+ (reduce/op f #'(e* (... ...)))]) + (if (list? e+) + (quasisyntax/loc stx (#%app f-id #,@e+)) + (quasisyntax/loc stx #,e+)))] + [_ #f]))))])) + +(make-numeric-operator +) +(make-numeric-operator -) +(make-numeric-operator *) +(make-numeric-operator /) + +(define-syntax expt: (make-alias 'expt + (lambda (stx) (syntax-parse stx + [(_ n1:num/expand n2:num/expand) + (let ([v1 (or (quoted-stx-value? #'n1.expanded) + (quoted-stx-value? #'n1.evidence))] + [v2 (or (quoted-stx-value? #'n2.expanded) + (quoted-stx-value? #'n2.evidence))]) + (and v1 v2 ;; Should never fail + (quasisyntax/loc stx #,(expt v1 v2))))] + [_ #f])))) diff --git a/test/math-fail.rkt b/test/math-fail.rkt index 617c730..332fdc3 100644 --- a/test/math-fail.rkt +++ b/test/math-fail.rkt @@ -14,6 +14,7 @@ (ann (let ([n 5]) (*: n 1/5 1)) One) (ann (let ([n 4]) (/: n n)) One) (ann (let ([n 2]) (expt: 3 (-: n n))) One) + (ann (expt: 3 2) Zero) ;; -- lambda => back to racket/base (ann ((lambda ([f : (-> Natural Natural Natural)]) (f 0 0)) +:) Zero) (ann ((lambda ([f : (-> Natural Natural Integer)]) (f 0 0)) -:) Zero) diff --git a/test/math-pass.rkt b/test/math-pass.rkt index 8fa1d91..4df45b8 100644 --- a/test/math-pass.rkt +++ b/test/math-pass.rkt @@ -19,6 +19,17 @@ (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) +:) Integer) 0) + (check-equal? + (let-num: ([n -4] [m 5]) + (ann (+: m n -1) Zero)) + 0) + + (check-equal? + (let () + (define-num: n 6) + (define-num: m -8) + (ann (+: n 2 m) Zero)) + 0) ;; -- -: (check-equal? (ann (-: 0 0) Zero) 0) @@ -31,6 +42,17 @@ (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) -:) Integer) 0) + (check-equal? + (let-num: ([n 4] [m 5]) + (ann (-: m n 1) Zero)) + 0) + + (check-equal? + (let () + (define-num: n 6) + (define-num: m -8) + (ann (-: n m 14) Zero)) + 0) ;; -- *: (check-equal? (ann (*: 0 1315) Zero) 0) @@ -43,6 +65,18 @@ (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) *:) Integer) 0) + (check-equal? + (let-num: ([n 4] [m 5]) + (ann (-: (*: m n) 20) Zero)) + 0) + + (check-equal? + (let () + (define-num: n 2) + (define-num: m -8) + (ann (-: (*: n -2 m) 32) Zero)) + 0) + ;; -- /: (check-equal? (ann (/: 0 1) Zero) 0) @@ -54,6 +88,18 @@ (ann ((lambda ([f : (-> Integer Integer Exact-Rational)]) (f 1 1)) /:) Real) 1) + (check-equal? + (let-num: ([n 4] [m 12]) + (ann (-: (/: m n) 3) Zero)) + 0) + + (check-equal? + (let () + (define-num: n 2) + (define-num: m -8) + (ann (+: (/: m n) 4) Zero)) + 0) + ;; -- Nested (check-equal? @@ -90,7 +136,20 @@ (ann (expt: (+: 5 -5) 78) Zero) 0) (check-equal? - (ann (expt: (*: 2 2) (expt: 2 2)) Index) + (ann (-: (expt: (*: 2 2) (expt: 2 2)) 256) Zero) + 0) + (check-equal? + (ann (expt: (* 2 2) (expt: 2 2)) Natural) ;; Not an index 256) + (check-equal? + (let-num: ([n1 5] [n2 4]) + (ann (-: (expt: n1 n2) 625) Zero)) + 0) + (check-equal? + (let () + (define-num: n1 8) + (define-num: n2 2) + (ann (-: (expt: n1 n2) 64) Zero)) + 0) )