Clean up cons/e enumerator.

Isolates enumerator performance bottleneck.
This commit is contained in:
Max New 2013-12-05 22:51:16 -06:00
parent 78820bda45
commit 18e9bc1132
2 changed files with 83 additions and 109 deletions

View File

@ -83,8 +83,10 @@
(define (map/e f inv-f e . es) (define (map/e f inv-f e . es)
(cond [(empty? es) (cond [(empty? es)
(enum (size e) (enum (size e)
(compose f (enum-from e)) (λ (x)
(compose (enum-to e) inv-f))] (f (decode e x)))
(λ (n)
(encode e (inv-f n))))]
[else [else
(define es/e (list/e (cons e es))) (define es/e (list/e (cons e es)))
(map/e (map/e
@ -193,14 +195,14 @@
(hash-ref rev-map x))))) (hash-ref rev-map x)))))
(define nats/e (define nats/e
(enum +inf.f (enum +inf.0
identity identity
(λ (n) (λ (n)
(unless (>= n 0) (unless (>= n 0)
(redex-error 'encode "Not a natural")) (redex-error 'encode "Not a natural"))
n))) n)))
(define ints/e (define ints/e
(enum +inf.f (enum +inf.0
(λ (n) (λ (n)
(if (even? n) (if (even? n)
(* -1 (/ n 2)) (* -1 (/ n 2))
@ -353,84 +355,58 @@
(cons empty/e (λ (_) #f)) (cons empty/e (λ (_) #f))
(cons e-p e-ps)))) (cons e-p e-ps))))
;; cons/e : enum a, enum b -> enum (cons a b) (define (foldr1 f l)
(define cons/e (match l
(case-lambda [(cons x '()) x]
[(e) e] [(cons x xs) (f x (foldr1 f xs))]))
[(e1 e2)
(cond [(or (= 0 (size e1)) ;; cons/e : enum a, enum b ... -> enum (cons a b ...)
(= 0 (size e2))) empty/e] (define (cons/e e . es)
[(not (infinite? (enum-size e1))) (define (cons/e2 e1 e2)
(cond [(not (infinite? (enum-size e2))) (define s1 (enum-size e1))
(define size (* (enum-size e1) (define s2 (enum-size e2))
(enum-size e2))) (define size (* s1 s2))
(enum size (cond [(zero? size) empty/e]
(λ (n) ;; bijection from n -> axb [(or (not (infinite? s1))
(if (> n size) (not (infinite? s2)))
(redex-error 'decode "out of range") (define fst-finite? (not (infinite? s1)))
(call-with-values (define fin-size
(λ () (cond [fst-finite? s1]
(quotient/remainder n (enum-size e2))) [else s2]))
(λ (q r) (define (dec n)
(cons ((enum-from e1) q) (define-values (q r)
((enum-from e2) r)))))) (quotient/remainder n fin-size))
(λ (xs) (define x1 (decode e1 (if fst-finite? r q)))
(unless (pair? xs) (define x2 (decode e2 (if fst-finite? q r)))
(redex-error 'encode "not a pair")) (cons x1 x2))
(define q (encode e1 (car xs))) (define/match (enc p)
(define r (encode e2 (cdr xs))) [((cons x1 x2))
(+ (* (enum-size e2) q) r)))] (define n1 (encode e1 x1))
[else (define n2 (encode e2 x2))
(enum +inf.f (define q (if fst-finite? n2 n1))
(λ (n) (define r (if fst-finite? n1 n2))
(call-with-values (+ (* fin-size q)
(λ () r)])
(quotient/remainder n (enum-size e1))) (enum size dec enc)]
(λ (q r) [else
(cons ((enum-from e1) r) (define (dec n)
((enum-from e2) q))))) (define k (floor-untri n))
(λ (xs) (define t (tri k))
(unless (pair? xs) (define l (- n t))
(redex-error 'encode "not a pair")) (define m (- k l))
(+ ((enum-to e1) (car xs)) (define x1 (decode e1 l))
(* (enum-size e1) (define x2 (decode e2 m))
((enum-to e2) (cdr xs))))))])] (cons x1 x2))
[(not (infinite? (enum-size e2))) (define/match (enc p)
(enum +inf.f [((cons x1 x2))
(λ (n) (define l (encode e1 x1))
(call-with-values (define m (encode e2 x2))
(λ () (+ (/ (* (+ l m)
(quotient/remainder n (enum-size e2))) (+ l m 1))
(λ (q r) 2)
(cons ((enum-from e1) q) l)])
((enum-from e2) r))))) (enum size dec enc)]))
(λ (xs) (foldr1 cons/e2 (cons e es)))
(unless (pair? xs)
(redex-error 'encode "not a pair"))
(+ (* (enum-size e2)
((enum-to e1) (car xs)))
((enum-to e2) (cdr xs)))))]
[else
(enum (* (enum-size e1)
(enum-size e2))
(λ (n)
(let* ([k (floor-untri n)]
[t (tri k)]
[l (- n t)]
[m (- k l)])
(cons ((enum-from e1) l)
((enum-from e2) m))))
(λ (xs) ;; bijection from nxn -> n, inverse of previous
;; (n,m) -> (n+m)(n+m+1)/2 + n
(unless (pair? xs)
(redex-error 'encode "not a pair"))
(let ([l ((enum-to e1) (car xs))]
[m ((enum-to e2) (cdr xs))])
(+ (/ (* (+ l m) (+ l m 1))
2)
l))))])]
[(a b c . rest)
(cons/e a (apply cons/e b c rest))]))
;; Traversal (maybe come up with a better name ;; Traversal (maybe come up with a better name
;; traverse/e : (a -> enum b), (listof a) -> enum (listof b) ;; traverse/e : (a -> enum b), (listof a) -> enum (listof b)
@ -503,7 +479,7 @@
;; sizes : gvector int ;; sizes : gvector int
(let ([sizes (gvector first)]) (let ([sizes (gvector first)])
(enum (if (infinite? (size e)) (enum (if (infinite? (size e))
+inf.f +inf.0
(foldl (foldl
(λ (curSize acc) (λ (curSize acc)
(let ([sum (+ curSize acc)]) (let ([sum (+ curSize acc)])
@ -540,7 +516,7 @@
(+ sizeUpTo (+ sizeUpTo
(encode ei b))))))] (encode ei b))))))]
[(not (infinite? (size e))) [(not (infinite? (size e)))
(enum +inf.f (enum +inf.0
(λ (n) (λ (n)
(call-with-values (call-with-values
(λ () (λ ()
@ -552,7 +528,7 @@
(+ (* (size e) (encode (f (car ab)) (cdr ab))) (+ (* (size e) (encode (f (car ab)) (cdr ab)))
(encode e (car ab)))))] (encode e (car ab)))))]
[else ;; both infinite, same as cons/e [else ;; both infinite, same as cons/e
(enum +inf.f (enum +inf.0
(λ (n) (λ (n)
(let* ([k (floor-untri n)] (let* ([k (floor-untri n)]
[t (tri k)] [t (tri k)]
@ -628,7 +604,7 @@
(let* ([first (size (f (decode e 0)))] (let* ([first (size (f (decode e 0)))]
[sizes (gvector first)]) [sizes (gvector first)])
(enum (if (infinite? (size e)) (enum (if (infinite? (size e))
+inf.f +inf.0
(foldl (foldl
(λ (curSize acc) (λ (curSize acc)
(let ([sum (+ curSize acc)]) (let ([sum (+ curSize acc)])
@ -665,7 +641,7 @@
(+ sizeUpTo (+ sizeUpTo
(encode ei b))))))] (encode ei b))))))]
[(not (infinite? (size e))) [(not (infinite? (size e)))
(enum +inf.f (enum +inf.0
(λ (n) (λ (n)
(call-with-values (call-with-values
(λ () (λ ()
@ -677,7 +653,7 @@
(+ (* (size e) (encode (f (car ab)) (cdr ab))) (+ (* (size e) (encode (f (car ab)) (cdr ab)))
(encode e (car ab)))))] (encode e (car ab)))))]
[else ;; both infinite, same as cons/e [else ;; both infinite, same as cons/e
(enum +inf.f (enum +inf.0
(λ (n) (λ (n)
(let* ([k (floor-untri n)] (let* ([k (floor-untri n)]
[t (tri k)] [t (tri k)]
@ -774,7 +750,7 @@
(define fix-size (define fix-size
(if (= 0 (size e)) (if (= 0 (size e))
0 0
+inf.f)) +inf.0))
(fix/e fix-size (fix/e fix-size
(λ (self) (λ (self)
(disj-sum/e #:alternate? #t (disj-sum/e #:alternate? #t
@ -878,8 +854,6 @@
(cons (map/e - - from-1/e) (cons (map/e - - from-1/e)
(λ (n) (< n 0))))) (λ (n) (< n 0)))))
;; The last 3 here are -inf.0, +inf.0 and +nan.0
;; Consider moving those to the beginning
(define weird-flonums/e-p (define weird-flonums/e-p
(cons (from-list/e '(+inf.0 -inf.0 +nan.0)) (cons (from-list/e '(+inf.0 -inf.0 +nan.0))
(λ (n) (λ (n)
@ -937,7 +911,7 @@
(cons var/e symbol?))) (cons var/e symbol?)))
(define any/e (define any/e
(fix/e +inf.f (fix/e +inf.0
(λ (any/e) (λ (any/e)
(disj-sum/e #:alternate? #t (disj-sum/e #:alternate? #t
(cons base/e (negate pair?)) (cons base/e (negate pair?))

View File

@ -40,7 +40,7 @@
(define nats+1 (nats+/e 1)) (define nats+1 (nats+/e 1))
(test-begin (test-begin
(check-equal? (size nats+1) +inf.f) (check-equal? (size nats+1) +inf.0)
(check-equal? (decode nats+1 0) 1) (check-equal? (decode nats+1 0) 1)
(check-equal? (decode nats+1 1) 2) (check-equal? (decode nats+1 1) 2)
(check-bijection? nats+1)) (check-bijection? nats+1))
@ -63,7 +63,7 @@
;; sum tests ;; sum tests
(define evens/e (define evens/e
(enum +inf.f (enum +inf.0
(λ (n) (λ (n)
(* 2 n)) (* 2 n))
(λ (n) (λ (n)
@ -73,7 +73,7 @@
(error 'even))))) (error 'even)))))
(define odds/e (define odds/e
(enum +inf.f (enum +inf.0
(λ (n) (λ (n)
(+ (* 2 n) 1)) (+ (* 2 n) 1))
(λ (n) (λ (n)
@ -106,13 +106,13 @@
(check-bijection? bool-or-num) (check-bijection? bool-or-num)
(check-equal? (size bool-or-nat) (check-equal? (size bool-or-nat)
+inf.f) +inf.0)
(check-equal? (decode bool-or-nat 0) #t) (check-equal? (decode bool-or-nat 0) #t)
(check-equal? (decode bool-or-nat 1) 0) (check-equal? (decode bool-or-nat 1) 0)
(check-bijection? bool-or-nat) (check-bijection? bool-or-nat)
(check-equal? (size odd-or-even) (check-equal? (size odd-or-even)
+inf.f) +inf.0)
(check-equal? (decode odd-or-even 0) 0) (check-equal? (decode odd-or-even 0) 0)
(check-equal? (decode odd-or-even 1) 1) (check-equal? (decode odd-or-even 1) 1)
(check-equal? (decode odd-or-even 2) 2) (check-equal? (decode odd-or-even 2) 2)
@ -160,13 +160,13 @@
(check-bijection? bool-or-num) (check-bijection? bool-or-num)
(check-equal? (size bool-or-nat) (check-equal? (size bool-or-nat)
+inf.f) +inf.0)
(check-equal? (decode bool-or-nat 0) #t) (check-equal? (decode bool-or-nat 0) #t)
(check-equal? (decode bool-or-nat 1) 0) (check-equal? (decode bool-or-nat 1) 0)
(check-bijection? bool-or-nat) (check-bijection? bool-or-nat)
(check-equal? (size odd-or-even) (check-equal? (size odd-or-even)
+inf.f) +inf.0)
(check-equal? (decode odd-or-even 0) 0) (check-equal? (decode odd-or-even 0) 0)
(check-equal? (decode odd-or-even 1) 1) (check-equal? (decode odd-or-even 1) 1)
(check-equal? (decode odd-or-even 2) 2) (check-equal? (decode odd-or-even 2) 2)
@ -205,7 +205,7 @@
(check-bijection? bool-or-num) (check-bijection? bool-or-num)
(check-equal? (size bool-or-nat) (check-equal? (size bool-or-nat)
+inf.f) +inf.0)
(check-equal? (decode bool-or-nat 0) #t) (check-equal? (decode bool-or-nat 0) #t)
(check-equal? (decode bool-or-nat 1) #f) (check-equal? (decode bool-or-nat 1) #f)
(check-equal? (decode bool-or-nat 2) 0) (check-equal? (decode bool-or-nat 2) 0)
@ -236,14 +236,14 @@
(check-equal? (decode bool*bool 0) (check-equal? (decode bool*bool 0)
(cons #t #t)) (cons #t #t))
(check-equal? (decode bool*bool 1) (check-equal? (decode bool*bool 1)
(cons #t #f))
(check-equal? (decode bool*bool 2)
(cons #f #t)) (cons #f #t))
(check-equal? (decode bool*bool 2)
(cons #t #f))
(check-equal? (decode bool*bool 3) (check-equal? (decode bool*bool 3)
(cons #f #f)) (cons #f #f))
(check-bijection? bool*bool) (check-bijection? bool*bool)
(check-equal? (size bool*nats) +inf.f) (check-equal? (size bool*nats) +inf.0)
(check-equal? (decode bool*nats 0) (check-equal? (decode bool*nats 0)
(cons #t 0)) (cons #t 0))
(check-equal? (decode bool*nats 1) (check-equal? (decode bool*nats 1)
@ -254,7 +254,7 @@
(cons #f 1)) (cons #f 1))
(check-bijection? bool*nats) (check-bijection? bool*nats)
(check-equal? (size nats*bool) +inf.f) (check-equal? (size nats*bool) +inf.0)
(check-equal? (decode nats*bool 0) (check-equal? (decode nats*bool 0)
(cons 0 #t)) (cons 0 #t))
(check-equal? (decode nats*bool 1) (check-equal? (decode nats*bool 1)
@ -265,7 +265,7 @@
(cons 1 #f)) (cons 1 #f))
(check-bijection? nats*bool) (check-bijection? nats*bool)
(check-equal? (size nats*nats) +inf.f) (check-equal? (size nats*nats) +inf.0)
(check ns-equal? (check ns-equal?
(decode nats*nats 0) (decode nats*nats 0)
(cons 0 0)) (cons 0 0))
@ -325,7 +325,7 @@
(check-equal? (decode 3-up 5) (cons 2 2)) (check-equal? (decode 3-up 5) (cons 2 2))
(check-bijection? 3-up) (check-bijection? 3-up)
(check-equal? (size from-3) +inf.f) (check-equal? (size from-3) +inf.0)
(check-equal? (decode from-3 0) (cons 0 0)) (check-equal? (decode from-3 0) (cons 0 0))
(check-equal? (decode from-3 1) (cons 1 1)) (check-equal? (decode from-3 1) (cons 1 1))
(check-equal? (decode from-3 2) (cons 2 2)) (check-equal? (decode from-3 2) (cons 2 2))
@ -335,7 +335,7 @@
(check-equal? (decode from-3 6) (cons 0 2)) (check-equal? (decode from-3 6) (cons 0 2))
(check-bijection? from-3) (check-bijection? from-3)
(check-equal? (size nats-to) +inf.f) (check-equal? (size nats-to) +inf.0)
(check-equal? (decode nats-to 0) (cons 0 0)) (check-equal? (decode nats-to 0) (cons 0 0))
(check-equal? (decode nats-to 1) (cons 1 0)) (check-equal? (decode nats-to 1) (cons 1 0))
(check-equal? (decode nats-to 2) (cons 1 1)) (check-equal? (decode nats-to 2) (cons 1 1))
@ -345,7 +345,7 @@
(check-equal? (decode nats-to 6) (cons 3 0)) (check-equal? (decode nats-to 6) (cons 3 0))
(check-bijection? nats-to) (check-bijection? nats-to)
(check-equal? (size nats-up) +inf.f) (check-equal? (size nats-up) +inf.0)
(check-equal? (decode nats-up 0) (cons 0 0)) (check-equal? (decode nats-up 0) (cons 0 0))
(check-equal? (decode nats-up 1) (cons 0 1)) (check-equal? (decode nats-up 1) (cons 0 1))
(check-equal? (decode nats-up 2) (cons 1 1)) (check-equal? (decode nats-up 2) (cons 1 1))
@ -391,7 +391,7 @@
(check-equal? (encode 3-up-2 (cons 1 1)) 2) (check-equal? (encode 3-up-2 (cons 1 1)) 2)
(check-equal? (encode 3-up-2 (cons 2 0)) 3) (check-equal? (encode 3-up-2 (cons 2 0)) 3)
(check-equal? (size nats-to-2) +inf.f) (check-equal? (size nats-to-2) +inf.0)
(check-equal? (encode nats-to-2 (cons 0 0)) 0) (check-equal? (encode nats-to-2 (cons 0 0)) 0)
(check-equal? (encode nats-to-2 (cons 1 0)) 1) (check-equal? (encode nats-to-2 (cons 1 0)) 1)
(check-equal? (encode nats-to-2 (cons 1 1)) 2) (check-equal? (encode nats-to-2 (cons 1 1)) 2)