From b8538ec135b0b8cc4aa4818169c1e8f04872ed3a Mon Sep 17 00:00:00 2001 From: Max New Date: Thu, 9 May 2013 15:02:03 -0500 Subject: [PATCH] Added Redex enumerators. Supports names and recursive patterns. Limited support for repeats and mismatches. --- collects/redex/private/enum.rkt | 522 +++++++++++++++- collects/redex/private/enumerator.rkt | 861 ++++++++++++++++++++++++++ 2 files changed, 1375 insertions(+), 8 deletions(-) create mode 100644 collects/redex/private/enumerator.rkt diff --git a/collects/redex/private/enum.rkt b/collects/redex/private/enum.rkt index 709b15cdc6..eef1fc58da 100644 --- a/collects/redex/private/enum.rkt +++ b/collects/redex/private/enum.rkt @@ -1,6 +1,11 @@ #lang racket/base (require racket/contract - "lang-struct.rkt") + racket/list + racket/match + racket/function + "lang-struct.rkt" + "match-a-pattern.rkt" + "enumerator.rkt") (provide (contract-out @@ -11,11 +16,512 @@ [enum-ith (-> enum? exact-nonnegative-integer? any/c)] [enum? (-> any/c boolean?)])) -(struct enum ()) -(define (lang-enumerators nts) (make-hash)) -(define (pat-enumerator li p) - (unless (equal? p '(name natural natural)) - (error 'enum.rkt "not yet implemented")) - (enum)) -(define (enum-ith e i) i) +(define (lang-enumerators nts) + (let* ([l-enums (make-hash)] + [rec-nt-terms (find-recs nts)] + [sorted-nts (sort-nt-terms nts rec-nt-terms)]) + (foldl + (λ (nt m) + (hash-set + m + (nt-name nt) + (with-handlers + ([exn:fail? fail/enum]) + (rec-pat/enum `(nt ,(nt-name nt)) + sorted-nts + rec-nt-terms)))) + (hash) + sorted-nts))) + +(define enum-ith decode) +(struct decomposition (ctx term)) + +(define (pat-enumerator lang-enums pat) + (enum-names pat + (sep-names pat) + lang-enums)) + +(define (rec-pat/enum pat nts rec-nt-terms) + (enum-names pat + nts + (sep-names pat) + rec-nt-terms)) + + +;; find-recs : lang -> (hash symbol -o> (assoclist rhs bool)) +;; Identifies which non-terminals are recursive +(define (find-recs nt-pats) + (define is-rec? + (case-lambda + [(n) (is-rec? n (hash))] + [(nt seen) + (or (seen? seen (nt-name nt)) + (ormap + (λ (rhs) + (let rec ([pat (rhs-pattern rhs)]) + (match-a-pattern + pat + [`any #f] + [`number #f] + [`string #f] + [`natural #f] + [`integer #f] + [`real #f] + [`boolean #f] + [`variable #f] + [`(variable-except ,s ...) #f] + [`(variable-prefix ,s) #f] + [`variable-not-otherwise-mentioned #f] + [`hole #f] + [`(nt ,id) + (is-rec? (make-nt + id + (lookup nt-pats id)) + (add-seen seen + (nt-name nt)))] + [`(name ,name ,pat) + (rec pat)] + [`(mismatch-name ,name ,pat) + (rec pat)] + [`(in-hole ,p1 ,p2) + (or (rec p1) + (rec p2))] + [`(hide-hole ,p) (rec p)] + [`(side-condition ,p ,g ,e) ;; error + (error 'unsupported "side-condition")] + [`(cross ,s) + (error 'unsupported "cross")] ;; error + [`(list ,sub-pats ...) + (ormap (λ (sub-pat) + (match sub-pat + [`(repeat ,pat ,name ,mismatch) + (rec pat)] + [else (rec sub-pat)])) + sub-pats)] + [(? (compose not pair?)) #f]))) + (nt-rhs nt)))])) + (define (calls-rec? rhs recs) + (let rec ([pat (rhs-pattern rhs)]) + (match-a-pattern + pat + [`any #f] + [`number #f] + [`string #f] + [`natural #f] + [`integer #f] + [`real #f] + [`boolean #f] + [`variable #f] + [`(variable-except ,s ...) #f] + [`(variable-prefix ,s) #f] + [`variable-not-otherwise-mentioned #f] + [`hole #f] + [`(nt ,id) + (hash-ref recs id)] + [`(name ,name ,pat) + (rec pat)] + [`(mismatch-name ,name ,pat) + (rec pat)] + [`(in-hole ,p1 ,p2) + (or (rec p1) + (rec p2))] + [`(hide-hole ,p) (rec p)] + [`(side-condition ,p ,g ,e) ;; error + (error 'no-enum "side-condition")] + [`(cross ,s) + (error 'no-enum "cross")] ;; error + [`(list ,sub-pats ...) + (ormap (λ (sub-pat) + (match sub-pat + [`(repeat ,pat ,name ,mismatch) + (rec pat)] + [else (rec sub-pat)])) + sub-pats)] + [(? (compose not pair?)) #f]))) + (define (seen? m s) + (hash-ref m s #f)) + (define (add-seen m s) + (hash-set m s #t)) + (let ([recs + (foldl + (λ (nt m) + (hash-set m (nt-name nt) (is-rec? nt))) + (hash) nt-pats)]) + (foldl + (λ (nt m) + (let ([rhs (nt-rhs nt)]) + (hash-set m (nt-name nt) + (map (λ (rhs) + (cons rhs (calls-rec? rhs recs))) + rhs)))) + (hash) + nt-pats))) + +;; sort-nt-terms : lang (hash symbol -o> (assoclist rhs bool)) -> lang +(define (sort-nt-terms nt-pats recs) + (map + (λ (nt) + (let ([rec-nts (hash-ref recs (nt-name nt))]) + (make-nt (nt-name nt) + (sort (nt-rhs nt) + (λ (r1 r2) + (and (not (cdr (assoc r1 rec-nts))) + (cdr (assoc r2 rec-nts)))))))) + nt-pats)) + +;; sep-names : single-pattern lang -> (assoclist symbol pattern) +(define (sep-names pat) + (let loop ([pat pat] + [named-pats '()]) + (match-a-pattern + pat + [`any named-pats] + [`number named-pats] + [`string named-pats] + [`natural named-pats] + [`integer named-pats] + [`real named-pats] + [`boolean named-pats] + [`variable named-pats] + [`(variable-except ,s ...) named-pats] + [`(variable-prefix ,s) named-pats] + [`variable-not-otherwise-mentioned named-pats] + [`hole named-pats] + ;; names inside nts are separate + [`(nt ,id) named-pats] + [`(name ,name ,pat) + (loop pat + (add-if-new name pat named-pats))] + [`(mismatch-name ,name ,pat) + (loop pat + (add-if-new name pat named-pats))] + [`(in-hole ,p1 ,p2) + (loop p2 + (loop p1 named-pats))] + [`(hide-hole ,p) (loop p named-pats)] + [`(side-condition ,p ,g ,e) ;; error + (error 'no-enum "side condition")] + [`(cross ,s) + (error 'no-enum "cross")] ;; error + [`(list ,sub-pats ...) + (foldl (λ (sub-pat named-pats) + (match sub-pat + ;; unnamed repeat + [`(repeat ,pat #f #f) + (loop pat named-pats)] + ;; named repeat + [`(repeat ,pat ,name #f) + (loop pat + (add-if-new name 'name-r named-pats))] + ;; mismatch named repeat + [`(repeat ,pat #f ,mismatch) + (loop pat + (add-if-new mismatch 'mismatch-r named-pats))] + ;; normal subpattern + [else (loop sub-pat named-pats)])) + named-pats + sub-pats)] + [(? (compose not pair?)) + named-pats]))) + +(define (add-if-new k v l) + (cond [(assoc k l) l] + [else (cons `(,k ,v) l)])) + +(define enum-names + (case-lambda + [(pat named-pats nts) + (enum-names-with + (λ (pat named) + (pat/enum-with-names pat nts named)) + pat named-pats)] + [(pat nts named-pats rec-nt-terms) + (enum-names-with + (λ (pat named) + (pat/enum-with-names pat nts named rec-nt-terms)) + pat named-pats)])) + +(define (enum-names-with f pat named-pats) + (let rec ([named-pats named-pats] + [env (hash)]) + (cond [(null? named-pats) (f pat env)] + [else + (match + (car named-pats) + ;; named repeat + [`(,name name-r) + (error 'unimplemented "named-repeat")] + ;; mismatch repeat + [`(,name mismatch-r) + (error 'unimplemented "mismatch-repeat")] + [`(,name ,pat mismatch) + (error 'unimplemented "mismatch")] + ;; named + [`(,name ,pat) + (map/enum ;; loses bijection + cdr + (λ (x) (cons name x)) + (dep/enum + (f pat env) + (λ (term) + (rec (cdr named-pats) + (hash-set env + name + term)))))] + [else (error 'bad-assoc)])]))) + +(define pat/enum-with-names + (case-lambda + [(pat nt-enums named-terms) + (let loop ([pat pat]) + (match-a-pattern + pat + [`any + (sum/enum + any/enum + (listof/enum any/enum))] + [`number num/enum] + [`string string/enum] + [`natural natural/enum] + [`integer integer/enum] + [`real real/enum] + [`boolean bool/enum] + [`variable var/enum] + [`(variable-except ,s ...) + ;; todo + (error 'unimplemented "var-except")] + [`(variable-prefix ,s) + ;; todo + (error 'unimplemented "var-prefix")] + [`variable-not-otherwise-mentioned + (error 'unimplemented "var-not-mentioned")] ;; error + [`hole + (const/enum 'hole)] + [`(nt ,id) + (hash-ref nt-enums id)] + [`(name ,name ,pat) + (const/enum (hash-ref named-terms name))] + [`(mismatch-name ,name ,pat) + (error 'unimplemented "mismatch-name")] + [`(in-hole ,p1 ,p2) ;; untested + (map/enum + (λ (t1-t2) ;; loses bijection + (plug-hole (car t1-t2) + (cdr t1-t2))) + (λ (plugged) + (cons 'hole plugged)) + (prod/enum + (loop p1) + (loop p2)))] + [`(hide-hole ,p) + (loop p)] + [`(side-condition ,p ,g ,e) + (error 'no-enum "side condition")] + [`(cross ,s) + (error 'no-enum "cross")] + [`(list ,sub-pats ...) + ;; enum-list + (map/enum + flatten-1 + identity + (list/enum + (map + (λ (sub-pat) + (match sub-pat + [`(repeat ,pat #f #f) + (map/enum + cdr + (λ (ts) + (cons (length ts) + ts)) + (dep/enum + nats + (λ (n) + (list/enum + (build-list n (const (loop pat)))))))] + [`(repeat ,pat ,name #f) + (error 'unimplemented "named-repeat")] + [`(repeat ,pat #f ,mismatch) + (error 'unimplemented "mismatch-repeat")] + [else (loop sub-pat)])) + sub-pats)))] + [(? (compose not pair?)) + (const/enum pat)]))] + [(pat nts named-terms rec-nt-terms) + (let loop ([pat pat]) + (match-a-pattern + pat + [`any + (sum/enum + any/enum + (listof/enum any/enum))] + [`number num/enum] + [`string string/enum] + [`natural natural/enum] + [`integer integer/enum] + [`real real/enum] + [`boolean bool/enum] + [`variable var/enum] + [`(variable-except ,s ...) + ;; todo + (error 'unimplemented "var except")] + [`(variable-prefix ,s) + ;; todo + (error 'unimplemented "var prefix")] + [`variable-not-otherwise-mentioned + (error 'unimplemented "var not otherwise mentioned")] + [`hole + (const/enum 'hole)] + [`(nt ,id) + (let ([rhss (lookup nts id)]) + (apply sum/enum + (map + (λ (rhs) + (cond [(cdr (assoc rhs (hash-ref rec-nt-terms id))) + (thunk/enum + +inf.f + (λ () + (rec-pat/enum (rhs-pattern rhs) + nts + rec-nt-terms)))] + [else + (rec-pat/enum (rhs-pattern rhs) + nts + rec-nt-terms)])) + rhss)))] + [`(name ,name ,pat) + (const/enum (hash-ref named-terms name))] + [`(mismatch-name ,name ,pat) + (error 'unimplemented "mismatch-name")] + [`(in-hole ,p1 ,p2) ;; untested + (map/enum + (λ (t1-t2) + (decomposition (car t1-t2) + (cdr t1-t2))) + (λ (decomp) + (cons (decomposition-ctx decomp) + (decomposition-term decomp))) + (prod/enum + (loop p1) + (loop p2)))] + [`(hide-hole ,p) + ;; todo + (loop p)] + [`(side-condition ,p ,g ,e) + (error 'no-enum "side-condition")] + [`(cross ,s) + (error 'no-enum "cross")] + [`(list ,sub-pats ...) + ;; enum-list + (map/enum + flatten-1 + identity + (list/enum + (map + (λ (sub-pat) + (match sub-pat + [`(repeat ,pat #f #f) + (map/enum + cdr + (λ (ts) + (cons (length ts) + ts)) + (dep/enum + nats + (λ (n) + (list/enum + (build-list n (const (loop pat)))))))] + [`(repeat ,pat ,name #f) + (error 'unimplemented "named-repeat")] + [`(repeat ,pat #f ,mismatch) + (error 'unimplemented "mismatch-repeat")] + [else (loop sub-pat)])) + sub-pats)))] + [(? (compose not pair?)) + (const/enum pat)]))])) + +(define (flatten-1 xs) + (append-map + (λ (x) + (if (or (pair? x) + (null? x)) + x + (list x))) + xs)) + +;; lookup : lang symbol -> (listof rhs) +(define (lookup nts name) + (let rec ([nts nts]) + (cond [(null? nts) (error 'unkown-nt)] + [(eq? name (nt-name (car nts))) + (nt-rhs (car nts))] + [else (rec (cdr nts))]))) + +(define natural/enum nats) + +(define char/enum + (map/enum + integer->char + char->integer + (range/enum #x61 #x7a))) + +(define string/enum + (map/enum + list->string + string->list + (listof/enum char/enum))) + +(define integer/enum + (sum/enum nats + (map/enum (λ (n) (- (+ n 1))) + (λ (n) (- (- n) 1)) + nats))) + +(define real/enum (from-list/enum '(0.0 1.5 123.112354))) +(define num/enum + (sum/enum natural/enum + integer/enum + real/enum)) + +(define bool/enum + (from-list/enum '(#t #f))) + +(define var/enum + (map/enum + (compose string->symbol list->string list) + (compose car string->list symbol->string) + char/enum)) + +(define any/enum + (sum/enum num/enum + string/enum + bool/enum + var/enum)) + +(define (plug-hole ctx term) + (let loop ([ctx ctx]) + (match + ctx + ['hole term] + [`(,ts ...) + (map loop ts)] + [x x]))) + +(module+ test + (require rackunit) + + (define rep `(,(make-nt 'r + `(,(make-rhs `(list variable + (repeat variable #f #f))))))) + (define rs (hash-ref (lang-enumerators rep) 'r)) + (test-begin + (check-equal? (enum-ith rs 0) '(a)) + (check-equal? (size rs) +inf.f)) + (define λc `(,(make-nt 'e + `(,(make-rhs `(list (repeat variable #f #f))) + ,(make-rhs `(list λ variable (nt e))) + ,(make-rhs `(list (nt e) (nt e))))))) + (define les (lang-enumerators λc)) + (define es (hash-ref les 'e)) + (check-equal? (size es) +inf.f)) diff --git a/collects/redex/private/enumerator.rkt b/collects/redex/private/enumerator.rkt new file mode 100644 index 0000000000..38fed5534b --- /dev/null +++ b/collects/redex/private/enumerator.rkt @@ -0,0 +1,861 @@ +#lang racket/base +(require racket/math + racket/list + racket/function) + +(provide enum + enum? + size + encode + decode + empty/enum + const/enum + from-list/enum + sum/enum + prod/enum + dep/enum + dep2/enum ;; doesn't require size + map/enum + filter/enum ;; very bad, only use for small enums + except/enum + thunk/enum + listof/enum + list/enum + fail/enum + + to-list + take/enum + drop/enum + foldl-enum + display-enum + + nats + range/enum + nats+/enum) + +;; an enum a is a struct of < Nat or +Inf, Nat -> a, a -> Nat > +(struct enum + (size from to) + #:prefab) + +;; size : enum a -> Nat or +Inf +(define (size e) + (enum-size e)) + +;; decode : enum a, Nat -> a +(define (decode e n) + (if (and (< n (enum-size e)) + (>= n 0)) + ((enum-from e) n) + (error 'out-of-range))) + +;; encode : enum a, a -> Nat +(define (encode e a) + ((enum-to e) a)) + +;; Helper functions +;; map/enum : (a -> b), (b -> a), enum a -> enum b +(define (map/enum f inv-f e) + (enum (size e) + (compose f (enum-from e)) + (compose (enum-to e) inv-f))) + +;; filter/enum : enum a, (a -> bool) -> enum a +;; size won't be accurate! +;; encode is not accurate right now! +(define (filter/enum e p) + (enum (size e) + (λ (n) + (let loop ([i 0] + [seen 0]) + (let ([a (decode e i)]) + (if (p a) + (if (= seen n) + a + (loop (+ i 1) (+ seen 1))) + (loop (+ i 1) seen))))) + (λ (x) (encode e x)))) + +;; except/enum : enum a, a -> enum a +(define (except/enum e a) + (unless (> (size e) 0) + (error 'empty-enum)) + (let ([m (encode e a)]) + (enum (- (size e) 1) + (λ (n) + (if (< n m) + (decode e n) + (decode e (+ n 1)))) + (λ (x) + (let ([n (encode e x)]) + (cond [(< n m) n] + [(> n m) (- n 1)] + [else (error 'excepted)])))))) + +;; to-list : enum a -> listof a +;; better be finite +(define (to-list e) + (when (infinite? (size e)) + (error 'too-big)) + (map (enum-from e) + (build-list (size e) + identity))) + +;; take/enum : enum a, Nat -> enum a +;; returns an enum of the first n parts of e +;; n must be less than or equal to size e +(define (take/enum e n) + (unless (or (infinite? (size e)) + (<= n (size e))) + (error 'too-big)) + (enum n + (λ (k) + (unless (< k n) + (error 'out-of-range)) + (decode e k)) + (λ (x) + (let ([k (encode e x)]) + (unless (< k n) + (error 'out-of-range)) + k)))) + +;; drop/enum : enum a, Nat -> enum a +(define (drop/enum e n) + (unless (or (infinite? (size e)) + (<= n (size e))) + (error 'too-big)) + (enum (- (size e) n) + (λ (m) + (decode e (+ n m))) + (λ (x) + (- (encode e x) n)))) + +;; foldl-enum : enum a, b, (a,b -> b) -> b +;; better be a finite enum +(define (foldl-enum f id e) + (foldl f id (to-list e))) + +;; display-enum : enum a, Nat -> void +(define (display-enum e n) + (for ([i (range n)]) + (display (decode e i)) + (newline) (newline))) + +(define empty/enum + (enum 0 + (λ (n) + (error 'empty)) + (λ (x) + (error 'empty)))) + +(define (const/enum c) + (enum 1 + (λ (n) + c) + (λ (x) + (if (equal? c x) + 0 + (error 'bad-val))))) + +;; from-list/enum :: Listof a -> Gen a +;; input list should not contain duplicates +(define (from-list/enum l) + (if (empty? l) + empty/enum + (enum (length l) + (λ (n) + (list-ref l n)) + (λ (x) + (length (take-while l (λ (y) + (not (eq? x y))))))))) + +;; take-while : Listof a, (a -> bool) -> Listof a +(define (take-while l pred) + (cond [(empty? l) (error 'empty)] + [(not (pred (car l))) '()] + [else + (cons (car l) + (take-while (cdr l) pred))])) + +(define bools + (from-list/enum (list #t #f))) +(define nats + (enum +inf.f + identity + (λ (n) + (unless (>= n 0) + (error 'out-of-range)) + n))) +(define ints + (enum +inf.f + (λ (n) + (if (even? n) + (* -1 (/ n 2)) + (/ (+ n 1) 2))) + (λ (n) + (if (> n 0) + (- (* 2 n) 1) + (* 2 (abs n)))))) + +;; sum :: enum a, enum b -> enum (a or b) +(define sum/enum + (case-lambda + [(e) e] + [(e1 e2) + (cond + [(= 0 (size e1)) e2] + [(= 0 (size e2)) e1] + [(not (infinite? (enum-size e1))) + (enum (+ (enum-size e1) + (enum-size e2)) + (λ (n) + (if (< n (enum-size e1)) + ((enum-from e1) n) + ((enum-from e2) (- n (enum-size e1))))) + (λ (x) + (with-handlers ([exn:fail? (λ (_) + (+ (enum-size e1) + ((enum-to e2) x)))]) + ((enum-to e1) x))))] + [(not (infinite? (enum-size e2))) + (sum/enum e2 e1)] + [else ;; both infinite, interleave them + (enum +inf.f + (λ (n) + (if (even? n) + ((enum-from e1) (/ n 2)) + ((enum-from e2) (/ (- n 1) 2)))) + (λ (x) + (with-handlers ([exn:fail? + (λ (_) + (+ (* ((enum-to e2) x) 2) + 1))]) + (* ((enum-to e1) x) 2))))])] + [(a b c . rest) + (sum/enum a (apply sum/enum b c rest))])) + +(define odds + (enum +inf.f + (λ (n) + (+ (* 2 n) 1)) + (λ (n) + (if (and (not (zero? (modulo n 2))) + (>= n 0)) + (/ (- n 1) 2) + (error 'odd))))) + +(define evens + (enum +inf.f + (λ (n) + (* 2 n)) + (λ (n) + (if (and (zero? (modulo n 2)) + (>= n 0)) + (/ n 2) + (error 'even))))) + +(define n*n + (enum +inf.f + (λ (n) + ;; calculate the k for which (tri k) is the greatest + ;; triangle number <= n + (let* ([k (floor-untri n)] + [t (tri k)] + [l (- n t)] + [m (- k l)]) + (cons l m))) + (λ (ns) + (unless (pair? ns) + (error "not a list")) + (let ([l (car ns)] + [m (cdr ns)]) + (+ (/ (* (+ l m) (+ l m 1)) + 2) + l))) ;; (n,m) -> (n+m)(n+m+1)/2 + n + )) + +;; prod/enum : enum a, enum b -> enum (a,b) +(define prod/enum + (case-lambda + [(e) e] + [(e1 e2) + (cond [(or (= 0 (size e1)) + (= 0 (size e2))) empty/enum] + [(not (infinite? (enum-size e1))) + (cond [(not (infinite? (enum-size e2))) + (let [(size (* (enum-size e1) + (enum-size e2)))] + (enum size + (λ (n) ;; bijection from n -> axb + (if (> n size) + (error "out of range") + (call-with-values + (λ () + (quotient/remainder n (enum-size e2))) + (λ (q r) + (cons ((enum-from e1) q) + ((enum-from e2) r)))))) + (λ (xs) + (unless (pair? xs) + (error "not a pair")) + (+ (* (enum-size e1) + ((enum-to e1) (car xs))) + ((enum-to e2) (cdr xs))))))] + [else + (enum +inf.f + (λ (n) + (call-with-values + (λ () + (quotient/remainder n (enum-size e1))) + (λ (q r) + (cons ((enum-from e1) r) + ((enum-from e2) q))))) + (λ (xs) + (unless (pair? xs) + (error "not a pair")) + (+ ((enum-to e1) (car xs)) + (* (enum-size e1) + ((enum-to e2) (cdr xs))))))])] + [(not (infinite? (enum-size e2))) + (enum +inf.f + (λ (n) + (call-with-values + (λ () + (quotient/remainder n (enum-size e2))) + (λ (q r) + (cons ((enum-from e1) q) + ((enum-from e2) r))))) + (λ (xs) + (unless (pair? xs) + (error "not a pair")) + (+ (* (enum-size e2) + ((enum-to e1) (car xs))) + ((enum-to e2) (cdr xs)))))] + [else + (enum (* (enum-size e1) + (enum-size e2)) + (λ (n) + (let* ([k (floor-untri n)] + [t (tri k)] + [l (- n t)] + [m (- k l)]) + (cons ((enum-from e1) l) + ((enum-from e2) m)))) + (λ (xs) ;; bijection from nxn -> n, inverse of previous + ;; (n,m) -> (n+m)(n+m+1)/2 + n + (unless (pair? xs) + (error "not a pair")) + (let ([l ((enum-to e1) (car xs))] + [m ((enum-to e2) (cdr xs))]) + (+ (/ (* (+ l m) (+ l m 1)) + 2) + l))))])] + [(a b c . rest) + (prod/enum a (apply prod/enum b c rest))])) + +;; the nth triangle number +(define (tri n) + (/ (* n (+ n 1)) + 2)) + +;; the floor of the inverse of tri +;; returns the largest triangle number less than k +;; always returns an integer +(define (floor-untri k) + (let ([n (integer-sqrt (+ 1 (* 8 k)))]) + (/ (- n + (if (even? n) + 2 + 1)) + 2))) + + +;; dep/enum : enum a (a -> enum b) -> enum (a,b) +(define (dep/enum e f) + (cond [(= 0 (size e)) empty/enum] + [(not (infinite? (size (f (decode e 0))))) + (enum (if (infinite? (size e)) + +inf.f + (foldl + 0 (map (compose size f) (to-list e)))) + (λ (n) ;; n -> axb + (let loop ([ei 0] + [seen 0]) + (let* ([a (decode e ei)] + [e2 (f a)]) + (if (< (- n seen) + (size e2)) + (cons a (decode e2 (- n seen))) + (loop (+ ei 1) + (+ seen (size e2))))))) + (λ (ab) ;; axb -> n + (let ([ai (encode e (car ab))]) + (+ (let loop ([i 0] + [sum 0]) + (if (>= i ai) + sum + (loop (+ i 1) + (+ sum + (size (f (decode e i))))))) + (encode (f (car ab)) + (cdr ab))))))] + [(not (infinite? (size e))) + (enum +inf.f + (λ (n) + (call-with-values + (λ () + (quotient/remainder n (size e))) + (λ (q r) + (cons (decode e r) + (decode (f (decode e r)) q))))) + (λ (ab) + (+ (* (size e) (encode (f (car ab)) (cdr ab))) + (encode e (car ab)))))] + [else ;; both infinite, same as prod/enum + (enum +inf.f + (λ (n) + (let* ([k (floor-untri n)] + [t (tri k)] + [l (- n t)] + [m (- k l)] + [a (decode e l)]) + (cons a + (decode (f a) m)))) + (λ (xs) ;; bijection from nxn -> n, inverse of previous + ;; (n,m) -> (n+m)(n+m+1)/2 + n + (unless (pair? xs) + (error "not a pair")) + (let ([l (encode e (car xs))] + [m (encode (f (car xs)) (cdr xs))]) + (+ (/ (* (+ l m) (+ l m 1)) + 2) + l))))])) + +;; dep2 : enum a (a -> enum b) -> enum (a,b) +(define (dep2/enum e f) + (cond [(= 0 (size e)) empty/enum] + [(not (infinite? (size (f (decode e 0))))) + ;; memoize tab : boxof (hash nat -o> (nat . nat)) + ;; maps an index into the dep/enum to the 2 indices that we need + (let ([tab (box (hash))]) + (enum (if (infinite? (size e)) + +inf.f + (foldl + 0 (map (compose size f) (to-list e)))) + (λ (n) ;; n -> axb + (call-with-values + (λ () + (letrec + ;; go : nat -> nat nat + ([go + (λ (n) + (cond [(hash-has-key? (unbox tab) n) + (let ([ij (hash-ref (unbox tab) n)]) + (values (car ij) (cdr ij)))] + [(= n 0) ;; find the first element + (find 0 0 0)] + [else ;; recur + (call-with-values + (λ () (go (- n 1))) + (λ (ai bi) + (find ai (- n bi 1) n)))]))] + ;; find : nat nat nat -> nat + [find + ;; start is our starting eindex + ;; seen is how many we've already seen + (λ (start seen n) + (let loop ([ai start] + [seen seen]) + (let* ([a (decode e ai)] + [bs (f a)]) + (cond [(< (- n seen) + (size bs)) + (let ([bi (- n seen)]) + (begin + (set-box! tab + (hash-set (unbox tab) + n + (cons ai bi))) + (values ai bi)))] + [else + (loop (+ ai 1) + (+ seen (size bs)))]))))]) + (go n))) + (λ (ai bi) + (let ([a (decode e ai)]) + (cons a + (decode (f a) bi)))))) + ;; todo: memoize encode + (λ (ab) ;; axb -> n + (let ([ai (encode e (car ab))]) + (+ (let loop ([i 0] + [sum 0]) + (if (>= i ai) + sum + (loop (+ i 1) + (+ sum + (size (f (decode e i))))))) + (encode (f (car ab)) + (cdr ab)))))))] + [else ;; both infinite, same as prod/enum + (dep/enum e f)])) + + + +;; more utility enums +;; nats of course +(define (range/enum low high) + (cond [(> low high) (error 'bad-range)] + [(infinite? high) + (if (infinite? low) + ints + (map/enum + (λ (n) + (+ n low)) + (λ (n) + (- n low)) + nats))] + [(infinite? low) + (map/enum + (λ (n) + (- high n)) + (λ (n) + (+ high n)) + nats)] + [else + (map/enum (λ (n) (+ n low)) + (λ (n) (- n low)) + (take/enum nats (+ 1 (- high low))))])) + +;; thunk/enum : Nat or +-Inf, ( -> enum a) -> enum a +(define (thunk/enum s thunk) + (enum s + (λ (n) + (decode (thunk) n)) + (λ (x) + (encode (thunk) x)))) + +;; listof/enum : enum a -> enum (listof a) +(define (listof/enum e) + (thunk/enum + (if (= 0 (size e)) + 0 + +inf.f) + (λ () + (sum/enum + (const/enum '()) + (prod/enum e (listof/enum e)))))) + +;; list/enum : listof (enum any) -> enum (listof any) +(define (list/enum es) + (apply prod/enum (append es `(,(const/enum '()))))) + +(define (nats+/enum n) + (map/enum (λ (k) + (+ k n)) + (λ (k) + (- k n)) + nats)) + +;; fail/enum : exn -> enum () +;; returns an enum that calls a thunk +(define (fail/enum e) + (let ([t + (λ (_) + (raise e))]) + (enum 1 + t + t))) + +(module+ + test + (require rackunit) + (provide check-bijection?) + (define confidence 1000) + (define nums (build-list confidence identity)) + (define-simple-check (check-bijection? e) + (let ([nums (build-list (if (<= (enum-size e) confidence) + (enum-size e) + confidence) + identity)]) + (andmap = + nums + (map (λ (n) + (encode e (decode e n))) + nums)))) + + ;; const/enum tests + (let [(e (const/enum 17))] + (test-begin + (check-eq? (decode e 0) 17) + (check-exn exn:fail? + (λ () + (decode e 1))) + (check-eq? (encode e 17) 0) + (check-exn exn:fail? + (λ () + (encode e 0))) + (check-bijection? e))) + + ;; from-list/enum tests + (let [(e (from-list/enum '(5 4 1 8)))] + (test-begin + (check-eq? (decode e 0) 5) + (check-eq? (decode e 3) 8) + (check-exn exn:fail? + (λ () (decode e 4))) + (check-eq? (encode e 5) 0) + (check-eq? (encode e 8) 3) + (check-exn exn:fail? + (λ () + (encode e 17))) + (check-bijection? e))) + + ;; map test + (define nats+1 (nats+/enum 1)) + + (test-begin + (check-equal? (size nats+1) +inf.f) + (check-equal? (decode nats+1 0) 1) + (check-equal? (decode nats+1 1) 2) + (check-bijection? nats+1)) + ;; encode check + (test-begin + (check-exn exn:fail? + (λ () + (decode nats -1)))) + + ;; ints checks + (test-begin + (check-eq? (decode ints 0) 0) ; 0 -> 0 + (check-eq? (decode ints 1) 1) ; 1 -> 1 + (check-eq? (decode ints 2) -1) ; 2 -> 1 + (check-eq? (encode ints 0) 0) + (check-eq? (encode ints 1) 1) + (check-eq? (encode ints -1) 2) + (check-bijection? ints)) ; -1 -> 2, -3 -> 4 + + ;; sum tests + (test-begin + (let [(bool-or-num (sum/enum bools + (from-list/enum '(0 1 2)))) + (bool-or-nat (sum/enum bools + nats)) + (nat-or-bool (sum/enum nats + bools)) + (odd-or-even (sum/enum evens + odds))] + (check-equal? (enum-size bool-or-num) + 5) + (check-equal? (decode bool-or-num 0) #t) + (check-equal? (decode bool-or-num 1) #f) + (check-equal? (decode bool-or-num 2) 0) + (check-exn exn:fail? + (λ () + (decode bool-or-num 5))) + (check-equal? (encode bool-or-num #f) 1) + (check-equal? (encode bool-or-num 2) 4) + (check-bijection? bool-or-num) + + (check-equal? (enum-size bool-or-nat) + +inf.f) + (check-equal? (decode bool-or-nat 0) #t) + (check-equal? (decode bool-or-nat 2) 0) + (check-bijection? bool-or-nat) + + (check-equal? (encode bool-or-num #f) 1) + (check-equal? (encode bool-or-num 2) 4) + + (check-equal? (enum-size odd-or-even) + +inf.f) + (check-equal? (decode odd-or-even 0) 0) + (check-equal? (decode odd-or-even 1) 1) + (check-equal? (decode odd-or-even 2) 2) + (check-exn exn:fail? + (λ () + (decode odd-or-even -1))) + (check-equal? (encode odd-or-even 0) 0) + (check-equal? (encode odd-or-even 1) 1) + (check-equal? (encode odd-or-even 2) 2) + (check-equal? (encode odd-or-even 3) 3) + (check-bijection? odd-or-even))) + + ;; prod/enum tests + (define bool*bool (prod/enum bools bools)) + (define 1*b (prod/enum (const/enum 1) bools)) + (define bool*nats (prod/enum bools nats)) + (define nats*bool (prod/enum nats bools)) + (define nats*nats (prod/enum nats nats)) + (define ns-equal? (λ (ns ms) + (and (= (car ns) + (car ms)) + (= (cdr ns) + (cdr ms))))) + + ;; prod tests + (test-begin + + (check-equal? (size 1*b) 2) + (check-equal? (decode 1*b 0) (cons 1 #t)) + (check-equal? (decode 1*b 1) (cons 1 #f)) + (check-bijection? 1*b) + (check-equal? (enum-size bool*bool) 4) + (check-equal? (decode bool*bool 0) + (cons #t #t)) + (check-equal? (decode bool*bool 1) + (cons #t #f)) + (check-equal? (decode bool*bool 2) + (cons #f #t)) + (check-equal? (decode bool*bool 3) + (cons #f #f)) + (check-bijection? bool*bool) + + (check-equal? (enum-size bool*nats) +inf.f) + (check-equal? (decode bool*nats 0) + (cons #t 0)) + (check-equal? (decode bool*nats 1) + (cons #f 0)) + (check-equal? (decode bool*nats 2) + (cons #t 1)) + (check-equal? (decode bool*nats 3) + (cons #f 1)) + (check-bijection? bool*nats) + + (check-equal? (enum-size nats*bool) +inf.f) + (check-equal? (decode nats*bool 0) + (cons 0 #t)) + (check-equal? (decode nats*bool 1) + (cons 0 #f)) + (check-equal? (decode nats*bool 2) + (cons 1 #t)) + (check-equal? (decode nats*bool 3) + (cons 1 #f)) + (check-bijection? nats*bool) + + (check-equal? (enum-size nats*nats) +inf.f) + (check ns-equal? + (decode nats*nats 0) + (cons 0 0)) + (check ns-equal? + (decode nats*nats 1) + (cons 0 1)) + (check ns-equal? + (decode nats*nats 2) + (cons 1 0)) + (check ns-equal? + (decode nats*nats 3) + (cons 0 2)) + (check ns-equal? + (decode nats*nats 4) + (cons 1 1)) + (check-bijection? nats*nats)) + + + ;; dep/enum tests + (define (up-to n) + (take/enum nats (+ n 1))) + + (define 3-up + (dep/enum + (from-list/enum '(0 1 2)) + up-to)) + + (define from-3 + (dep/enum + (from-list/enum '(0 1 2)) + nats+/enum)) + + (define nats-to + (dep/enum nats up-to)) + + (define nats-up + (dep/enum nats nats+/enum)) + + (test-begin + (check-equal? (size 3-up) 6) + (check-equal? (decode 3-up 0) (cons 0 0)) + (check-equal? (decode 3-up 1) (cons 1 0)) + (check-equal? (decode 3-up 2) (cons 1 1)) + (check-equal? (decode 3-up 3) (cons 2 0)) + (check-equal? (decode 3-up 4) (cons 2 1)) + (check-equal? (decode 3-up 5) (cons 2 2)) + (check-bijection? 3-up) + + (check-equal? (size from-3) +inf.f) + (check-equal? (decode from-3 0) (cons 0 0)) + (check-equal? (decode from-3 1) (cons 1 1)) + (check-equal? (decode from-3 2) (cons 2 2)) + (check-equal? (decode from-3 3) (cons 0 1)) + (check-equal? (decode from-3 4) (cons 1 2)) + (check-equal? (decode from-3 5) (cons 2 3)) + (check-equal? (decode from-3 6) (cons 0 2)) + (check-bijection? from-3) + + (check-equal? (size nats-to) +inf.f) + (check-equal? (decode nats-to 0) (cons 0 0)) + (check-equal? (decode nats-to 1) (cons 1 0)) + (check-equal? (decode nats-to 2) (cons 1 1)) + (check-equal? (decode nats-to 3) (cons 2 0)) + (check-equal? (decode nats-to 4) (cons 2 1)) + (check-equal? (decode nats-to 5) (cons 2 2)) + (check-equal? (decode nats-to 6) (cons 3 0)) + (check-bijection? nats-to) + + (check-equal? (size nats-up) +inf.f) + (check-equal? (decode nats-up 0) (cons 0 0)) + (check-equal? (decode nats-up 1) (cons 0 1)) + (check-equal? (decode nats-up 2) (cons 1 1)) + (check-equal? (decode nats-up 3) (cons 0 2)) + (check-equal? (decode nats-up 4) (cons 1 2)) + (check-equal? (decode nats-up 5) (cons 2 2)) + (check-equal? (decode nats-up 6) (cons 0 3)) + (check-equal? (decode nats-up 7) (cons 1 3)) + + (check-bijection? nats-up)) + + ;; dep2/enum tests + ;; same as dep unless the right side is finite + (define 3-up-2 + (dep2/enum + (from-list/enum '(0 1 2)) + up-to)) + + (define nats-to-2 + (dep2/enum nats up-to)) + + + (test-begin + (check-equal? (size 3-up-2) 6) + (check-equal? (decode 3-up-2 0) (cons 0 0)) + (check-equal? (decode 3-up-2 1) (cons 1 0)) + (check-equal? (decode 3-up-2 2) (cons 1 1)) + (check-equal? (decode 3-up-2 3) (cons 2 0)) + (check-equal? (decode 3-up-2 4) (cons 2 1)) + (check-equal? (decode 3-up-2 5) (cons 2 2)) + (check-bijection? 3-up-2) + + (check-equal? (size nats-to-2) +inf.f) + (check-equal? (decode nats-to-2 0) (cons 0 0)) + (check-equal? (decode nats-to-2 1) (cons 1 0)) + (check-equal? (decode nats-to-2 2) (cons 1 1)) + (check-equal? (decode nats-to-2 3) (cons 2 0)) + (check-equal? (decode nats-to-2 4) (cons 2 1)) + (check-equal? (decode nats-to-2 5) (cons 2 2)) + (check-equal? (decode nats-to-2 6) (cons 3 0)) + (check-bijection? nats-to-2) + ) + + + ;; take/enum test + (define to-2 (up-to 2)) + (test-begin + (check-equal? (size to-2) 3) + (check-equal? (decode to-2 0) 0) + (check-equal? (decode to-2 1) 1) + (check-equal? (decode to-2 2) 2) + (check-bijection? to-2)) + + ;; to-list, foldl test + (test-begin + (check-equal? (to-list (up-to 3)) + '(0 1 2 3)) + (check-equal? (foldl-enum cons '() (up-to 3)) + '(3 2 1 0))))