racket/collects/redex/private/enum.rkt
2013-05-20 19:25:07 -05:00

551 lines
16 KiB
Racket

#lang racket/base
(require racket/contract
racket/list
racket/match
racket/function
racket/set
"lang-struct.rkt"
"match-a-pattern.rkt"
"enumerator.rkt")
(provide
(contract-out
[lang-enumerators (-> (listof nt?) lang-enum?)]
[pat-enumerator (-> lang-enum?
any/c ;; pattern
enum?)]
[enum-ith (-> enum? exact-nonnegative-integer? any/c)]
[lang-enum? (-> any/c boolean?)]
[enum? (-> any/c boolean?)]))
(struct lang-enum (enums))
(struct decomposition (ctx term))
(struct named (name val))
(struct named-t (val term))
(struct name (name) #:transparent)
(struct unimplemented (msg))
(define enum-ith decode)
(define (lang-enumerators lang)
(let ([l-enums (make-hash)])
(let-values ([(fin-lang rec-lang)
(sep-lang lang)])
(for-each
(λ (nt)
(hash-set! l-enums
(nt-name nt)
(enumerate-rhss (nt-rhs nt)
l-enums)))
fin-lang)
(for-each
(λ (nt)
(hash-set! l-enums
(nt-name nt)
(thunk/enum +inf.f
(λ ()
(enumerate-rhss (nt-rhs nt)
l-enums)))))
rec-lang))
(lang-enum l-enums)))
(define (pat-enumerator l-enum pat)
(map/enum
to-term
(λ (_)
(error 'pat-enum "Enumerator is not a bijection"))
(pat/enum pat
(lang-enum-enums l-enum))))
(define (enumerate-rhss rhss l-enums)
(apply sum/enum
(map
(λ (rhs)
(pat/enum (rhs-pattern rhs)
l-enums))
rhss)))
;; find-edges : lang -> (hash symbol -o> (setof symbol))
(define (find-edges lang)
(foldl
(λ (nt m)
(hash-set
m (nt-name nt)
(fold-map/set
(λ (rhs)
(let loop ([pat (rhs-pattern rhs)]
[s (set)])
(match-a-pattern
pat
[`any s]
[`number s]
[`string s]
[`natural s]
[`integer s]
[`real s]
[`boolean s]
[`variable s]
[`(variable-except ,v ...) s]
[`(variable-prefix ,v) s]
[`variable-not-otherwise-mentioned s]
[`hole s]
[`(nt ,id)
(set-add s id)]
[`(name ,name ,pat)
(loop pat s)]
[`(mismatch-name ,name ,pat)
(loop pat s)]
[`(in-hole ,p1 ,p2)
(set-union (loop p1 s)
(loop p2 s))]
[`(hide-hole ,p) (loop p s)]
[`(side-condition ,p ,g ,e) s]
[`(cross ,s) s]
[`(list ,sub-pats ...)
(fold-map/set
(λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(loop pat s)]
[else (loop sub-pat s)]))
sub-pats)]
[(? (compose not pair?)) s])))
(nt-rhs nt))))
(hash)
lang))
;; find-cycles : (hashsymbol -o> (setof symbol)) -> (setof symbol)
(define (find-cycles edges)
(foldl
(λ (v s)
(if (let rec ([cur v]
[seen (set)])
(cond [(set-member? seen cur) #t]
[else
(ormap
(λ (next)
(rec next
(set-add seen cur)))
(set->list (hash-ref edges
cur)))]))
(set-add s v)
s))
(set)
(hash-keys edges)))
;; calls-rec? : pat (setof symbol) -> bool
(define (calls-rec? pat recs)
(let rec ([pat pat])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(set-member? recs id)]
[`(name ,name ,pat)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(unsupported/enum pat)]
[`(cross ,s)
(unsupported/enum pat)] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
;; fold-map : (a -> setof b) (listof a) -> (setof b)
(define (fold-map/set f l)
(foldl
(λ (x s)
(set-union (f x) s))
(set)
l))
;; sep-lang : lang -> lang lang
;; topologically sorts non-terminals by dependency
;; sorts rhs's so that recursive ones go last
(define (sep-lang lang)
(define (filter-edges edges lang)
(foldl
(λ (nt m)
(let ([name (nt-name nt)])
(hash-set m name
(hash-ref edges name))))
(hash)
lang))
(let* ([edges (find-edges lang)]
[cyclic-nts (find-cycles edges)])
(let-values ([(cyclic non-cyclic)
(partition (λ (nt)
(set-member? cyclic-nts (nt-name nt)))
lang)])
(let ([sorted-left (topo-sort non-cyclic
(filter-edges edges non-cyclic))] ;; topological sort
[sorted-right (sort-nt-terms cyclic
cyclic-nts)] ;; rhs sort
)
(values sorted-left
sorted-right)))))
;; recursive-rhss : lang (hash symbol -o> (setof symbol)) -> (hash symbol -o> (assoclist rhs bool))
(define (recursive-rhss lang recs)
(foldl
(λ (nt m)
(let ([rhs (nt-rhs nt)])
(hash-set m (nt-name nt)
(map (λ (rhs)
(cons rhs
(calls-rec? (rhs-pattern rhs)
recs)))
rhs))))
(hash)
lang))
;; topo-sort : lang (hash symbol -o> (setof symbol)) -> lang
(define (topo-sort lang edges)
(define (find-top rem edges)
(let find ([rem rem])
(let ([v (car rem)])
(let check ([vs (hash-keys edges)])
(cond [(empty? vs) v]
[(set-member? (hash-ref edges (car vs))
v)
(find (cdr rem))]
[else (check (cdr vs))])))))
(let loop ([rem (hash-keys edges)]
[edges edges]
[out-lang '()])
(cond [(empty? rem) out-lang]
[else
(let ([v (find-top rem edges)])
(loop (remove v rem)
(hash-remove edges v)
(cons
(findf
(λ (nt)
(eq? v (nt-name nt)))
lang)
out-lang)))])))
;; sort-nt-terms : lang (setof symbol) -> lang
(define (sort-nt-terms lang nts)
(let ([recs (recursive-rhss lang nts)])
(map
(λ (nt)
(let ([rec-nts (hash-ref recs (nt-name nt))])
(make-nt (nt-name nt)
(sort (nt-rhs nt)
(λ (r1 r2)
(and (not (cdr (assoc r1 rec-nts)))
(cdr (assoc r2 rec-nts))))))))
lang)))
(define (pat/enum pat l-enums)
(enum-names pat
(sep-names pat)
l-enums))
;; sep-names : single-pattern lang -> (assoclist symbol pattern)
(define (sep-names pat)
(let loop ([pat pat]
[named-pats '()])
(match-a-pattern
pat
[`any named-pats]
[`number named-pats]
[`string named-pats]
[`natural named-pats]
[`integer named-pats]
[`real named-pats]
[`boolean named-pats]
[`variable named-pats]
[`(variable-except ,s ...) named-pats]
[`(variable-prefix ,s) named-pats]
[`variable-not-otherwise-mentioned named-pats]
[`hole named-pats]
;; names inside nts are separate
[`(nt ,id) named-pats]
[`(name ,name ,pat)
(loop pat
(add-if-new name pat named-pats))]
[`(mismatch-name ,name ,pat)
(loop pat (cons (unimplemented "mismatch") named-pats))]
[`(in-hole ,p1 ,p2)
(loop p2
(loop p1 named-pats))]
[`(hide-hole ,p) (loop p named-pats)]
[`(side-condition ,p ,g ,e) ;; not supported
named-pats]
[`(cross ,s)
named-pats] ;; not supported
[`(list ,sub-pats ...)
(foldl (λ (sub-pat named-pats)
(match sub-pat
[`(repeat ,pat #f #f)
(loop pat named-pats)]
[`(repeat ,pat ,name ,mismatch)
(loop pat (cons (unimplemented "named/mismatched repeat") named-pats))]
[else (loop sub-pat named-pats)]))
named-pats
sub-pats)]
[(? (compose not pair?))
named-pats])))
(define (add-if-new k v l)
(cond [(assoc-named k l) l]
[else (cons (named k v) l)]))
(define (assoc-named n l)
(cond [(null? l) #f]
[else
(or (let ([cur (car l)])
(and (named? cur)
(equal? (named-name cur)
n)))
(assoc-named n (cdr l)))]))
(define (enum-names pat named-pats nt-enums)
(let rec ([named-pats named-pats]
[env (hash)])
(cond [(null? named-pats)
(pat/enum-with-names pat nt-enums env)]
[else
(let ([cur (car named-pats)])
(cond ([named? cur]
(let ([name (named-name cur)]
[pat (named-val cur)])
(map/enum
(λ (ts)
(named name
(named-t (car ts)
(cdr ts))))
(λ (n)
(if (equal? (named-name n)
name)
(let ([val (named-val n)])
(cons (named-t-val val)
(named-t-term val)))
(error 'wrong-name
"expected ~a, got ~a"
name
(named-name n))))
(dep/enum
(pat/enum-with-names pat nt-enums env)
(λ (term)
(rec (cdr named-pats)
(hash-set env
name
term)))))))
[else (error/enum 'unimplemented
(unimplemented-msg cur))]))])))
(define (pat/enum-with-names pat nt-enums named-terms)
(let loop ([pat pat])
(match-a-pattern
pat
[`any
(sum/enum
any/enum
(listof/enum any/enum))]
[`number num/enum]
[`string string/enum]
[`natural natural/enum]
[`integer integer/enum]
[`real real/enum]
[`boolean bool/enum]
[`variable var/enum]
[`(variable-except ,s ...)
(apply except/enum var/enum s)]
[`(variable-prefix ,s)
;; todo
(error/enum 'unimplemented "var-prefix")]
[`variable-not-otherwise-mentioned
(error/enum 'unimplemented "var-not-mentioned")] ;; error
[`hole
(const/enum the-hole)]
[`(nt ,id)
(hash-ref nt-enums id)]
[`(name ,n ,pat)
(const/enum (name n))]
[`(mismatch-name ,name ,pat)
(error/enum 'unimplemented "mismatch-name")]
[`(in-hole ,p1 ,p2) ;; untested
(map/enum
(λ (ts)
(decomposition (car ts)
(cdr ts)))
(λ (decomp)
(cons (decomposition-ctx decomp)
(decomposition-term decomp)))
(prod/enum
(loop p1)
(loop p2)))]
[`(hide-hole ,p)
(loop p)]
[`(side-condition ,p ,g ,e)
(unsupported/enum pat)]
[`(cross ,s)
(unsupported/enum pat)]
[`(list ,sub-pats ...)
;; enum-list
(map/enum
flatten-1
identity
(list/enum
(map
(λ (sub-pat)
(match sub-pat
[`(repeat ,pat #f #f)
(map/enum
cdr
(λ (ts)
(cons (length ts)
ts))
(dep/enum
nats
(λ (n)
(list/enum
(build-list n (const (loop pat)))))))]
[`(repeat ,pat ,name #f)
(error/enum 'unimplemented "named-repeat")]
[`(repeat ,pat #f ,mismatch)
(error/enum 'unimplemented "mismatch-repeat")]
[else (map/enum
list
car
(loop sub-pat))]))
sub-pats)))]
[(? (compose not pair?))
(const/enum pat)])))
(define (flatten-1 xs)
(append-map
(λ (x)
(if (or (pair? x)
(null? x))
x
(list x)))
xs))
;; lookup : lang symbol -> (listof rhs)
(define (lookup nts name)
(let rec ([nts nts])
(cond [(null? nts) (error 'unkown-nt)]
[(eq? name (nt-name (car nts)))
(nt-rhs (car nts))]
[else (rec (cdr nts))])))
(define natural/enum nats)
(define char/enum
(map/enum
integer->char
char->integer
(range/enum #x61 #x7a)))
(define string/enum
(map/enum
list->string
string->list
(listof/enum char/enum)))
(define integer/enum
(sum/enum nats
(map/enum (λ (n) (- (+ n 1)))
(λ (n) (- (- n) 1))
nats)))
(define real/enum (from-list/enum '(0.0 1.5 123.112354)))
(define num/enum
(sum/enum natural/enum
integer/enum
real/enum))
(define bool/enum
(from-list/enum '(#t #f)))
(define var/enum
(map/enum
(compose string->symbol list->string list)
(compose car string->list symbol->string)
char/enum))
(define any/enum
(sum/enum num/enum
string/enum
bool/enum
var/enum))
(define (to-term aug)
(cond [(named? aug)
(rep-name aug)]
[(decomposition? aug)
(plug-hole aug)]
[else aug]))
(define (rep-name s)
(to-term
(let* ([n (named-name s)]
[v (named-val s)]
[val (named-t-val v)]
[term (named-t-term v)])
(let loop ([term term])
(cond [(and (name? term)
(equal? (name-name term) n))
val]
[(cons? term)
(map loop term)]
[(named? term)
(map-named loop
term)]
[(decomposition? term)
(map-decomp loop
term)]
[else term])))))
(define (plug-hole ctx term)
(to-term
(let loop ([ctx ctx])
(cond [(hole? ctx) term]
[(cons? ctx) (map loop ctx)]
[(named? )])
(match
ctx
['hole term]
[`(,ts ...)
(map loop ts)]
[x x]))))
(define (map-decomp f dcmp)
(let ([ctx (decomposition-ctx dcmp)]
[term (decomposition-term dcmp)])
(decomposition (f ctx)
(f term))))
(define (map-named f n)
(let ([v (named-val n)])
(named (named-name n)
(named-t
(named-t-val v)
(f (named-t-term v))))))