diff --git a/pkgs/redex-pkgs/redex-lib/redex/private/enumerator.rkt b/pkgs/redex-pkgs/redex-lib/redex/private/enumerator.rkt index 9b90eaa4e5..a4fc88e1d4 100644 --- a/pkgs/redex-pkgs/redex-lib/redex/private/enumerator.rkt +++ b/pkgs/redex-pkgs/redex-lib/redex/private/enumerator.rkt @@ -62,10 +62,19 @@ ;; Helper functions ;; map/e : (a -> b), (b -> a), enum a -> enum b -(define (map/e f inv-f e) - (enum (size e) - (compose f (enum-from e)) - (compose (enum-to e) inv-f))) +(define (map/e f inv-f e . es) + (cond [(empty? es) + (enum (size e) + (compose f (enum-from e)) + (compose (enum-to e) inv-f))] + [else + (define es/e (list/e (cons e es))) + (map/e + (λ (xs) + (apply f xs)) + (λ (ys) + (call-with-values (λ () (inv-f ys)) list)) + es/e)])) ;; filter/e : enum a, (a -> bool) -> enum a ;; size won't be accurate! @@ -298,24 +307,24 @@ (= 0 (size e2))) empty/e] [(not (infinite? (enum-size e1))) (cond [(not (infinite? (enum-size e2))) - (let ([size (* (enum-size e1) - (enum-size e2))]) - (enum size - (λ (n) ;; bijection from n -> axb - (if (> n size) - (error "out of range") - (call-with-values - (λ () - (quotient/remainder n (enum-size e2))) - (λ (q r) - (cons ((enum-from e1) q) - ((enum-from e2) r)))))) - (λ (xs) - (unless (pair? xs) - (error "not a pair")) - (+ (* (enum-size e1) - ((enum-to e1) (car xs))) - ((enum-to e2) (cdr xs))))))] + (define size (* (enum-size e1) + (enum-size e2))) + (enum size + (λ (n) ;; bijection from n -> axb + (if (> n size) + (error "out of range") + (call-with-values + (λ () + (quotient/remainder n (enum-size e2))) + (λ (q r) + (cons ((enum-from e1) q) + ((enum-from e2) r)))))) + (λ (xs) + (unless (pair? xs) + (error "not a pair")) + (define q (encode e1 (car xs))) + (define r (encode e2 (cdr xs))) + (+ (* (enum-size e2) q) r)))] [else (enum +inf.f (λ (n) diff --git a/pkgs/redex-pkgs/redex-test/redex/tests/enumerator-test.rkt b/pkgs/redex-pkgs/redex-test/redex/tests/enumerator-test.rkt index 6459e5e102..1cf66953fd 100644 --- a/pkgs/redex-pkgs/redex-test/redex/tests/enumerator-test.rkt +++ b/pkgs/redex-pkgs/redex-test/redex/tests/enumerator-test.rkt @@ -50,13 +50,6 @@ (λ () (decode nats -1)))) -#; -(define (nats+/e n) - (map/e (λ (k) - (+ k n)) - (λ (k) - (- k n)))) - ;; ints checks (test-begin (check-eq? (decode ints/e 0) 0) ; 0 -> 0 @@ -135,6 +128,7 @@ ;; cons/e tests (define bool*bool (cons/e bools/e bools/e)) (define 1*b (cons/e (const/e 1) bools/e)) +(define b*1 (cons/e bools/e (const/e 1))) (define bool*nats (cons/e bools/e nats)) (define nats*bool (cons/e nats bools/e)) (define nats*nats (cons/e nats nats)) @@ -151,6 +145,7 @@ (check-equal? (decode 1*b 0) (cons 1 #t)) (check-equal? (decode 1*b 1) (cons 1 #f)) (check-bijection? 1*b) + (check-bijection? b*1) (check-equal? (size bool*bool) 4) (check-equal? (decode bool*bool 0) (cons #t #t)) @@ -202,6 +197,17 @@ (cons 1 1)) (check-bijection? nats*nats)) +;; multi-arg map/e test +(define sums/e + (map/e + cons + (λ (x-y) + (values (car x-y) (cdr x-y))) + (from-list/e '(1 2)) + (from-list/e '(3 4)))) + +(test-begin + (check-bijection? sums/e)) ;; dep/e tests (define (up-to n) @@ -285,7 +291,6 @@ (define nats-to-2 (dep/e nats up-to)) - (test-begin (check-equal? (size 3-up-2) 6) (check-equal? (decode 3-up-2 0) (cons 0 0))