Redex enum maintains bijection internally

Also simplified other redex enumeration internals.
This commit is contained in:
Max New 2013-05-15 09:42:53 -05:00
parent c0f45d7d99
commit 843edcc78d
2 changed files with 349 additions and 245 deletions

View File

@ -3,172 +3,270 @@
racket/list
racket/match
racket/function
racket/set
"lang-struct.rkt"
"match-a-pattern.rkt"
"enumerator.rkt")
(provide
(contract-out
[lang-enumerators (-> (listof nt?) (hash/c symbol? enum?))]
[pat-enumerator (-> (hash/c symbol? enum?)
[lang-enumerators (-> (listof nt?) lang-enum?)]
[pat-enumerator (-> lang-enum?
any/c ;; pattern
enum?)]
[enum-ith (-> enum? exact-nonnegative-integer? any/c)]
[lang-enum? (-> any/c boolean?)]
[enum? (-> any/c boolean?)]))
(define (lang-enumerators nts)
(let* ([l-enums (make-hash)]
[rec-nt-terms (find-recs nts)]
[sorted-nts (sort-nt-terms nts rec-nt-terms)])
(foldl
(λ (nt m)
(hash-set
m
(nt-name nt)
(with-handlers
([exn:fail? fail/enum])
(rec-pat/enum `(nt ,(nt-name nt))
sorted-nts
rec-nt-terms))))
(hash)
sorted-nts)))
(struct lang-enum (enums))
(struct decomposition (ctx term))
(struct hole ())
(struct named (name val))
(struct named-t (val term))
(struct name (name) #:transparent)
(struct unimplemented (msg))
(define enum-ith decode)
(struct decomposition (ctx term))
(define (pat-enumerator lang-enums pat)
(enum-names pat
(sep-names pat)
lang-enums))
(define (lang-enumerators lang)
(let ([l-enums (make-hash)])
(let-values ([(fin-lang rec-lang)
(sep-lang lang)])
(for-each
(λ (nt)
(hash-set! l-enums
(nt-name nt)
(enumerate-rhss (nt-rhs nt)
l-enums)))
fin-lang)
(for-each
(λ (nt)
(hash-set! l-enums
(nt-name nt)
(thunk/enum +inf.f
(λ ()
(enumerate-rhss (nt-rhs nt)
l-enums)))))
rec-lang))
(lang-enum l-enums)))
(define (rec-pat/enum pat nts rec-nt-terms)
(enum-names pat
nts
(sep-names pat)
rec-nt-terms))
(define (pat-enumerator l-enum pat)
(map/enum
to-term
(λ (_)
(error 'pat-enum "Enumerator is not a bijection"))
(pat/enum pat
(lang-enum-enums l-enum))))
(define (enumerate-rhss rhss l-enums)
(apply sum/enum
(map
(λ (rhs)
(pat/enum (rhs-pattern rhs)
l-enums))
rhss)))
;; find-recs : lang -> (hash symbol -o> (assoclist rhs bool))
;; Identifies which non-terminals are recursive
(define (find-recs nt-pats)
(define is-rec?
(case-lambda
[(n) (is-rec? n (hash))]
[(nt seen)
(or (seen? seen (nt-name nt))
(ormap
(λ (rhs)
(let rec ([pat (rhs-pattern rhs)])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(is-rec? (make-nt
id
(lookup nt-pats id))
(add-seen seen
(nt-name nt)))]
[`(name ,name ,pat)
;; find-edges : lang -> (hash symbol -o> (setof symbol))
(define (find-edges lang)
(foldl
(λ (nt m)
(hash-set
m (nt-name nt)
(fold-map/set
(λ (rhs)
(let loop ([pat (rhs-pattern rhs)]
[s (set)])
(match-a-pattern
pat
[`any s]
[`number s]
[`string s]
[`natural s]
[`integer s]
[`real s]
[`boolean s]
[`variable s]
[`(variable-except ,v ...) s]
[`(variable-prefix ,v) s]
[`variable-not-otherwise-mentioned s]
[`hole s]
[`(nt ,id)
(set-add s id)]
[`(name ,name ,pat)
(loop pat s)]
[`(mismatch-name ,name ,pat)
(loop pat s)]
[`(in-hole ,p1 ,p2)
(set-union (loop p1 s)
(loop p2 s))]
[`(hide-hole ,p) (loop p s)]
[`(side-condition ,p ,g ,e) ;; error
(unsupported/enum pat)]
[`(cross ,s)
(unsupported/enum pat)] ;; error
[`(list ,sub-pats ...)
(fold-map/set
(λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(loop pat s)]
[else (loop sub-pat s)]))
sub-pats)]
[(? (compose not pair?)) s])))
(nt-rhs nt))))
(hash)
lang))
;; find-cycles : (hashsymbol -o> (setof symbol)) -> (setof symbol)
(define (find-cycles edges)
(foldl
(λ (v s)
(if (let rec ([cur v]
[seen (set)])
(cond [(set-member? seen cur) #t]
[else
(ormap
(λ (next)
(rec next
(set-add seen cur)))
(set->list (hash-ref edges
cur)))]))
(set-add s v)
s))
(set)
(hash-keys edges)))
;; calls-rec? : pat (setof symbol) -> bool
(define (calls-rec? pat recs)
(let rec ([pat pat])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(set-member? recs id)]
[`(name ,name ,pat)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(unsupported/enum pat)]
[`(cross ,s)
(unsupported/enum pat)] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(unsupported/enum pat)]
[`(cross ,s)
(unsupported/enum pat)] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
(nt-rhs nt)))]))
(define (calls-rec? rhs recs)
(let rec ([pat (rhs-pattern rhs)])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(hash-ref recs id)]
[`(name ,name ,pat)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(rec p)]
[`(cross ,s)
(unsupported/enum pat)] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
(define (seen? m s)
(hash-ref m s #f))
(define (add-seen m s)
(hash-set m s #t))
(let ([recs
(foldl
(λ (nt m)
(hash-set m (nt-name nt) (is-rec? nt)))
(hash) nt-pats)])
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
;; fold-map : (a -> setof b) (listof a) -> (setof b)
(define (fold-map/set f l)
(foldl
(λ (x s)
(set-union (f x) s))
(set)
l))
;; sep-lang : lang -> lang lang
;; topologically sorts non-terminals by dependency
;; sorts rhs's so that recursive ones go last
(define (sep-lang lang)
(define (filter-edges edges lang)
(foldl
(λ (nt m)
(let ([rhs (nt-rhs nt)])
(hash-set m (nt-name nt)
(map (λ (rhs)
(cons rhs (calls-rec? rhs recs)))
rhs))))
(let ([name (nt-name nt)])
(hash-set m name
(hash-ref edges name))))
(hash)
nt-pats)))
lang))
(let* ([edges (find-edges lang)]
[cyclic-nts (find-cycles edges)])
(let-values ([(cyclic non-cyclic)
(partition (λ (nt)
(set-member? cyclic-nts (nt-name nt)))
lang)])
(let ([sorted-left (topo-sort non-cyclic
(filter-edges edges non-cyclic))] ;; topological sort
[sorted-right (sort-nt-terms cyclic
cyclic-nts)] ;; rhs sort
)
(values sorted-left
sorted-right)))))
;; sort-nt-terms : lang (hash symbol -o> (assoclist rhs bool)) -> lang
(define (sort-nt-terms nt-pats recs)
(map
(λ (nt)
(let ([rec-nts (hash-ref recs (nt-name nt))])
(make-nt (nt-name nt)
(sort (nt-rhs nt)
(λ (r1 r2)
(and (not (cdr (assoc r1 rec-nts)))
(cdr (assoc r2 rec-nts))))))))
nt-pats))
;; recursive-rhss : lang (hash symbol -o> (setof symbol)) -> (hash symbol -o> (assoclist rhs bool))
(define (recursive-rhss lang recs)
(foldl
(λ (nt m)
(let ([rhs (nt-rhs nt)])
(hash-set m (nt-name nt)
(map (λ (rhs)
(cons rhs
(calls-rec? (rhs-pattern rhs)
recs)))
rhs))))
(hash)
lang))
;; topo-sort : lang (hash symbol -o> (setof symbol)) -> lang
(define (topo-sort lang edges)
(define (find-top rem edges)
(let find ([rem rem])
(let ([v (car rem)])
(let check ([vs (hash-keys edges)])
(cond [(empty? vs) v]
[(set-member? (hash-ref edges (car vs))
v)
(find (cdr rem))]
[else (check (cdr vs))])))))
(let loop ([rem (hash-keys edges)]
[edges edges]
[out-lang '()])
(cond [(empty? rem) out-lang]
[else
(let ([v (find-top rem edges)])
(loop (remove v rem)
(hash-remove edges v)
(cons
(findf
(λ (nt)
(eq? v (nt-name nt)))
lang)
out-lang)))])))
;; sort-nt-terms : lang (setof symbol) -> lang
(define (sort-nt-terms lang nts)
(let ([recs (recursive-rhss lang nts)])
(map
(λ (nt)
(let ([rec-nts (hash-ref recs (nt-name nt))])
(make-nt (nt-name nt)
(sort (nt-rhs nt)
(λ (r1 r2)
(and (not (cdr (assoc r1 rec-nts)))
(cdr (assoc r2 rec-nts))))))))
lang)))
(define (pat/enum pat l-enums)
(enum-names pat
(sep-names pat)
l-enums))
;; sep-names : single-pattern lang -> (assoclist symbol pattern)
(define (sep-names pat)
@ -194,8 +292,7 @@
(loop pat
(add-if-new name pat named-pats))]
[`(mismatch-name ,name ,pat)
(loop pat
(add-if-new name pat named-pats))]
(loop pat (cons (unimplemented "mismatch") named-pats))]
[`(in-hole ,p1 ,p2)
(loop p2
(loop p1 named-pats))]
@ -207,18 +304,12 @@
[`(list ,sub-pats ...)
(foldl (λ (sub-pat named-pats)
(match sub-pat
;; unnamed repeat
[`(repeat ,pat #f #f)
(loop pat named-pats)]
;; named repeat
[`(repeat ,pat ,name #f)
(loop pat
(add-if-new name 'name-r named-pats))]
;; mismatch named repeat
(loop pat (cons (unimplemented "named repeat") named-pats))]
[`(repeat ,pat #f ,mismatch)
(loop pat
(add-if-new mismatch 'mismatch-r named-pats))]
;; normal subpattern
(loop pat (cons (unimplemented "mismatch repeat") named-pats))]
[else (loop sub-pat named-pats)]))
named-pats
sub-pats)]
@ -226,82 +317,54 @@
named-pats])))
(define (add-if-new k v l)
(cond [(assoc k l) l]
[else (cons `(,k ,v) l)]))
(cond [(assoc-named k l) l]
[else (cons (named k v) l)]))
(define enum-names
(case-lambda
[(pat named-pats nts)
(enum-names-with
(λ (pat named)
(pat/enum-with-names pat nts named))
pat named-pats)]
[(pat nts named-pats rec-nt-terms)
(enum-names-with
(λ (pat named)
(pat/enum-with-names pat nts named rec-nt-terms))
pat named-pats)]))
(define (assoc-named n l)
(cond [(null? l) #f]
[else
(or (let ([cur (car l)])
(and (named? cur)
(equal? (named-name cur)
n)))
(assoc-named n (cdr l)))]))
(define (enum-names-with f pat named-pats)
(define (enum-names pat named-pats nt-enums)
(let rec ([named-pats named-pats]
[env (hash)])
(cond [(null? named-pats) (f pat env)]
(cond [(null? named-pats)
(pat/enum-with-names pat nt-enums env)]
[else
(match
(car named-pats)
;; named repeat
[`(,name name-r)
(error/enum 'unimplemented "named-repeat")]
;; mismatch repeat
[`(,name mismatch-r)
(error/enum 'unimplemented "mismatch-repeat")]
[`(,name ,pat mismatch)
(error/enum 'unimplemented "mismatch")]
;; named
[`(,name ,pat)
(map/enum ;; loses bijection
cdr
(λ (x) (cons name x))
(dep/enum
(f pat env)
(λ (term)
(rec (cdr named-pats)
(hash-set env
name
term)))))]
[else (error 'bad-assoc)])])))
(let ([cur (car named-pats)])
(cond ([named? cur]
(let ([name (named-name cur)]
[pat (named-val cur)])
(map/enum
(λ (ts)
(named name
(named-t (car ts)
(cdr ts))))
(λ (n)
(if (equal? (named-name n)
name)
(let ([val (named-val n)])
(cons (named-t-val val)
(named-t-term val)))
(error 'wrong-name
"expected ~a, got ~a"
name
(named-name n))))
(dep/enum
(pat/enum-with-names pat nt-enums env)
(λ (term)
(rec (cdr named-pats)
(hash-set env
name
term)))))))
[else (error/enum 'unimplemented
(unimplemented-msg cur))]))])))
(define pat/enum-with-names
(case-lambda
[(pat nt-enums named-terms)
(pat/enum-with-names-with
pat
named-terms
(λ (nt)
(hash-ref nt-enums nt)))]
[(pat nts named-terms rec-nt-terms)
(pat/enum-with-names-with
pat
named-terms
(λ (nt)
(let ([rhss (lookup nts nt)])
(apply sum/enum
(map
(λ (rhs)
(cond [(cdr (assoc rhs (hash-ref rec-nt-terms nt)))
(thunk/enum
+inf.f
(λ ()
(rec-pat/enum (rhs-pattern rhs)
nts
rec-nt-terms)))]
[else
(rec-pat/enum (rhs-pattern rhs)
nts
rec-nt-terms)]))
rhss)))))]))
(define (pat/enum-with-names-with pat named-terms f)
(define (pat/enum-with-names pat nt-enums named-terms)
(let loop ([pat pat])
(match-a-pattern
pat
@ -326,18 +389,19 @@
[`hole
(const/enum 'hole)]
[`(nt ,id)
(f id)]
[`(name ,name ,pat)
(const/enum (hash-ref named-terms name))]
(hash-ref nt-enums id)]
[`(name ,n ,pat)
(const/enum (name n))]
[`(mismatch-name ,name ,pat)
(error/enum 'unimplemented "mismatch-name")]
[`(in-hole ,p1 ,p2) ;; untested
(map/enum
(λ (t1-t2) ;; loses bijection
(plug-hole (car t1-t2)
(cdr t1-t2)))
(λ (plugged)
(cons 'hole plugged))
(λ (ts)
(decomposition (car ts)
(cdr ts)))
(λ (decomp)
(cons (decomposition-ctx decomp)
(decomposition-term decomp)))
(prod/enum
(loop p1)
(loop p2)))]
@ -437,11 +501,43 @@
bool/enum
var/enum))
(define (to-term aug)
(cond [(named? aug)
(rep-name aug)]
[else aug]))
(define (rep-name s)
(to-term
(let* ([n (named-name s)]
[v (named-val s)]
[val (named-t-val v)]
[term (named-t-term v)])
(let loop ([term term])
(cond [(and (name? term)
(equal? (name-name term) n))
val]
[(named? term)
(map-named loop
term)]
[else term])))))
(define (map-named f n)
(let ([v (named-val n)])
(named (named-name n)
(named-t
(named-t-val v)
(f (named-t-term v))))))
#;
(define (plug-hole ctx term)
(let loop ([ctx ctx])
(match
ctx
['hole term]
[`(,ts ...)
(map loop ts)]
[x x])))
(to-term
(let loop ([ctx ctx])
(cond [(hole? ctx) term]
[(cons? ctx) (map loop ctx)]
[(named? )])
(match
ctx
['hole term]
[`(,ts ...)
(map loop ts)]
[x x]))))

View File

@ -7,11 +7,14 @@
(syntax-case stx ()
[(_ N l p)
(with-syntax ([line (syntax-line stx)])
#'(for ([i (in-range N)])
(unless (redex-match
l p
(generate-term l p #:i-th i))
(error 'bad-term "line ~a: i=~a" line i))))]))
#'(test-begin
(for ([i (in-range N)])
(check-not-exn
(λ ()
(unless (redex-match
l p
(generate-term l p #:i-th i))
(error 'bad-term "line ~a: i=~a" line i)))))))]))
;; Repeat test
(define-language Rep
@ -38,6 +41,11 @@
(try-it 100 Named n)
(define-language not-SKI
(y x
s
k
i)
(x (variable-except s k i)))
(try-it 21 not-SKI x)
(try-it 22 not-SKI x)
(try-it 25 not-SKI y)