racket/collects/redex/private/enum.rkt
Max New b8538ec135 Added Redex enumerators.
Supports names and recursive patterns.
Limited support for repeats and mismatches.
2013-05-09 16:32:53 -05:00

528 lines
12 KiB
Racket

#lang racket/base
(require racket/contract
racket/list
racket/match
racket/function
"lang-struct.rkt"
"match-a-pattern.rkt"
"enumerator.rkt")
(provide
(contract-out
[lang-enumerators (-> (listof nt?) (hash/c symbol? enum?))]
[pat-enumerator (-> (hash/c symbol? enum?)
any/c ;; pattern
enum?)]
[enum-ith (-> enum? exact-nonnegative-integer? any/c)]
[enum? (-> any/c boolean?)]))
(define (lang-enumerators nts)
(let* ([l-enums (make-hash)]
[rec-nt-terms (find-recs nts)]
[sorted-nts (sort-nt-terms nts rec-nt-terms)])
(foldl
(λ (nt m)
(hash-set
m
(nt-name nt)
(with-handlers
([exn:fail? fail/enum])
(rec-pat/enum `(nt ,(nt-name nt))
sorted-nts
rec-nt-terms))))
(hash)
sorted-nts)))
(define enum-ith decode)
(struct decomposition (ctx term))
(define (pat-enumerator lang-enums pat)
(enum-names pat
(sep-names pat)
lang-enums))
(define (rec-pat/enum pat nts rec-nt-terms)
(enum-names pat
nts
(sep-names pat)
rec-nt-terms))
;; find-recs : lang -> (hash symbol -o> (assoclist rhs bool))
;; Identifies which non-terminals are recursive
(define (find-recs nt-pats)
(define is-rec?
(case-lambda
[(n) (is-rec? n (hash))]
[(nt seen)
(or (seen? seen (nt-name nt))
(ormap
(λ (rhs)
(let rec ([pat (rhs-pattern rhs)])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(is-rec? (make-nt
id
(lookup nt-pats id))
(add-seen seen
(nt-name nt)))]
[`(name ,name ,pat)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(error 'unsupported "side-condition")]
[`(cross ,s)
(error 'unsupported "cross")] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
(nt-rhs nt)))]))
(define (calls-rec? rhs recs)
(let rec ([pat (rhs-pattern rhs)])
(match-a-pattern
pat
[`any #f]
[`number #f]
[`string #f]
[`natural #f]
[`integer #f]
[`real #f]
[`boolean #f]
[`variable #f]
[`(variable-except ,s ...) #f]
[`(variable-prefix ,s) #f]
[`variable-not-otherwise-mentioned #f]
[`hole #f]
[`(nt ,id)
(hash-ref recs id)]
[`(name ,name ,pat)
(rec pat)]
[`(mismatch-name ,name ,pat)
(rec pat)]
[`(in-hole ,p1 ,p2)
(or (rec p1)
(rec p2))]
[`(hide-hole ,p) (rec p)]
[`(side-condition ,p ,g ,e) ;; error
(error 'no-enum "side-condition")]
[`(cross ,s)
(error 'no-enum "cross")] ;; error
[`(list ,sub-pats ...)
(ormap (λ (sub-pat)
(match sub-pat
[`(repeat ,pat ,name ,mismatch)
(rec pat)]
[else (rec sub-pat)]))
sub-pats)]
[(? (compose not pair?)) #f])))
(define (seen? m s)
(hash-ref m s #f))
(define (add-seen m s)
(hash-set m s #t))
(let ([recs
(foldl
(λ (nt m)
(hash-set m (nt-name nt) (is-rec? nt)))
(hash) nt-pats)])
(foldl
(λ (nt m)
(let ([rhs (nt-rhs nt)])
(hash-set m (nt-name nt)
(map (λ (rhs)
(cons rhs (calls-rec? rhs recs)))
rhs))))
(hash)
nt-pats)))
;; sort-nt-terms : lang (hash symbol -o> (assoclist rhs bool)) -> lang
(define (sort-nt-terms nt-pats recs)
(map
(λ (nt)
(let ([rec-nts (hash-ref recs (nt-name nt))])
(make-nt (nt-name nt)
(sort (nt-rhs nt)
(λ (r1 r2)
(and (not (cdr (assoc r1 rec-nts)))
(cdr (assoc r2 rec-nts))))))))
nt-pats))
;; sep-names : single-pattern lang -> (assoclist symbol pattern)
(define (sep-names pat)
(let loop ([pat pat]
[named-pats '()])
(match-a-pattern
pat
[`any named-pats]
[`number named-pats]
[`string named-pats]
[`natural named-pats]
[`integer named-pats]
[`real named-pats]
[`boolean named-pats]
[`variable named-pats]
[`(variable-except ,s ...) named-pats]
[`(variable-prefix ,s) named-pats]
[`variable-not-otherwise-mentioned named-pats]
[`hole named-pats]
;; names inside nts are separate
[`(nt ,id) named-pats]
[`(name ,name ,pat)
(loop pat
(add-if-new name pat named-pats))]
[`(mismatch-name ,name ,pat)
(loop pat
(add-if-new name pat named-pats))]
[`(in-hole ,p1 ,p2)
(loop p2
(loop p1 named-pats))]
[`(hide-hole ,p) (loop p named-pats)]
[`(side-condition ,p ,g ,e) ;; error
(error 'no-enum "side condition")]
[`(cross ,s)
(error 'no-enum "cross")] ;; error
[`(list ,sub-pats ...)
(foldl (λ (sub-pat named-pats)
(match sub-pat
;; unnamed repeat
[`(repeat ,pat #f #f)
(loop pat named-pats)]
;; named repeat
[`(repeat ,pat ,name #f)
(loop pat
(add-if-new name 'name-r named-pats))]
;; mismatch named repeat
[`(repeat ,pat #f ,mismatch)
(loop pat
(add-if-new mismatch 'mismatch-r named-pats))]
;; normal subpattern
[else (loop sub-pat named-pats)]))
named-pats
sub-pats)]
[(? (compose not pair?))
named-pats])))
(define (add-if-new k v l)
(cond [(assoc k l) l]
[else (cons `(,k ,v) l)]))
(define enum-names
(case-lambda
[(pat named-pats nts)
(enum-names-with
(λ (pat named)
(pat/enum-with-names pat nts named))
pat named-pats)]
[(pat nts named-pats rec-nt-terms)
(enum-names-with
(λ (pat named)
(pat/enum-with-names pat nts named rec-nt-terms))
pat named-pats)]))
(define (enum-names-with f pat named-pats)
(let rec ([named-pats named-pats]
[env (hash)])
(cond [(null? named-pats) (f pat env)]
[else
(match
(car named-pats)
;; named repeat
[`(,name name-r)
(error 'unimplemented "named-repeat")]
;; mismatch repeat
[`(,name mismatch-r)
(error 'unimplemented "mismatch-repeat")]
[`(,name ,pat mismatch)
(error 'unimplemented "mismatch")]
;; named
[`(,name ,pat)
(map/enum ;; loses bijection
cdr
(λ (x) (cons name x))
(dep/enum
(f pat env)
(λ (term)
(rec (cdr named-pats)
(hash-set env
name
term)))))]
[else (error 'bad-assoc)])])))
(define pat/enum-with-names
(case-lambda
[(pat nt-enums named-terms)
(let loop ([pat pat])
(match-a-pattern
pat
[`any
(sum/enum
any/enum
(listof/enum any/enum))]
[`number num/enum]
[`string string/enum]
[`natural natural/enum]
[`integer integer/enum]
[`real real/enum]
[`boolean bool/enum]
[`variable var/enum]
[`(variable-except ,s ...)
;; todo
(error 'unimplemented "var-except")]
[`(variable-prefix ,s)
;; todo
(error 'unimplemented "var-prefix")]
[`variable-not-otherwise-mentioned
(error 'unimplemented "var-not-mentioned")] ;; error
[`hole
(const/enum 'hole)]
[`(nt ,id)
(hash-ref nt-enums id)]
[`(name ,name ,pat)
(const/enum (hash-ref named-terms name))]
[`(mismatch-name ,name ,pat)
(error 'unimplemented "mismatch-name")]
[`(in-hole ,p1 ,p2) ;; untested
(map/enum
(λ (t1-t2) ;; loses bijection
(plug-hole (car t1-t2)
(cdr t1-t2)))
(λ (plugged)
(cons 'hole plugged))
(prod/enum
(loop p1)
(loop p2)))]
[`(hide-hole ,p)
(loop p)]
[`(side-condition ,p ,g ,e)
(error 'no-enum "side condition")]
[`(cross ,s)
(error 'no-enum "cross")]
[`(list ,sub-pats ...)
;; enum-list
(map/enum
flatten-1
identity
(list/enum
(map
(λ (sub-pat)
(match sub-pat
[`(repeat ,pat #f #f)
(map/enum
cdr
(λ (ts)
(cons (length ts)
ts))
(dep/enum
nats
(λ (n)
(list/enum
(build-list n (const (loop pat)))))))]
[`(repeat ,pat ,name #f)
(error 'unimplemented "named-repeat")]
[`(repeat ,pat #f ,mismatch)
(error 'unimplemented "mismatch-repeat")]
[else (loop sub-pat)]))
sub-pats)))]
[(? (compose not pair?))
(const/enum pat)]))]
[(pat nts named-terms rec-nt-terms)
(let loop ([pat pat])
(match-a-pattern
pat
[`any
(sum/enum
any/enum
(listof/enum any/enum))]
[`number num/enum]
[`string string/enum]
[`natural natural/enum]
[`integer integer/enum]
[`real real/enum]
[`boolean bool/enum]
[`variable var/enum]
[`(variable-except ,s ...)
;; todo
(error 'unimplemented "var except")]
[`(variable-prefix ,s)
;; todo
(error 'unimplemented "var prefix")]
[`variable-not-otherwise-mentioned
(error 'unimplemented "var not otherwise mentioned")]
[`hole
(const/enum 'hole)]
[`(nt ,id)
(let ([rhss (lookup nts id)])
(apply sum/enum
(map
(λ (rhs)
(cond [(cdr (assoc rhs (hash-ref rec-nt-terms id)))
(thunk/enum
+inf.f
(λ ()
(rec-pat/enum (rhs-pattern rhs)
nts
rec-nt-terms)))]
[else
(rec-pat/enum (rhs-pattern rhs)
nts
rec-nt-terms)]))
rhss)))]
[`(name ,name ,pat)
(const/enum (hash-ref named-terms name))]
[`(mismatch-name ,name ,pat)
(error 'unimplemented "mismatch-name")]
[`(in-hole ,p1 ,p2) ;; untested
(map/enum
(λ (t1-t2)
(decomposition (car t1-t2)
(cdr t1-t2)))
(λ (decomp)
(cons (decomposition-ctx decomp)
(decomposition-term decomp)))
(prod/enum
(loop p1)
(loop p2)))]
[`(hide-hole ,p)
;; todo
(loop p)]
[`(side-condition ,p ,g ,e)
(error 'no-enum "side-condition")]
[`(cross ,s)
(error 'no-enum "cross")]
[`(list ,sub-pats ...)
;; enum-list
(map/enum
flatten-1
identity
(list/enum
(map
(λ (sub-pat)
(match sub-pat
[`(repeat ,pat #f #f)
(map/enum
cdr
(λ (ts)
(cons (length ts)
ts))
(dep/enum
nats
(λ (n)
(list/enum
(build-list n (const (loop pat)))))))]
[`(repeat ,pat ,name #f)
(error 'unimplemented "named-repeat")]
[`(repeat ,pat #f ,mismatch)
(error 'unimplemented "mismatch-repeat")]
[else (loop sub-pat)]))
sub-pats)))]
[(? (compose not pair?))
(const/enum pat)]))]))
(define (flatten-1 xs)
(append-map
(λ (x)
(if (or (pair? x)
(null? x))
x
(list x)))
xs))
;; lookup : lang symbol -> (listof rhs)
(define (lookup nts name)
(let rec ([nts nts])
(cond [(null? nts) (error 'unkown-nt)]
[(eq? name (nt-name (car nts)))
(nt-rhs (car nts))]
[else (rec (cdr nts))])))
(define natural/enum nats)
(define char/enum
(map/enum
integer->char
char->integer
(range/enum #x61 #x7a)))
(define string/enum
(map/enum
list->string
string->list
(listof/enum char/enum)))
(define integer/enum
(sum/enum nats
(map/enum (λ (n) (- (+ n 1)))
(λ (n) (- (- n) 1))
nats)))
(define real/enum (from-list/enum '(0.0 1.5 123.112354)))
(define num/enum
(sum/enum natural/enum
integer/enum
real/enum))
(define bool/enum
(from-list/enum '(#t #f)))
(define var/enum
(map/enum
(compose string->symbol list->string list)
(compose car string->list symbol->string)
char/enum))
(define any/enum
(sum/enum num/enum
string/enum
bool/enum
var/enum))
(define (plug-hole ctx term)
(let loop ([ctx ctx])
(match
ctx
['hole term]
[`(,ts ...)
(map loop ts)]
[x x])))
(module+ test
(require rackunit)
(define rep `(,(make-nt 'r
`(,(make-rhs `(list variable
(repeat variable #f #f)))))))
(define rs (hash-ref (lang-enumerators rep) 'r))
(test-begin
(check-equal? (enum-ith rs 0) '(a))
(check-equal? (size rs) +inf.f))
(define λc `(,(make-nt 'e
`(,(make-rhs `(list (repeat variable #f #f)))
,(make-rhs `(list λ variable (nt e)))
,(make-rhs `(list (nt e) (nt e)))))))
(define les (lang-enumerators λc))
(define es (hash-ref les 'e))
(check-equal? (size es) +inf.f))