diff --git a/math.rkt b/math.rkt index 4b60ce0..9b53475 100644 --- a/math.rkt +++ b/math.rkt @@ -1,7 +1,7 @@ #lang typed/racket/base (provide - +: ;-: *: /: + +: -: *: /: ;; Fold syntactic constants ) @@ -22,38 +22,44 @@ #:with f: (format-id #'f "~a:" (syntax-e #'f)) #'(define-syntax f: (syntax-parser - [g:id - (syntax/loc #'g f)] [(g e* (... ...)) #:with e+* (for/list ([e (in-list (syntax->list #'(e* (... ...))))]) (expand-expr e)) - #:with e++ (reduce/op f (syntax->list #'e+*) #:src #'g) - (syntax/loc #'g e++)] + (let ([e++ (reduce/op f (syntax->list #'e+*))]) + (if (list? e++) + (quasisyntax/loc #'g (f #,@e++)) + (quasisyntax/loc #'g #,e++)))] + [g:id + (syntax/loc #'g f)] [(g e* (... ...)) (syntax/loc #'g (f e* (... ...)))]))])) (make-numeric-operator +) +(make-numeric-operator -) +(make-numeric-operator *) +(make-numeric-operator /) ;; ----------------------------------------------------------------------------- -(define-for-syntax (reduce/op op e* #:src stx) +(define-for-syntax (reduce/op op e*) (let loop ([prev #f] [acc '()] [e* e*]) (if (null? e*) - ;; then: combine `prev` and `acc` into a list or single number - (cond - [(null? acc) - (quasisyntax/loc stx #,prev)] - [else - (let ([acc+ (reverse (if prev (cons prev acc) acc))]) - (quasisyntax/loc stx (#,op #,@acc+)))]) + ;; then: finished, return a number (prev) or list of expressions (acc) + (if (null? acc) + prev + (reverse (if prev (cons prev acc) acc))) ;; else: pop the next argument from e*, fold if it's a constant - (syntax-parse (car e*) - [n:number - (if prev - ;; eval? - (loop (op prev (car e*)) acc (cdr e*)) - (loop (car e*) acc (cdr e*)))] - [e - (loop #f (cons (car e*) (if prev (cons prev acc) acc)) (cdr e*))])))) + (let ([v (quoted-stx-value? (car e*))]) + (if (number? v) + ;; then: reduce the number + (if prev + ;; Watch for division-by-zero + (if (and (zero? v) (eq? / op)) + (loop v (cons prev acc) (cdr e*)) + (loop (op prev v) acc (cdr e*))) + (loop v acc (cdr e*))) + ;; else: save value in acc + (let ([acc+ (cons (car e*) (if prev (cons prev acc) acc))]) + (loop #f acc+ (cdr e*))))) ))) diff --git a/test/math-fail.rkt b/test/math-fail.rkt new file mode 100644 index 0000000..08b50a0 --- /dev/null +++ b/test/math-fail.rkt @@ -0,0 +1,33 @@ +#lang racket/base + +(define (expr->typed-module expr) + #`(module t typed/racket/base + (require trivial/math) + #,expr)) + +(define TEST-CASE* (map expr->typed-module '( + (ann (let ([n 2]) (+: n -2)) Zero) + (ann (let ([n 2]) (-: 2 n)) Zero) + (ann (let ([n 5]) (*: n 1/5 1)) One) + (ann (let ([n 4]) (/: n n)) One) + ;; -- lambda => back to racket/base + (ann ((lambda ([f : (-> Natural Natural Natural)]) (f 0 0)) +:) Zero) + (ann ((lambda ([f : (-> Natural Natural Integer)]) (f 0 0)) -:) Zero) + (ann ((lambda ([f : (-> Natural Natural Natural)]) (f 0 0)) *:) Zero) + (ann ((lambda ([f : (-> Natural Natural Exact-Rational)]) (f 0 0)) /:) Zero) + ;; -- dividing by zero => fall back to racket/base + (ann (/: 1 1 0) One) +))) + +(module+ test + (require + rackunit) + + (define (format-eval stx) + (lambda () ;; For `check-exn` + (compile-syntax stx))) + + (for ([rkt (in-list TEST-CASE*)]) + (check-exn #rx"format::|Type Checker" + (format-eval rkt))) +) diff --git a/test/math-pass.rkt b/test/math-pass.rkt new file mode 100644 index 0000000..82ea069 --- /dev/null +++ b/test/math-pass.rkt @@ -0,0 +1,85 @@ +#lang typed/racket/base + +(module+ test + (require + trivial/math + typed/rackunit + ) + + ;; -- +: + (check-equal? (ann (+: 0 0) Zero) 0) + (check-equal? (ann (+: 1 0) One) 1) + (check-equal? (ann (+: 0 1) One) 1) + (check-equal? (ann (+: 3 2) 5) 5) + (check-equal? (ann (+: 3 1 1) Natural) 5) + + (check-equal? + (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) +:) Integer) + 0) + + + ;; -- -: + (check-equal? (ann (-: 0 0) Zero) 0) + (check-equal? (ann (-: 1 1) Zero) 0) + (check-equal? (ann (-: 2 2) Zero) 0) + (check-equal? (ann (-: 99 97 2) Zero) 0) + (check-equal? (ann (-: 8 1 3 16) -12) -12) + + (check-equal? + (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) -:) Integer) + 0) + + + ;; -- *: + (check-equal? (ann (*: 0 1315) Zero) 0) + (check-equal? (ann (*: 11 0) Zero) 0) + (check-equal? (ann (*: 3 1 3) 9) 9) + (check-equal? (ann (*: -1 8 4) Negative-Integer) -32) + (check-equal? (ann (*: 5 1/5 1) One) 1) + + (check-equal? + (ann ((lambda ([f : (-> Integer Integer Integer)]) (f 0 0)) *:) Integer) + 0) + + + ;; -- /: + (check-equal? (ann (/: 0 1) Zero) 0) + (check-equal? (ann (/: 0 42) Zero) 0) + (check-equal? (ann (/: 0 1 2 3 4) Zero) 0) + (check-equal? (ann (/: 9 9) One) 1) + + ;; We do not catch this statically + (check-exn exn:fail:contract? + (lambda () (/: 3 0))) + + (check-equal? + (ann ((lambda ([f : (-> Integer Integer Exact-Rational)]) (f 1 1)) /:) Real) + 1) + + + ;; -- Nested + (check-equal? + (ann (+: (+: 1 1) (+: 1 1 1) 1) Index) + 6) + (check-equal? + (ann (*: (+: 9 1) (-: 6 3 2 1) 1) Zero) + 0) + (check-equal? + (ann (/: (+: 1 2 3 4) (+: (-: 3 2) (+: 1))) Natural) + 5) + + + ;; -- Operator works, but we can't fold constants + (let ([n 0]) + (check-equal? (ann (+: n 1 2 3 4) Natural) 10) + (check-equal? (ann (-: n n) Integer) 0) + (check-equal? (ann (*: n 8 1 4 13 1) Natural) 0) + (check-equal? (ann (/: n 1) Exact-Rational) 0)) + + (check-equal? (ann (let ([n 2]) (+: n -2)) Integer) 0) + (check-equal? (ann (let ([n 5]) (*: n 1/5 1)) Exact-Rational) 1) + (check-equal? (ann (let ([n 4]) (/: n n)) Positive-Exact-Rational) 1) + (check-exn #rx"division by zero" + (lambda () (ann (/: 0 0) Zero))) ;; Same for racket/base + +)