Fix pattern unification to deal with #f correctly

This commit is contained in:
Burke Fetscher 2013-03-05 21:52:56 -06:00
parent 18ad15c659
commit edcb13e206
3 changed files with 166 additions and 131 deletions

View File

@ -4,10 +4,11 @@
racket/contract
racket/set
racket/match
(for-syntax "rewrite-side-conditions.rkt")
"match-a-pattern.rkt"
"matcher.rkt"
"extract-conditions.rkt")
"extract-conditions.rkt"
(for-syntax "rewrite-side-conditions.rkt"
racket/base))
(provide unify
unify*
@ -21,8 +22,8 @@
empty-env
pat*-clause-p?s
bind-names
remove-empty-dqs)
remove-empty-dqs
unif-fail)
;;
;; atom := `any | `number | `string | `integer | `boolean | `real | `variable | `variable-not-otherwise-mentioned
;; var := symbol?
@ -50,6 +51,18 @@
(struct env (eqs dqs) #:transparent)
(define empty-env (env (hash) '()))
(struct unif-fail () #:transparent)
(define (not-failed? maybe-failed)
(not (unif-fail? maybe-failed)))
(define-syntax (and/fail stx)
(syntax-case stx ()
[(_ conds ... res)
#'(if (and conds ...)
res
(unif-fail))]))
(define predef-pats (set 'any 'number 'string 'integer 'boolean 'real 'variable 'natural 'variable-not-otherwise-mentioned))
(define (predef-pat? a)
(set-member? predef-pats a))
@ -134,13 +147,14 @@
;; pat pat env -> (or/c p*e #f)
(define (unify t u e L)
;(-> pat? pat? env/c compiled-lang? (or/c p*e/c #f))
;(printf "u: ~s ~s ~s\n\n" t u e)
(parameterize ([dqs-found (make-hash)])
(define eqs (hash-copy (env-eqs e)))
(define t* (bind-names t eqs L))
(define u* (bind-names u eqs L))
(define res (and t* u* (unify* t* u* eqs L)))
(and res
(define res (and/fail (not-failed? t*)
(not-failed? u*)
(unify* t* u* eqs L)))
(and (not-failed? res)
(let* ([static-eqs (hash/mut->imm eqs)]
[found-pre-dqs
(apply set-union (set)
@ -169,13 +183,12 @@
;; pat pat env lang -> (or/c env #f)
(define (disunify t u e L)
;(-> pat? pat? env/c any/c (or/c env/c #f))
;(printf "du: ~s ~s\n\n" t u)
(parameterize ([new-eqs (make-hash)])
(define eqs (hash-copy (env-eqs e)))
(define t* (bind-names t eqs L))
(define u* (bind-names u eqs L))
(cond
[(or (not t*) (not u*))
[(or (unif-fail? t*) (unif-fail? u*))
e]
[else
(define bn-eqs (hash-copy eqs))
@ -216,7 +229,6 @@
;; eqs dqs -> dqs or #f
;; simplified - first element in lhs of all inequations is a var not occuring in lhs of eqns
(define (check-and-resimplify eqs dqs L)
;(printf "c-a-r: ~s\n~s\n" dqs eqs)
(define-values (dqs-notok dqs-ok)
(partition (λ (dq)
(hash-has-key?
@ -242,30 +254,14 @@
;; disunfy* pat* pat* eqs lang -> dq or boolean (dq is a pat*)
(define (disunify* u* t* eqs L)
;(printf "du*: ~s ~s ~s\n" t* u* eqs)
(parameterize ([new-eqs (make-hash)])
(let ([res (unify* u* t* eqs L)])
(cond
[(not res) #t]
[(unif-fail? res) #t]
[(empty? (hash-keys (new-eqs))) #f]
[else
(extend-dq (new-eqs) base-dq)]))))
(define (update-env e new-eqs the-dqs)
(env (for/fold ([eqs (env-eqs e)])
([(k v) (in-hash new-eqs)])
(hash-set eqs k v))
the-dqs))
(define (update-ineqs e new-es)
(struct-copy env e
[dqs (cons (for/fold ([dq '((list)(list))])
([(l r) (in-hash new-es)])
(match dq
[`((,vars ...) (,rhss ...))
`((,vars ... (name ,l ,(bound))) (,rhss ... ,r))]))
(env-dqs e))]))
;; the "root" pats will be pats without names,
;; which match both pat and pat*...
@ -278,38 +274,41 @@
(error 'bind-names "pat*, not a pat: ~s\n" pat)]
[`(name ,name ,pat)
(define b-pat (bind-names pat e L))
(and b-pat
(let recur ([id name])
(define res (hash-ref e (lvar id) (uninstantiated)))
(match res
[(uninstantiated)
(when (equal? b-pat (bound))
(error 'bind-names "tried to set something to bound"))
(and (not (occurs?* id b-pat e L))
(hash-set! e (lvar id) b-pat)
;; here we only bind to things in the LOCAL pattern
;; so don't update new-eqs
`(name ,name ,(bound)))]
[(lvar id)
(define next (recur id))
(match next
[`(name ,id-new ,(bound))
(unless (eq? id id-new)
;; path compression: don't update new-eqs here
(hash-set! e (lvar id) (lvar id-new)))]
[_ (void)])
next]
[else ;; some pat* (res is already bound)
(and (unify-update* id b-pat res e L)
`(name ,id ,(bound)))])))]
(and/fail (not-failed? b-pat)
(let recur ([id name])
(define res (hash-ref e (lvar id) (uninstantiated)))
(match res
[(uninstantiated)
(when (equal? b-pat (bound))
(error 'bind-names "tried to set something to bound"))
(and/fail (not (occurs?* id b-pat e L))
(hash-set! e (lvar id) b-pat)
;; here we only bind to things in the LOCAL pattern
;; so don't update new-eqs
`(name ,name ,(bound)))]
[(lvar id)
(define next (recur id))
(match next
[`(name ,id-new ,(bound))
(unless (eq? id id-new)
;; path compression: don't update new-eqs here
(hash-set! e (lvar id) (lvar id-new)))]
[_ (void)])
next]
[else ;; some pat* (res is already bound)
(and/fail (not-failed? (unify-update* id b-pat res e L))
`(name ,id ,(bound)))])))]
[`(list ,pats ...)
(let/ec fail
`(list ,@(for/list ([p pats])
(or (bind-names p e L) (fail #f)))))]
(define res (bind-names p e L))
(if (not-failed? res)
res
(fail (unif-fail))))))]
[`(mismatch-name ,name ,p)
(define b-pat (bind-names p e L))
(and b-pat
`(mismatch-name ,name ,(bind-names p e L)))]
(and/fail (not-failed? b-pat)
`(mismatch-name ,name ,(bind-names p e L)))]
[_ pat]))
@ -317,7 +316,6 @@
(define (unify* t0 u0 e L)
(define t (resolve t0 e))
(define u (resolve u0 e))
;(printf "unify*: ~s ~s\n" t u)
(match* (t u)
;; mismatch patterns
[(`(mismatch-name ,name ,t-pat) u)
@ -334,28 +332,25 @@
[(_ `(name ,name ,(bound)))
(unify* u t e L)]
;; cstrs
#;[(`(nt ,n) `(cstr (,nts ...) ,p)) ;; remove ?? put back?
`(cstr ,(sort (remove-duplicates (cons n nts))
symbol<?) ,p)]
[(`(cstr (,nts1 ...) ,p1) `(cstr (,nts2 ...) ,p2))
(let ([res (unify* p1 p2 e L)])
(and res
(when (lvar? res)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p1 p2 e))
`(cstr ,(merge-ids/sorted nts1 nts2) ,res)))]
(and/fail (not-failed? res)
(when (lvar? res)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p1 p2 e))
`(cstr ,(merge-ids/sorted nts1 nts2) ,res)))]
[(`(cstr ,nts ,p) _)
(let ([res (unify* p u e L)])
(and res
(match res
[(lvar id)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p u e)]
[`(nt ,nt)
`(cstr ,(merge-ids/sorted (list nt) nts)
,p)]
[`(cstr ,nts2 ,new-p)
`(cstr ,(merge-ids/sorted nts nts2) ,new-p)]
[else
`(cstr ,nts ,res)])))]
(and/fail (not-failed? res)
(match res
[(lvar id)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p u e)]
[`(nt ,nt)
`(cstr ,(merge-ids/sorted (list nt) nts)
,p)]
[`(cstr ,nts2 ,new-p)
`(cstr ,(merge-ids/sorted nts nts2) ,new-p)]
[else
`(cstr ,nts ,res)])))]
[(_ `(cstr ,nts ,p))
(unify* `(cstr ,nts ,p) t e L)]
;; nts
@ -367,25 +362,28 @@
(if (hash-has-key? (compiled-lang-collapsible-nts L) p)
(unify* (hash-ref (compiled-lang-collapsible-nts L) p) u e L)
(let ([res (unify* u u e L)])
(and res
(when (lvar? res)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" u u e))
`(cstr (,p) ,res))))]
(and/fail (not-failed? res)
(when (lvar? res)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" u u e))
`(cstr (,p) ,res))))]
[(_ `(nt ,p))
(unify* `(nt ,p) t e L)]
;; other pat stuff
[(`(list ,ts ...) `(list ,us ...))
(and (= (length ts) (length us))
(let/ec fail
`(list ,@(for/list ([t ts] [u us])
(or (unify* t u e L) (fail #f))))))]
(and/fail (= (length ts) (length us))
(let/ec fail
`(list ,@(for/list ([t ts] [u us])
(let ([res (unify* t u e L)])
(if (not-failed? res)
res
(fail (unif-fail))))))))]
[((? number-type? t) (? number-type? u))
(cond
[(number-superset? t u) u]
[(number-superset? u t) t])]
[((? number-type? t) _)
(and ((number-pred t) u)
u)]
(and/fail ((number-pred t) u)
u)]
[(_ (? number-type? u))
(unify* u t e L)]
[(`variable-not-otherwise-mentioned `variable-not-otherwise-mentioned)
@ -395,16 +393,16 @@
[(`variable-not-otherwise-mentioned `variable)
`variable-not-otherwise-mentioned]
[(`variable-not-otherwise-mentioned (? symbol? s))
(and (not (memq s (compiled-lang-literals L)))
(not (base-type? s))
s)]
(and/fail (not (memq s (compiled-lang-literals L)))
(not (base-type? s))
s)]
[(`variable `variable)
`variable]
[(_ `variable)
(unify* u t e L)]
[(`variable (? symbol? s))
(and (not (base-type? s))
s)]
(and/fail (not (base-type? s))
s)]
;; string stuff
[(`string `string)
`string]
@ -417,7 +415,7 @@
[(`boolean `boolean)
`boolean]
[(`string `boolean)
#f]
(unif-fail)]
[(_ `boolean)
(unify* u t e L)]
[(`boolean (? boolean? b))
@ -425,12 +423,12 @@
;; other
[((? base-type? t) (? base-type? u))
(and (equal? t u)
t)]
(and/fail (equal? t u)
t)]
[((? (compose not pair?) t) (? (compose not pair?) u))
(and (equal? t u)
t)]
[(_ _) #f]))
(and/fail (equal? t u)
t)]
[(_ _) (unif-fail)]))
(define (resolve pat env)
(match pat
@ -446,18 +444,17 @@
;; unify-update* : id pat* pat* env lang -> pat* or #f
(define (unify-update* id pat-1 pat-2 e L)
;(printf "unify-update ~s ~s ~s\n" id pat-1 pat-2)
(let ([u-res (unify* pat-1 pat-2 e L)])
(and (not (occurs?* id pat-1 e L))
(not (occurs?* id pat-2 e L))
(when u-res
(when (equal? u-res (bound)) (error 'update "tried to set something to bound"))
(unless (equal? u-res (hash-ref e (lvar id) (uninstantiated)))
(hash-set! e (lvar id) u-res)
(unless (or (nt-only-pat? u-res)
(ground-pat-eq? pat-1 pat-2))
(hash-set! (new-eqs) (lvar id) u-res))))
u-res)))
(and/fail (not (occurs?* id pat-1 e L))
(not (occurs?* id pat-2 e L))
(when (not-failed? u-res)
(when (equal? u-res (bound)) (error 'update "tried to set something to bound"))
(unless (equal? u-res (hash-ref e (lvar id) (uninstantiated)))
(hash-set! e (lvar id) u-res)
(unless (or (nt-only-pat? u-res)
(ground-pat-eq? pat-1 pat-2))
(hash-set! (new-eqs) (lvar id) u-res))))
u-res)))
(define (nt-only-pat? p*)
(match p*
@ -478,7 +475,6 @@
;; TODO: replace name in p*'s with lvar - this is the most obvious of many
;; functions that would be improved by this
(define (occurs?* name p e L)
;(printf "occurs: ~s ~s\n" name p)
(match p
[`(name ,name-p ,(bound))
(or (eq? name name-p)
@ -497,14 +493,13 @@
(define (instantiate* id pat e L)
;(printf "inst*: ~s ~s\n" id pat)
(define id-pat (resolve (lookup-pat id e) e))
(match id-pat
[`(name ,next-id ,(bound))
(and (instantiate* next-id pat e L)
(not (occurs?* id (lvar next-id) e L))
(hash-set! e (lvar id) (lvar next-id))
`(name ,next-id ,(bound)))]
(and/fail (not-failed? (instantiate* next-id pat e L))
(not (occurs?* id (lvar next-id) e L))
(hash-set! e (lvar id) (lvar next-id))
`(name ,next-id ,(bound)))]
[else
(match pat
[`(name ,id-2 ,(bound))
@ -513,17 +508,16 @@
pat]
[else
(define id-2-pat (resolve (lookup-pat id-2 e) e))
;(printf "id: ~s id-pat: ~s id-2: ~s id-2-pat: ~s\n" id id-pat id-2 id-2-pat)
(define res (unify-update* id id-pat id-2-pat e L))
(and res
(not (occurs?* id-2 (lvar id) e L))
(hash-set! e (lvar id-2) (lvar id))
(unless (ground-pat-eq? id-pat id-2-pat)
(hash-set! (new-eqs) (lvar id-2) (lvar id)))
`(name ,id ,(bound)))])]
(and/fail (not-failed? res)
(not (occurs?* id-2 (lvar id) e L))
(hash-set! e (lvar id-2) (lvar id))
(unless (ground-pat-eq? id-pat id-2-pat)
(hash-set! (new-eqs) (lvar id-2) (lvar id)))
`(name ,id ,(bound)))])]
[else
(and (unify-update* id id-pat pat e L)
`(name ,id ,(bound)))])]))
(and/fail (not-failed? (unify-update* id id-pat pat e L))
`(name ,id ,(bound)))])]))
;; we want to consider ground pats that are equal
;; modulo constraints as equal when disunifying (see uses)
@ -607,10 +601,7 @@
rep)
(define (lookup id env)
(define res (hash-ref env (lvar id) (λ ()
#;(hash-set! env (lvar id) 'any)
#;'any
#f)))
(define res (hash-ref env (lvar id) (λ () #f)))
(match res
[(lvar new-id)
(lookup new-id env)]

View File

@ -587,6 +587,13 @@
(J (a any_1) any_2)]
[(J #t #f)])
(test (term (a #t))
#t)
(test (term (a 42))
#t)
(test (term (a #f))
#t)
(test (with-handlers ([exn:fail? exn-message])
(generate-term L0 #:satisfying (c any) +inf.0))
#rx".*generate-term:.*relation.*")

View File

@ -277,8 +277,44 @@
(p*e `(cstr (e) string) (hash)))
(check-equal? (unify/format `number `(nt e) (hash) L0)
(p*e `(cstr (e) number) (hash)))
;; test non-terminal against all built-ins (and reverse)
;; add reversal to unify/format
;; tests specific to #f
;; (which don't work in the above format)
(check-equal? (unify #f '(name x any) (env (hash) '()) #f)
(p*e `(name x ,(bound))
(env (hash (lvar 'x) #f) '())))
(check-equal? (unify #t '(name x any) (env (hash) '()) #f)
(p*e `(name x ,(bound))
(env (hash (lvar 'x) #t) '())))
(check-equal? (unify* #t 'any (hash) #f)
#t)
(check-equal? (unify* #f 'any (hash) #f)
#f)
(check-equal? (unify* #t 'any (hash) #f)
#t)
(check-equal? (unify* #f 'number (hash) #f)
(unif-fail))
(check-equal? (unify* '(list 1) 1 (hash) #f)
(unif-fail))
(check-equal? (unify* 'boolean #t (hash) #f)
#t)
(check-equal? (unify* 'boolean #f (hash) #f)
#f)
(check-equal? (unify* 'number #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'integer #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'natural #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'real #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'string #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'variable #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'variable-not-otherwise-mentioned #f (hash) #f)
(unif-fail))
(define-syntax (unify-all/results stx)
(syntax-case stx ()
@ -452,13 +488,13 @@
(check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) `(nt v) (hash) L0)
`(cstr (e v) (list (nt e) (nt v))))
(check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) 5 (hash) L0)
#f)
(unif-fail))
(check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) `(list (nt e) (nt v)) (hash) L0)
`(cstr (e) (list (nt e) (nt v))))
(check-equal? (unify*/lt `(cstr (e) number) `(cstr (v) natural) (hash) L0)
`(cstr (e v) natural))
(check-equal? (unify*/lt `(cstr (e) (list number variable)) `(cstr (e) number) (hash) L0)
#f)
(unif-fail))
(check-equal? (unify*/lt `(cstr (e) (list number variable-not-otherwise-mentioned))
`(cstr (e) (list integer variable)) (hash) L0)
`(cstr (e) (list integer variable-not-otherwise-mentioned)))
@ -665,9 +701,10 @@
h))
(list `(name x ,(bound))
(make-hash (list (cons (lvar 'x) 'any)))))
(check-false (let ([h (make-hash (list (cons (lvar 'x) (lvar 'y))
(cons (lvar 'y) 'any)))])
(bind-names `(list (name x 5) (name y 6)) h L0)))
(check-equal? (let ([h (make-hash (list (cons (lvar 'x) (lvar 'y))
(cons (lvar 'y) 'any)))])
(bind-names `(list (name x 5) (name y 6)) h L0))
(unif-fail))
(define-syntax do-unify
(λ (stx)