Added Redex enumerators.

Supports names and recursive patterns.
Limited support for repeats and mismatches.
This commit is contained in:
Max New 2013-05-09 15:02:03 -05:00
parent 1085045973
commit b8538ec135
2 changed files with 1375 additions and 8 deletions

View File

@ -1,6 +1,11 @@
#lang racket/base
(require racket/contract
"lang-struct.rkt")
racket/list
racket/match
racket/function
"lang-struct.rkt"
"match-a-pattern.rkt"
"enumerator.rkt")
(provide
(contract-out
@ -11,11 +16,512 @@
[enum-ith (-> enum? exact-nonnegative-integer? any/c)]
[enum? (-> any/c boolean?)]))
(struct enum ())
(define (lang-enumerators nts) (make-hash))
(define (pat-enumerator li p)
(unless (equal? p '(name natural natural))
(error 'enum.rkt "not yet implemented"))
(enum))
(define (enum-ith e i) i)
(define (lang-enumerators nts)
(let* ([l-enums (make-hash)]
[rec-nt-terms (find-recs nts)]
[sorted-nts (sort-nt-terms nts rec-nt-terms)])
(foldl
(λ (nt m)
(hash-set
m
(nt-name nt)
(with-handlers
([exn:fail? fail/enum])
(rec-pat/enum `(nt ,(nt-name nt))
sorted-nts
rec-nt-terms))))
(hash)
sorted-nts)))
(define enum-ith decode)
(struct decomposition (ctx term))
(define (pat-enumerator lang-enums pat)
(enum-names pat
(sep-names pat)
lang-enums))
(define (rec-pat/enum pat nts rec-nt-terms)
(enum-names pat
nts
(sep-names pat)
rec-nt-terms))
;; find-recs : lang -> (hash symbol -o> (assoclist rhs bool))
;; Identifies which non-terminals are recursive
(define (find-recs nt-pats)
(define is-rec?
(case-lambda
[(n) (is-rec? n (hash))]
[(nt seen)
(or (seen? seen (nt-name nt))
(ormap
(λ (rhs)
(let rec ([pat (rhs-pattern rhs)])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(is-rec? (make-nt
id
(lookup nt-pats id))
(add-seen seen
(nt-name nt)))]
[`(name ,name ,pat)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(error 'unsupported "side-condition")]
[`(cross ,s)
(error 'unsupported "cross")] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
(nt-rhs nt)))]))
(define (calls-rec? rhs recs)
(let rec ([pat (rhs-pattern rhs)])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(hash-ref recs id)]
[`(name ,name ,pat)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(error 'no-enum "side-condition")]
[`(cross ,s)
(error 'no-enum "cross")] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
(define (seen? m s)
(hash-ref m s #f))
(define (add-seen m s)
(hash-set m s #t))
(let ([recs
(foldl
(λ (nt m)
(hash-set m (nt-name nt) (is-rec? nt)))
(hash) nt-pats)])
(foldl
(λ (nt m)
(let ([rhs (nt-rhs nt)])
(hash-set m (nt-name nt)
(map (λ (rhs)
(cons rhs (calls-rec? rhs recs)))
rhs))))
(hash)
nt-pats)))
;; sort-nt-terms : lang (hash symbol -o> (assoclist rhs bool)) -> lang
(define (sort-nt-terms nt-pats recs)
(map
(λ (nt)
(let ([rec-nts (hash-ref recs (nt-name nt))])
(make-nt (nt-name nt)
(sort (nt-rhs nt)
(λ (r1 r2)
(and (not (cdr (assoc r1 rec-nts)))
(cdr (assoc r2 rec-nts))))))))
nt-pats))
;; sep-names : single-pattern lang -> (assoclist symbol pattern)
(define (sep-names pat)
(let loop ([pat pat]
[named-pats '()])
(match-a-pattern
pat
[`any named-pats]
[`number named-pats]
[`string named-pats]
[`natural named-pats]
[`integer named-pats]
[`real named-pats]
[`boolean named-pats]
[`variable named-pats]
[`(variable-except ,s ...) named-pats]
[`(variable-prefix ,s) named-pats]
[`variable-not-otherwise-mentioned named-pats]
[`hole named-pats]
;; names inside nts are separate
[`(nt ,id) named-pats]
[`(name ,name ,pat)
(loop pat
(add-if-new name pat named-pats))]
[`(mismatch-name ,name ,pat)
(loop pat
(add-if-new name pat named-pats))]
[`(in-hole ,p1 ,p2)
(loop p2
(loop p1 named-pats))]
[`(hide-hole ,p) (loop p named-pats)]
[`(side-condition ,p ,g ,e) ;; error
(error 'no-enum "side condition")]
[`(cross ,s)
(error 'no-enum "cross")] ;; error
[`(list ,sub-pats ...)
(foldl (λ (sub-pat named-pats)
(match sub-pat
;; unnamed repeat
[`(repeat ,pat #f #f)
(loop pat named-pats)]
;; named repeat
[`(repeat ,pat ,name #f)
(loop pat
(add-if-new name 'name-r named-pats))]
;; mismatch named repeat
[`(repeat ,pat #f ,mismatch)
(loop pat
(add-if-new mismatch 'mismatch-r named-pats))]
;; normal subpattern
[else (loop sub-pat named-pats)]))
named-pats
sub-pats)]
[(? (compose not pair?))
named-pats])))
(define (add-if-new k v l)
(cond [(assoc k l) l]
[else (cons `(,k ,v) l)]))
(define enum-names
(case-lambda
[(pat named-pats nts)
(enum-names-with
(λ (pat named)
(pat/enum-with-names pat nts named))
pat named-pats)]
[(pat nts named-pats rec-nt-terms)
(enum-names-with
(λ (pat named)
(pat/enum-with-names pat nts named rec-nt-terms))
pat named-pats)]))
(define (enum-names-with f pat named-pats)
(let rec ([named-pats named-pats]
[env (hash)])
(cond [(null? named-pats) (f pat env)]
[else
(match
(car named-pats)
;; named repeat
[`(,name name-r)
(error 'unimplemented "named-repeat")]
;; mismatch repeat
[`(,name mismatch-r)
(error 'unimplemented "mismatch-repeat")]
[`(,name ,pat mismatch)
(error 'unimplemented "mismatch")]
;; named
[`(,name ,pat)
(map/enum ;; loses bijection
cdr
(λ (x) (cons name x))
(dep/enum
(f pat env)
(λ (term)
(rec (cdr named-pats)
(hash-set env
name
term)))))]
[else (error 'bad-assoc)])])))
(define pat/enum-with-names
(case-lambda
[(pat nt-enums named-terms)
(let loop ([pat pat])
(match-a-pattern
pat
[`any
(sum/enum
any/enum
(listof/enum any/enum))]
[`number num/enum]
[`string string/enum]
[`natural natural/enum]
[`integer integer/enum]
[`real real/enum]
[`boolean bool/enum]
[`variable var/enum]
[`(variable-except ,s ...)
;; todo
(error 'unimplemented "var-except")]
[`(variable-prefix ,s)
;; todo
(error 'unimplemented "var-prefix")]
[`variable-not-otherwise-mentioned
(error 'unimplemented "var-not-mentioned")] ;; error
[`hole
(const/enum 'hole)]
[`(nt ,id)
(hash-ref nt-enums id)]
[`(name ,name ,pat)
(const/enum (hash-ref named-terms name))]
[`(mismatch-name ,name ,pat)
(error 'unimplemented "mismatch-name")]
[`(in-hole ,p1 ,p2) ;; untested
(map/enum
(λ (t1-t2) ;; loses bijection
(plug-hole (car t1-t2)
(cdr t1-t2)))
(λ (plugged)
(cons 'hole plugged))
(prod/enum
(loop p1)
(loop p2)))]
[`(hide-hole ,p)
(loop p)]
[`(side-condition ,p ,g ,e)
(error 'no-enum "side condition")]
[`(cross ,s)
(error 'no-enum "cross")]
[`(list ,sub-pats ...)
;; enum-list
(map/enum
flatten-1
identity
(list/enum
(map
(λ (sub-pat)
(match sub-pat
[`(repeat ,pat #f #f)
(map/enum
cdr
(λ (ts)
(cons (length ts)
ts))
(dep/enum
nats
(λ (n)
(list/enum
(build-list n (const (loop pat)))))))]
[`(repeat ,pat ,name #f)
(error 'unimplemented "named-repeat")]
[`(repeat ,pat #f ,mismatch)
(error 'unimplemented "mismatch-repeat")]
[else (loop sub-pat)]))
sub-pats)))]
[(? (compose not pair?))
(const/enum pat)]))]
[(pat nts named-terms rec-nt-terms)
(let loop ([pat pat])
(match-a-pattern
pat
[`any
(sum/enum
any/enum
(listof/enum any/enum))]
[`number num/enum]
[`string string/enum]
[`natural natural/enum]
[`integer integer/enum]
[`real real/enum]
[`boolean bool/enum]
[`variable var/enum]
[`(variable-except ,s ...)
;; todo
(error 'unimplemented "var except")]
[`(variable-prefix ,s)
;; todo
(error 'unimplemented "var prefix")]
[`variable-not-otherwise-mentioned
(error 'unimplemented "var not otherwise mentioned")]
[`hole
(const/enum 'hole)]
[`(nt ,id)
(let ([rhss (lookup nts id)])
(apply sum/enum
(map
(λ (rhs)
(cond [(cdr (assoc rhs (hash-ref rec-nt-terms id)))
(thunk/enum
+inf.f
(λ ()
(rec-pat/enum (rhs-pattern rhs)
nts
rec-nt-terms)))]
[else
(rec-pat/enum (rhs-pattern rhs)
nts
rec-nt-terms)]))
rhss)))]
[`(name ,name ,pat)
(const/enum (hash-ref named-terms name))]
[`(mismatch-name ,name ,pat)
(error 'unimplemented "mismatch-name")]
[`(in-hole ,p1 ,p2) ;; untested
(map/enum
(λ (t1-t2)
(decomposition (car t1-t2)
(cdr t1-t2)))
(λ (decomp)
(cons (decomposition-ctx decomp)
(decomposition-term decomp)))
(prod/enum
(loop p1)
(loop p2)))]
[`(hide-hole ,p)
;; todo
(loop p)]
[`(side-condition ,p ,g ,e)
(error 'no-enum "side-condition")]
[`(cross ,s)
(error 'no-enum "cross")]
[`(list ,sub-pats ...)
;; enum-list
(map/enum
flatten-1
identity
(list/enum
(map
(λ (sub-pat)
(match sub-pat
[`(repeat ,pat #f #f)
(map/enum
cdr
(λ (ts)
(cons (length ts)
ts))
(dep/enum
nats
(λ (n)
(list/enum
(build-list n (const (loop pat)))))))]
[`(repeat ,pat ,name #f)
(error 'unimplemented "named-repeat")]
[`(repeat ,pat #f ,mismatch)
(error 'unimplemented "mismatch-repeat")]
[else (loop sub-pat)]))
sub-pats)))]
[(? (compose not pair?))
(const/enum pat)]))]))
(define (flatten-1 xs)
(append-map
(λ (x)
(if (or (pair? x)
(null? x))
x
(list x)))
xs))
;; lookup : lang symbol -> (listof rhs)
(define (lookup nts name)
(let rec ([nts nts])
(cond [(null? nts) (error 'unkown-nt)]
[(eq? name (nt-name (car nts)))
(nt-rhs (car nts))]
[else (rec (cdr nts))])))
(define natural/enum nats)
(define char/enum
(map/enum
integer->char
char->integer
(range/enum #x61 #x7a)))
(define string/enum
(map/enum
list->string
string->list
(listof/enum char/enum)))
(define integer/enum
(sum/enum nats
(map/enum (λ (n) (- (+ n 1)))
(λ (n) (- (- n) 1))
nats)))
(define real/enum (from-list/enum '(0.0 1.5 123.112354)))
(define num/enum
(sum/enum natural/enum
integer/enum
real/enum))
(define bool/enum
(from-list/enum '(#t #f)))
(define var/enum
(map/enum
(compose string->symbol list->string list)
(compose car string->list symbol->string)
char/enum))
(define any/enum
(sum/enum num/enum
string/enum
bool/enum
var/enum))
(define (plug-hole ctx term)
(let loop ([ctx ctx])
(match
ctx
['hole term]
[`(,ts ...)
(map loop ts)]
[x x])))
(module+ test
(require rackunit)
(define rep `(,(make-nt 'r
`(,(make-rhs `(list variable
(repeat variable #f #f)))))))
(define rs (hash-ref (lang-enumerators rep) 'r))
(test-begin
(check-equal? (enum-ith rs 0) '(a))
(check-equal? (size rs) +inf.f))
(define λc `(,(make-nt 'e
`(,(make-rhs `(list (repeat variable #f #f)))
,(make-rhs `(list λ variable (nt e)))
,(make-rhs `(list (nt e) (nt e)))))))
(define les (lang-enumerators λc))
(define es (hash-ref les 'e))
(check-equal? (size es) +inf.f))

View File

@ -0,0 +1,861 @@
#lang racket/base
(require racket/math
racket/list
racket/function)
(provide enum
enum?
size
encode
decode
empty/enum
const/enum
from-list/enum
sum/enum
prod/enum
dep/enum
dep2/enum ;; doesn't require size
map/enum
filter/enum ;; very bad, only use for small enums
except/enum
thunk/enum
listof/enum
list/enum
fail/enum
to-list
take/enum
drop/enum
foldl-enum
display-enum
nats
range/enum
nats+/enum)
;; an enum a is a struct of < Nat or +Inf, Nat -> a, a -> Nat >
(struct enum
(size from to)
#:prefab)
;; size : enum a -> Nat or +Inf
(define (size e)
(enum-size e))
;; decode : enum a, Nat -> a
(define (decode e n)
(if (and (< n (enum-size e))
(>= n 0))
((enum-from e) n)
(error 'out-of-range)))
;; encode : enum a, a -> Nat
(define (encode e a)
((enum-to e) a))
;; Helper functions
;; map/enum : (a -> b), (b -> a), enum a -> enum b
(define (map/enum f inv-f e)
(enum (size e)
(compose f (enum-from e))
(compose (enum-to e) inv-f)))
;; filter/enum : enum a, (a -> bool) -> enum a
;; size won't be accurate!
;; encode is not accurate right now!
(define (filter/enum e p)
(enum (size e)
(λ (n)
(let loop ([i 0]
[seen 0])
(let ([a (decode e i)])
(if (p a)
(if (= seen n)
a
(loop (+ i 1) (+ seen 1)))
(loop (+ i 1) seen)))))
(λ (x) (encode e x))))
;; except/enum : enum a, a -> enum a
(define (except/enum e a)
(unless (> (size e) 0)
(error 'empty-enum))
(let ([m (encode e a)])
(enum (- (size e) 1)
(λ (n)
(if (< n m)
(decode e n)
(decode e (+ n 1))))
(λ (x)
(let ([n (encode e x)])
(cond [(< n m) n]
[(> n m) (- n 1)]
[else (error 'excepted)]))))))
;; to-list : enum a -> listof a
;; better be finite
(define (to-list e)
(when (infinite? (size e))
(error 'too-big))
(map (enum-from e)
(build-list (size e)
identity)))
;; take/enum : enum a, Nat -> enum a
;; returns an enum of the first n parts of e
;; n must be less than or equal to size e
(define (take/enum e n)
(unless (or (infinite? (size e))
(<= n (size e)))
(error 'too-big))
(enum n
(λ (k)
(unless (< k n)
(error 'out-of-range))
(decode e k))
(λ (x)
(let ([k (encode e x)])
(unless (< k n)
(error 'out-of-range))
k))))
;; drop/enum : enum a, Nat -> enum a
(define (drop/enum e n)
(unless (or (infinite? (size e))
(<= n (size e)))
(error 'too-big))
(enum (- (size e) n)
(λ (m)
(decode e (+ n m)))
(λ (x)
(- (encode e x) n))))
;; foldl-enum : enum a, b, (a,b -> b) -> b
;; better be a finite enum
(define (foldl-enum f id e)
(foldl f id (to-list e)))
;; display-enum : enum a, Nat -> void
(define (display-enum e n)
(for ([i (range n)])
(display (decode e i))
(newline) (newline)))
(define empty/enum
(enum 0
(λ (n)
(error 'empty))
(λ (x)
(error 'empty))))
(define (const/enum c)
(enum 1
(λ (n)
c)
(λ (x)
(if (equal? c x)
0
(error 'bad-val)))))
;; from-list/enum :: Listof a -> Gen a
;; input list should not contain duplicates
(define (from-list/enum l)
(if (empty? l)
empty/enum
(enum (length l)
(λ (n)
(list-ref l n))
(λ (x)
(length (take-while l (λ (y)
(not (eq? x y)))))))))
;; take-while : Listof a, (a -> bool) -> Listof a
(define (take-while l pred)
(cond [(empty? l) (error 'empty)]
[(not (pred (car l))) '()]
[else
(cons (car l)
(take-while (cdr l) pred))]))
(define bools
(from-list/enum (list #t #f)))
(define nats
(enum +inf.f
identity
(λ (n)
(unless (>= n 0)
(error 'out-of-range))
n)))
(define ints
(enum +inf.f
(λ (n)
(if (even? n)
(* -1 (/ n 2))
(/ (+ n 1) 2)))
(λ (n)
(if (> n 0)
(- (* 2 n) 1)
(* 2 (abs n))))))
;; sum :: enum a, enum b -> enum (a or b)
(define sum/enum
(case-lambda
[(e) e]
[(e1 e2)
(cond
[(= 0 (size e1)) e2]
[(= 0 (size e2)) e1]
[(not (infinite? (enum-size e1)))
(enum (+ (enum-size e1)
(enum-size e2))
(λ (n)
(if (< n (enum-size e1))
((enum-from e1) n)
((enum-from e2) (- n (enum-size e1)))))
(λ (x)
(with-handlers ([exn:fail? (λ (_)
(+ (enum-size e1)
((enum-to e2) x)))])
((enum-to e1) x))))]
[(not (infinite? (enum-size e2)))
(sum/enum e2 e1)]
[else ;; both infinite, interleave them
(enum +inf.f
(λ (n)
(if (even? n)
((enum-from e1) (/ n 2))
((enum-from e2) (/ (- n 1) 2))))
(λ (x)
(with-handlers ([exn:fail?
(λ (_)
(+ (* ((enum-to e2) x) 2)
1))])
(* ((enum-to e1) x) 2))))])]
[(a b c . rest)
(sum/enum a (apply sum/enum b c rest))]))
(define odds
(enum +inf.f
(λ (n)
(+ (* 2 n) 1))
(λ (n)
(if (and (not (zero? (modulo n 2)))
(>= n 0))
(/ (- n 1) 2)
(error 'odd)))))
(define evens
(enum +inf.f
(λ (n)
(* 2 n))
(λ (n)
(if (and (zero? (modulo n 2))
(>= n 0))
(/ n 2)
(error 'even)))))
(define n*n
(enum +inf.f
(λ (n)
;; calculate the k for which (tri k) is the greatest
;; triangle number <= n
(let* ([k (floor-untri n)]
[t (tri k)]
[l (- n t)]
[m (- k l)])
(cons l m)))
(λ (ns)
(unless (pair? ns)
(error "not a list"))
(let ([l (car ns)]
[m (cdr ns)])
(+ (/ (* (+ l m) (+ l m 1))
2)
l))) ;; (n,m) -> (n+m)(n+m+1)/2 + n
))
;; prod/enum : enum a, enum b -> enum (a,b)
(define prod/enum
(case-lambda
[(e) e]
[(e1 e2)
(cond [(or (= 0 (size e1))
(= 0 (size e2))) empty/enum]
[(not (infinite? (enum-size e1)))
(cond [(not (infinite? (enum-size e2)))
(let [(size (* (enum-size e1)
(enum-size e2)))]
(enum size
(λ (n) ;; bijection from n -> axb
(if (> n size)
(error "out of range")
(call-with-values
(λ ()
(quotient/remainder n (enum-size e2)))
(λ (q r)
(cons ((enum-from e1) q)
((enum-from e2) r))))))
(λ (xs)
(unless (pair? xs)
(error "not a pair"))
(+ (* (enum-size e1)
((enum-to e1) (car xs)))
((enum-to e2) (cdr xs))))))]
[else
(enum +inf.f
(λ (n)
(call-with-values
(λ ()
(quotient/remainder n (enum-size e1)))
(λ (q r)
(cons ((enum-from e1) r)
((enum-from e2) q)))))
(λ (xs)
(unless (pair? xs)
(error "not a pair"))
(+ ((enum-to e1) (car xs))
(* (enum-size e1)
((enum-to e2) (cdr xs))))))])]
[(not (infinite? (enum-size e2)))
(enum +inf.f
(λ (n)
(call-with-values
(λ ()
(quotient/remainder n (enum-size e2)))
(λ (q r)
(cons ((enum-from e1) q)
((enum-from e2) r)))))
(λ (xs)
(unless (pair? xs)
(error "not a pair"))
(+ (* (enum-size e2)
((enum-to e1) (car xs)))
((enum-to e2) (cdr xs)))))]
[else
(enum (* (enum-size e1)
(enum-size e2))
(λ (n)
(let* ([k (floor-untri n)]
[t (tri k)]
[l (- n t)]
[m (- k l)])
(cons ((enum-from e1) l)
((enum-from e2) m))))
(λ (xs) ;; bijection from nxn -> n, inverse of previous
;; (n,m) -> (n+m)(n+m+1)/2 + n
(unless (pair? xs)
(error "not a pair"))
(let ([l ((enum-to e1) (car xs))]
[m ((enum-to e2) (cdr xs))])
(+ (/ (* (+ l m) (+ l m 1))
2)
l))))])]
[(a b c . rest)
(prod/enum a (apply prod/enum b c rest))]))
;; the nth triangle number
(define (tri n)
(/ (* n (+ n 1))
2))
;; the floor of the inverse of tri
;; returns the largest triangle number less than k
;; always returns an integer
(define (floor-untri k)
(let ([n (integer-sqrt (+ 1 (* 8 k)))])
(/ (- n
(if (even? n)
2
1))
2)))
;; dep/enum : enum a (a -> enum b) -> enum (a,b)
(define (dep/enum e f)
(cond [(= 0 (size e)) empty/enum]
[(not (infinite? (size (f (decode e 0)))))
(enum (if (infinite? (size e))
+inf.f
(foldl + 0 (map (compose size f) (to-list e))))
(λ (n) ;; n -> axb
(let loop ([ei 0]
[seen 0])
(let* ([a (decode e ei)]
[e2 (f a)])
(if (< (- n seen)
(size e2))
(cons a (decode e2 (- n seen)))
(loop (+ ei 1)
(+ seen (size e2)))))))
(λ (ab) ;; axb -> n
(let ([ai (encode e (car ab))])
(+ (let loop ([i 0]
[sum 0])
(if (>= i ai)
sum
(loop (+ i 1)
(+ sum
(size (f (decode e i)))))))
(encode (f (car ab))
(cdr ab))))))]
[(not (infinite? (size e)))
(enum +inf.f
(λ (n)
(call-with-values
(λ ()
(quotient/remainder n (size e)))
(λ (q r)
(cons (decode e r)
(decode (f (decode e r)) q)))))
(λ (ab)
(+ (* (size e) (encode (f (car ab)) (cdr ab)))
(encode e (car ab)))))]
[else ;; both infinite, same as prod/enum
(enum +inf.f
(λ (n)
(let* ([k (floor-untri n)]
[t (tri k)]
[l (- n t)]
[m (- k l)]
[a (decode e l)])
(cons a
(decode (f a) m))))
(λ (xs) ;; bijection from nxn -> n, inverse of previous
;; (n,m) -> (n+m)(n+m+1)/2 + n
(unless (pair? xs)
(error "not a pair"))
(let ([l (encode e (car xs))]
[m (encode (f (car xs)) (cdr xs))])
(+ (/ (* (+ l m) (+ l m 1))
2)
l))))]))
;; dep2 : enum a (a -> enum b) -> enum (a,b)
(define (dep2/enum e f)
(cond [(= 0 (size e)) empty/enum]
[(not (infinite? (size (f (decode e 0)))))
;; memoize tab : boxof (hash nat -o> (nat . nat))
;; maps an index into the dep/enum to the 2 indices that we need
(let ([tab (box (hash))])
(enum (if (infinite? (size e))
+inf.f
(foldl + 0 (map (compose size f) (to-list e))))
(λ (n) ;; n -> axb
(call-with-values
(λ ()
(letrec
;; go : nat -> nat nat
([go
(λ (n)
(cond [(hash-has-key? (unbox tab) n)
(let ([ij (hash-ref (unbox tab) n)])
(values (car ij) (cdr ij)))]
[(= n 0) ;; find the first element
(find 0 0 0)]
[else ;; recur
(call-with-values
(λ () (go (- n 1)))
(λ (ai bi)
(find ai (- n bi 1) n)))]))]
;; find : nat nat nat -> nat
[find
;; start is our starting eindex
;; seen is how many we've already seen
(λ (start seen n)
(let loop ([ai start]
[seen seen])
(let* ([a (decode e ai)]
[bs (f a)])
(cond [(< (- n seen)
(size bs))
(let ([bi (- n seen)])
(begin
(set-box! tab
(hash-set (unbox tab)
n
(cons ai bi)))
(values ai bi)))]
[else
(loop (+ ai 1)
(+ seen (size bs)))]))))])
(go n)))
(λ (ai bi)
(let ([a (decode e ai)])
(cons a
(decode (f a) bi))))))
;; todo: memoize encode
(λ (ab) ;; axb -> n
(let ([ai (encode e (car ab))])
(+ (let loop ([i 0]
[sum 0])
(if (>= i ai)
sum
(loop (+ i 1)
(+ sum
(size (f (decode e i)))))))
(encode (f (car ab))
(cdr ab)))))))]
[else ;; both infinite, same as prod/enum
(dep/enum e f)]))
;; more utility enums
;; nats of course
(define (range/enum low high)
(cond [(> low high) (error 'bad-range)]
[(infinite? high)
(if (infinite? low)
ints
(map/enum
(λ (n)
(+ n low))
(λ (n)
(- n low))
nats))]
[(infinite? low)
(map/enum
(λ (n)
(- high n))
(λ (n)
(+ high n))
nats)]
[else
(map/enum (λ (n) (+ n low))
(λ (n) (- n low))
(take/enum nats (+ 1 (- high low))))]))
;; thunk/enum : Nat or +-Inf, ( -> enum a) -> enum a
(define (thunk/enum s thunk)
(enum s
(λ (n)
(decode (thunk) n))
(λ (x)
(encode (thunk) x))))
;; listof/enum : enum a -> enum (listof a)
(define (listof/enum e)
(thunk/enum
(if (= 0 (size e))
0
+inf.f)
(λ ()
(sum/enum
(const/enum '())
(prod/enum e (listof/enum e))))))
;; list/enum : listof (enum any) -> enum (listof any)
(define (list/enum es)
(apply prod/enum (append es `(,(const/enum '())))))
(define (nats+/enum n)
(map/enum (λ (k)
(+ k n))
(λ (k)
(- k n))
nats))
;; fail/enum : exn -> enum ()
;; returns an enum that calls a thunk
(define (fail/enum e)
(let ([t
(λ (_)
(raise e))])
(enum 1
t
t)))
(module+
test
(require rackunit)
(provide check-bijection?)
(define confidence 1000)
(define nums (build-list confidence identity))
(define-simple-check (check-bijection? e)
(let ([nums (build-list (if (<= (enum-size e) confidence)
(enum-size e)
confidence)
identity)])
(andmap =
nums
(map (λ (n)
(encode e (decode e n)))
nums))))
;; const/enum tests
(let [(e (const/enum 17))]
(test-begin
(check-eq? (decode e 0) 17)
(check-exn exn:fail?
(λ ()
(decode e 1)))
(check-eq? (encode e 17) 0)
(check-exn exn:fail?
(λ ()
(encode e 0)))
(check-bijection? e)))
;; from-list/enum tests
(let [(e (from-list/enum '(5 4 1 8)))]
(test-begin
(check-eq? (decode e 0) 5)
(check-eq? (decode e 3) 8)
(check-exn exn:fail?
(λ () (decode e 4)))
(check-eq? (encode e 5) 0)
(check-eq? (encode e 8) 3)
(check-exn exn:fail?
(λ ()
(encode e 17)))
(check-bijection? e)))
;; map test
(define nats+1 (nats+/enum 1))
(test-begin
(check-equal? (size nats+1) +inf.f)
(check-equal? (decode nats+1 0) 1)
(check-equal? (decode nats+1 1) 2)
(check-bijection? nats+1))
;; encode check
(test-begin
(check-exn exn:fail?
(λ ()
(decode nats -1))))
;; ints checks
(test-begin
(check-eq? (decode ints 0) 0) ; 0 -> 0
(check-eq? (decode ints 1) 1) ; 1 -> 1
(check-eq? (decode ints 2) -1) ; 2 -> 1
(check-eq? (encode ints 0) 0)
(check-eq? (encode ints 1) 1)
(check-eq? (encode ints -1) 2)
(check-bijection? ints)) ; -1 -> 2, -3 -> 4
;; sum tests
(test-begin
(let [(bool-or-num (sum/enum bools
(from-list/enum '(0 1 2))))
(bool-or-nat (sum/enum bools
nats))
(nat-or-bool (sum/enum nats
bools))
(odd-or-even (sum/enum evens
odds))]
(check-equal? (enum-size bool-or-num)
5)
(check-equal? (decode bool-or-num 0) #t)
(check-equal? (decode bool-or-num 1) #f)
(check-equal? (decode bool-or-num 2) 0)
(check-exn exn:fail?
(λ ()
(decode bool-or-num 5)))
(check-equal? (encode bool-or-num #f) 1)
(check-equal? (encode bool-or-num 2) 4)
(check-bijection? bool-or-num)
(check-equal? (enum-size bool-or-nat)
+inf.f)
(check-equal? (decode bool-or-nat 0) #t)
(check-equal? (decode bool-or-nat 2) 0)
(check-bijection? bool-or-nat)
(check-equal? (encode bool-or-num #f) 1)
(check-equal? (encode bool-or-num 2) 4)
(check-equal? (enum-size odd-or-even)
+inf.f)
(check-equal? (decode odd-or-even 0) 0)
(check-equal? (decode odd-or-even 1) 1)
(check-equal? (decode odd-or-even 2) 2)
(check-exn exn:fail?
(λ ()
(decode odd-or-even -1)))
(check-equal? (encode odd-or-even 0) 0)
(check-equal? (encode odd-or-even 1) 1)
(check-equal? (encode odd-or-even 2) 2)
(check-equal? (encode odd-or-even 3) 3)
(check-bijection? odd-or-even)))
;; prod/enum tests
(define bool*bool (prod/enum bools bools))
(define 1*b (prod/enum (const/enum 1) bools))
(define bool*nats (prod/enum bools nats))
(define nats*bool (prod/enum nats bools))
(define nats*nats (prod/enum nats nats))
(define ns-equal? (λ (ns ms)
(and (= (car ns)
(car ms))
(= (cdr ns)
(cdr ms)))))
;; prod tests
(test-begin
(check-equal? (size 1*b) 2)
(check-equal? (decode 1*b 0) (cons 1 #t))
(check-equal? (decode 1*b 1) (cons 1 #f))
(check-bijection? 1*b)
(check-equal? (enum-size bool*bool) 4)
(check-equal? (decode bool*bool 0)
(cons #t #t))
(check-equal? (decode bool*bool 1)
(cons #t #f))
(check-equal? (decode bool*bool 2)
(cons #f #t))
(check-equal? (decode bool*bool 3)
(cons #f #f))
(check-bijection? bool*bool)
(check-equal? (enum-size bool*nats) +inf.f)
(check-equal? (decode bool*nats 0)
(cons #t 0))
(check-equal? (decode bool*nats 1)
(cons #f 0))
(check-equal? (decode bool*nats 2)
(cons #t 1))
(check-equal? (decode bool*nats 3)
(cons #f 1))
(check-bijection? bool*nats)
(check-equal? (enum-size nats*bool) +inf.f)
(check-equal? (decode nats*bool 0)
(cons 0 #t))
(check-equal? (decode nats*bool 1)
(cons 0 #f))
(check-equal? (decode nats*bool 2)
(cons 1 #t))
(check-equal? (decode nats*bool 3)
(cons 1 #f))
(check-bijection? nats*bool)
(check-equal? (enum-size nats*nats) +inf.f)
(check ns-equal?
(decode nats*nats 0)
(cons 0 0))
(check ns-equal?
(decode nats*nats 1)
(cons 0 1))
(check ns-equal?
(decode nats*nats 2)
(cons 1 0))
(check ns-equal?
(decode nats*nats 3)
(cons 0 2))
(check ns-equal?
(decode nats*nats 4)
(cons 1 1))
(check-bijection? nats*nats))
;; dep/enum tests
(define (up-to n)
(take/enum nats (+ n 1)))
(define 3-up
(dep/enum
(from-list/enum '(0 1 2))
up-to))
(define from-3
(dep/enum
(from-list/enum '(0 1 2))
nats+/enum))
(define nats-to
(dep/enum nats up-to))
(define nats-up
(dep/enum nats nats+/enum))
(test-begin
(check-equal? (size 3-up) 6)
(check-equal? (decode 3-up 0) (cons 0 0))
(check-equal? (decode 3-up 1) (cons 1 0))
(check-equal? (decode 3-up 2) (cons 1 1))
(check-equal? (decode 3-up 3) (cons 2 0))
(check-equal? (decode 3-up 4) (cons 2 1))
(check-equal? (decode 3-up 5) (cons 2 2))
(check-bijection? 3-up)
(check-equal? (size from-3) +inf.f)
(check-equal? (decode from-3 0) (cons 0 0))
(check-equal? (decode from-3 1) (cons 1 1))
(check-equal? (decode from-3 2) (cons 2 2))
(check-equal? (decode from-3 3) (cons 0 1))
(check-equal? (decode from-3 4) (cons 1 2))
(check-equal? (decode from-3 5) (cons 2 3))
(check-equal? (decode from-3 6) (cons 0 2))
(check-bijection? from-3)
(check-equal? (size nats-to) +inf.f)
(check-equal? (decode nats-to 0) (cons 0 0))
(check-equal? (decode nats-to 1) (cons 1 0))
(check-equal? (decode nats-to 2) (cons 1 1))
(check-equal? (decode nats-to 3) (cons 2 0))
(check-equal? (decode nats-to 4) (cons 2 1))
(check-equal? (decode nats-to 5) (cons 2 2))
(check-equal? (decode nats-to 6) (cons 3 0))
(check-bijection? nats-to)
(check-equal? (size nats-up) +inf.f)
(check-equal? (decode nats-up 0) (cons 0 0))
(check-equal? (decode nats-up 1) (cons 0 1))
(check-equal? (decode nats-up 2) (cons 1 1))
(check-equal? (decode nats-up 3) (cons 0 2))
(check-equal? (decode nats-up 4) (cons 1 2))
(check-equal? (decode nats-up 5) (cons 2 2))
(check-equal? (decode nats-up 6) (cons 0 3))
(check-equal? (decode nats-up 7) (cons 1 3))
(check-bijection? nats-up))
;; dep2/enum tests
;; same as dep unless the right side is finite
(define 3-up-2
(dep2/enum
(from-list/enum '(0 1 2))
up-to))
(define nats-to-2
(dep2/enum nats up-to))
(test-begin
(check-equal? (size 3-up-2) 6)
(check-equal? (decode 3-up-2 0) (cons 0 0))
(check-equal? (decode 3-up-2 1) (cons 1 0))
(check-equal? (decode 3-up-2 2) (cons 1 1))
(check-equal? (decode 3-up-2 3) (cons 2 0))
(check-equal? (decode 3-up-2 4) (cons 2 1))
(check-equal? (decode 3-up-2 5) (cons 2 2))
(check-bijection? 3-up-2)
(check-equal? (size nats-to-2) +inf.f)
(check-equal? (decode nats-to-2 0) (cons 0 0))
(check-equal? (decode nats-to-2 1) (cons 1 0))
(check-equal? (decode nats-to-2 2) (cons 1 1))
(check-equal? (decode nats-to-2 3) (cons 2 0))
(check-equal? (decode nats-to-2 4) (cons 2 1))
(check-equal? (decode nats-to-2 5) (cons 2 2))
(check-equal? (decode nats-to-2 6) (cons 3 0))
(check-bijection? nats-to-2)
)
;; take/enum test
(define to-2 (up-to 2))
(test-begin
(check-equal? (size to-2) 3)
(check-equal? (decode to-2 0) 0)
(check-equal? (decode to-2 1) 1)
(check-equal? (decode to-2 2) 2)
(check-bijection? to-2))
;; to-list, foldl test
(test-begin
(check-equal? (to-list (up-to 3))
'(0 1 2 3))
(check-equal? (foldl-enum cons '() (up-to 3))
'(3 2 1 0))))