Clean up cons/e enumerator.

Isolates enumerator performance bottleneck.
This commit is contained in:
Max New 2013-12-05 22:51:16 -06:00
parent 78820bda45
commit 18e9bc1132
2 changed files with 83 additions and 109 deletions

View File

@ -83,8 +83,10 @@
(define (map/e f inv-f e . es)
(cond [(empty? es)
(enum (size e)
(compose f (enum-from e))
(compose (enum-to e) inv-f))]
(λ (x)
(f (decode e x)))
(λ (n)
(encode e (inv-f n))))]
[else
(define es/e (list/e (cons e es)))
(map/e
@ -193,14 +195,14 @@
(hash-ref rev-map x)))))
(define nats/e
(enum +inf.f
(enum +inf.0
identity
(λ (n)
(unless (>= n 0)
(redex-error 'encode "Not a natural"))
n)))
(define ints/e
(enum +inf.f
(enum +inf.0
(λ (n)
(if (even? n)
(* -1 (/ n 2))
@ -353,84 +355,58 @@
(cons empty/e (λ (_) #f))
(cons e-p e-ps))))
;; cons/e : enum a, enum b -> enum (cons a b)
(define cons/e
(case-lambda
[(e) e]
[(e1 e2)
(cond [(or (= 0 (size e1))
(= 0 (size e2))) empty/e]
[(not (infinite? (enum-size e1)))
(cond [(not (infinite? (enum-size e2)))
(define size (* (enum-size e1)
(enum-size e2)))
(enum size
(λ (n) ;; bijection from n -> axb
(if (> n size)
(redex-error 'decode "out of range")
(call-with-values
(λ ()
(quotient/remainder n (enum-size e2)))
(λ (q r)
(cons ((enum-from e1) q)
((enum-from e2) r))))))
(λ (xs)
(unless (pair? xs)
(redex-error 'encode "not a pair"))
(define q (encode e1 (car xs)))
(define r (encode e2 (cdr xs)))
(+ (* (enum-size e2) q) r)))]
(define (foldr1 f l)
(match l
[(cons x '()) x]
[(cons x xs) (f x (foldr1 f xs))]))
;; cons/e : enum a, enum b ... -> enum (cons a b ...)
(define (cons/e e . es)
(define (cons/e2 e1 e2)
(define s1 (enum-size e1))
(define s2 (enum-size e2))
(define size (* s1 s2))
(cond [(zero? size) empty/e]
[(or (not (infinite? s1))
(not (infinite? s2)))
(define fst-finite? (not (infinite? s1)))
(define fin-size
(cond [fst-finite? s1]
[else s2]))
(define (dec n)
(define-values (q r)
(quotient/remainder n fin-size))
(define x1 (decode e1 (if fst-finite? r q)))
(define x2 (decode e2 (if fst-finite? q r)))
(cons x1 x2))
(define/match (enc p)
[((cons x1 x2))
(define n1 (encode e1 x1))
(define n2 (encode e2 x2))
(define q (if fst-finite? n2 n1))
(define r (if fst-finite? n1 n2))
(+ (* fin-size q)
r)])
(enum size dec enc)]
[else
(enum +inf.f
(λ (n)
(call-with-values
(λ ()
(quotient/remainder n (enum-size e1)))
(λ (q r)
(cons ((enum-from e1) r)
((enum-from e2) q)))))
(λ (xs)
(unless (pair? xs)
(redex-error 'encode "not a pair"))
(+ ((enum-to e1) (car xs))
(* (enum-size e1)
((enum-to e2) (cdr xs))))))])]
[(not (infinite? (enum-size e2)))
(enum +inf.f
(λ (n)
(call-with-values
(λ ()
(quotient/remainder n (enum-size e2)))
(λ (q r)
(cons ((enum-from e1) q)
((enum-from e2) r)))))
(λ (xs)
(unless (pair? xs)
(redex-error 'encode "not a pair"))
(+ (* (enum-size e2)
((enum-to e1) (car xs)))
((enum-to e2) (cdr xs)))))]
[else
(enum (* (enum-size e1)
(enum-size e2))
(λ (n)
(let* ([k (floor-untri n)]
[t (tri k)]
[l (- n t)]
[m (- k l)])
(cons ((enum-from e1) l)
((enum-from e2) m))))
(λ (xs) ;; bijection from nxn -> n, inverse of previous
;; (n,m) -> (n+m)(n+m+1)/2 + n
(unless (pair? xs)
(redex-error 'encode "not a pair"))
(let ([l ((enum-to e1) (car xs))]
[m ((enum-to e2) (cdr xs))])
(+ (/ (* (+ l m) (+ l m 1))
(define (dec n)
(define k (floor-untri n))
(define t (tri k))
(define l (- n t))
(define m (- k l))
(define x1 (decode e1 l))
(define x2 (decode e2 m))
(cons x1 x2))
(define/match (enc p)
[((cons x1 x2))
(define l (encode e1 x1))
(define m (encode e2 x2))
(+ (/ (* (+ l m)
(+ l m 1))
2)
l))))])]
[(a b c . rest)
(cons/e a (apply cons/e b c rest))]))
l)])
(enum size dec enc)]))
(foldr1 cons/e2 (cons e es)))
;; Traversal (maybe come up with a better name
;; traverse/e : (a -> enum b), (listof a) -> enum (listof b)
@ -503,7 +479,7 @@
;; sizes : gvector int
(let ([sizes (gvector first)])
(enum (if (infinite? (size e))
+inf.f
+inf.0
(foldl
(λ (curSize acc)
(let ([sum (+ curSize acc)])
@ -540,7 +516,7 @@
(+ sizeUpTo
(encode ei b))))))]
[(not (infinite? (size e)))
(enum +inf.f
(enum +inf.0
(λ (n)
(call-with-values
(λ ()
@ -552,7 +528,7 @@
(+ (* (size e) (encode (f (car ab)) (cdr ab)))
(encode e (car ab)))))]
[else ;; both infinite, same as cons/e
(enum +inf.f
(enum +inf.0
(λ (n)
(let* ([k (floor-untri n)]
[t (tri k)]
@ -628,7 +604,7 @@
(let* ([first (size (f (decode e 0)))]
[sizes (gvector first)])
(enum (if (infinite? (size e))
+inf.f
+inf.0
(foldl
(λ (curSize acc)
(let ([sum (+ curSize acc)])
@ -665,7 +641,7 @@
(+ sizeUpTo
(encode ei b))))))]
[(not (infinite? (size e)))
(enum +inf.f
(enum +inf.0
(λ (n)
(call-with-values
(λ ()
@ -677,7 +653,7 @@
(+ (* (size e) (encode (f (car ab)) (cdr ab)))
(encode e (car ab)))))]
[else ;; both infinite, same as cons/e
(enum +inf.f
(enum +inf.0
(λ (n)
(let* ([k (floor-untri n)]
[t (tri k)]
@ -774,7 +750,7 @@
(define fix-size
(if (= 0 (size e))
0
+inf.f))
+inf.0))
(fix/e fix-size
(λ (self)
(disj-sum/e #:alternate? #t
@ -878,8 +854,6 @@
(cons (map/e - - from-1/e)
(λ (n) (< n 0)))))
;; The last 3 here are -inf.0, +inf.0 and +nan.0
;; Consider moving those to the beginning
(define weird-flonums/e-p
(cons (from-list/e '(+inf.0 -inf.0 +nan.0))
(λ (n)
@ -937,7 +911,7 @@
(cons var/e symbol?)))
(define any/e
(fix/e +inf.f
(fix/e +inf.0
(λ (any/e)
(disj-sum/e #:alternate? #t
(cons base/e (negate pair?))

View File

@ -40,7 +40,7 @@
(define nats+1 (nats+/e 1))
(test-begin
(check-equal? (size nats+1) +inf.f)
(check-equal? (size nats+1) +inf.0)
(check-equal? (decode nats+1 0) 1)
(check-equal? (decode nats+1 1) 2)
(check-bijection? nats+1))
@ -63,7 +63,7 @@
;; sum tests
(define evens/e
(enum +inf.f
(enum +inf.0
(λ (n)
(* 2 n))
(λ (n)
@ -73,7 +73,7 @@
(error 'even)))))
(define odds/e
(enum +inf.f
(enum +inf.0
(λ (n)
(+ (* 2 n) 1))
(λ (n)
@ -106,13 +106,13 @@
(check-bijection? bool-or-num)
(check-equal? (size bool-or-nat)
+inf.f)
+inf.0)
(check-equal? (decode bool-or-nat 0) #t)
(check-equal? (decode bool-or-nat 1) 0)
(check-bijection? bool-or-nat)
(check-equal? (size odd-or-even)
+inf.f)
+inf.0)
(check-equal? (decode odd-or-even 0) 0)
(check-equal? (decode odd-or-even 1) 1)
(check-equal? (decode odd-or-even 2) 2)
@ -160,13 +160,13 @@
(check-bijection? bool-or-num)
(check-equal? (size bool-or-nat)
+inf.f)
+inf.0)
(check-equal? (decode bool-or-nat 0) #t)
(check-equal? (decode bool-or-nat 1) 0)
(check-bijection? bool-or-nat)
(check-equal? (size odd-or-even)
+inf.f)
+inf.0)
(check-equal? (decode odd-or-even 0) 0)
(check-equal? (decode odd-or-even 1) 1)
(check-equal? (decode odd-or-even 2) 2)
@ -205,7 +205,7 @@
(check-bijection? bool-or-num)
(check-equal? (size bool-or-nat)
+inf.f)
+inf.0)
(check-equal? (decode bool-or-nat 0) #t)
(check-equal? (decode bool-or-nat 1) #f)
(check-equal? (decode bool-or-nat 2) 0)
@ -236,14 +236,14 @@
(check-equal? (decode bool*bool 0)
(cons #t #t))
(check-equal? (decode bool*bool 1)
(cons #t #f))
(check-equal? (decode bool*bool 2)
(cons #f #t))
(check-equal? (decode bool*bool 2)
(cons #t #f))
(check-equal? (decode bool*bool 3)
(cons #f #f))
(check-bijection? bool*bool)
(check-equal? (size bool*nats) +inf.f)
(check-equal? (size bool*nats) +inf.0)
(check-equal? (decode bool*nats 0)
(cons #t 0))
(check-equal? (decode bool*nats 1)
@ -254,7 +254,7 @@
(cons #f 1))
(check-bijection? bool*nats)
(check-equal? (size nats*bool) +inf.f)
(check-equal? (size nats*bool) +inf.0)
(check-equal? (decode nats*bool 0)
(cons 0 #t))
(check-equal? (decode nats*bool 1)
@ -265,7 +265,7 @@
(cons 1 #f))
(check-bijection? nats*bool)
(check-equal? (size nats*nats) +inf.f)
(check-equal? (size nats*nats) +inf.0)
(check ns-equal?
(decode nats*nats 0)
(cons 0 0))
@ -325,7 +325,7 @@
(check-equal? (decode 3-up 5) (cons 2 2))
(check-bijection? 3-up)
(check-equal? (size from-3) +inf.f)
(check-equal? (size from-3) +inf.0)
(check-equal? (decode from-3 0) (cons 0 0))
(check-equal? (decode from-3 1) (cons 1 1))
(check-equal? (decode from-3 2) (cons 2 2))
@ -335,7 +335,7 @@
(check-equal? (decode from-3 6) (cons 0 2))
(check-bijection? from-3)
(check-equal? (size nats-to) +inf.f)
(check-equal? (size nats-to) +inf.0)
(check-equal? (decode nats-to 0) (cons 0 0))
(check-equal? (decode nats-to 1) (cons 1 0))
(check-equal? (decode nats-to 2) (cons 1 1))
@ -345,7 +345,7 @@
(check-equal? (decode nats-to 6) (cons 3 0))
(check-bijection? nats-to)
(check-equal? (size nats-up) +inf.f)
(check-equal? (size nats-up) +inf.0)
(check-equal? (decode nats-up 0) (cons 0 0))
(check-equal? (decode nats-up 1) (cons 0 1))
(check-equal? (decode nats-up 2) (cons 1 1))
@ -391,7 +391,7 @@
(check-equal? (encode 3-up-2 (cons 1 1)) 2)
(check-equal? (encode 3-up-2 (cons 2 0)) 3)
(check-equal? (size nats-to-2) +inf.f)
(check-equal? (size nats-to-2) +inf.0)
(check-equal? (encode nats-to-2 (cons 0 0)) 0)
(check-equal? (encode nats-to-2 (cons 1 0)) 1)
(check-equal? (encode nats-to-2 (cons 1 1)) 2)