Fix pattern unification to deal with #f correctly

This commit is contained in:
Burke Fetscher 2013-03-05 21:52:56 -06:00
parent 18ad15c659
commit edcb13e206
3 changed files with 166 additions and 131 deletions

View File

@ -4,10 +4,11 @@
racket/contract racket/contract
racket/set racket/set
racket/match racket/match
(for-syntax "rewrite-side-conditions.rkt")
"match-a-pattern.rkt" "match-a-pattern.rkt"
"matcher.rkt" "matcher.rkt"
"extract-conditions.rkt") "extract-conditions.rkt"
(for-syntax "rewrite-side-conditions.rkt"
racket/base))
(provide unify (provide unify
unify* unify*
@ -21,8 +22,8 @@
empty-env empty-env
pat*-clause-p?s pat*-clause-p?s
bind-names bind-names
remove-empty-dqs) remove-empty-dqs
unif-fail)
;; ;;
;; atom := `any | `number | `string | `integer | `boolean | `real | `variable | `variable-not-otherwise-mentioned ;; atom := `any | `number | `string | `integer | `boolean | `real | `variable | `variable-not-otherwise-mentioned
;; var := symbol? ;; var := symbol?
@ -50,6 +51,18 @@
(struct env (eqs dqs) #:transparent) (struct env (eqs dqs) #:transparent)
(define empty-env (env (hash) '())) (define empty-env (env (hash) '()))
(struct unif-fail () #:transparent)
(define (not-failed? maybe-failed)
(not (unif-fail? maybe-failed)))
(define-syntax (and/fail stx)
(syntax-case stx ()
[(_ conds ... res)
#'(if (and conds ...)
res
(unif-fail))]))
(define predef-pats (set 'any 'number 'string 'integer 'boolean 'real 'variable 'natural 'variable-not-otherwise-mentioned)) (define predef-pats (set 'any 'number 'string 'integer 'boolean 'real 'variable 'natural 'variable-not-otherwise-mentioned))
(define (predef-pat? a) (define (predef-pat? a)
(set-member? predef-pats a)) (set-member? predef-pats a))
@ -134,13 +147,14 @@
;; pat pat env -> (or/c p*e #f) ;; pat pat env -> (or/c p*e #f)
(define (unify t u e L) (define (unify t u e L)
;(-> pat? pat? env/c compiled-lang? (or/c p*e/c #f)) ;(-> pat? pat? env/c compiled-lang? (or/c p*e/c #f))
;(printf "u: ~s ~s ~s\n\n" t u e)
(parameterize ([dqs-found (make-hash)]) (parameterize ([dqs-found (make-hash)])
(define eqs (hash-copy (env-eqs e))) (define eqs (hash-copy (env-eqs e)))
(define t* (bind-names t eqs L)) (define t* (bind-names t eqs L))
(define u* (bind-names u eqs L)) (define u* (bind-names u eqs L))
(define res (and t* u* (unify* t* u* eqs L))) (define res (and/fail (not-failed? t*)
(and res (not-failed? u*)
(unify* t* u* eqs L)))
(and (not-failed? res)
(let* ([static-eqs (hash/mut->imm eqs)] (let* ([static-eqs (hash/mut->imm eqs)]
[found-pre-dqs [found-pre-dqs
(apply set-union (set) (apply set-union (set)
@ -169,13 +183,12 @@
;; pat pat env lang -> (or/c env #f) ;; pat pat env lang -> (or/c env #f)
(define (disunify t u e L) (define (disunify t u e L)
;(-> pat? pat? env/c any/c (or/c env/c #f)) ;(-> pat? pat? env/c any/c (or/c env/c #f))
;(printf "du: ~s ~s\n\n" t u)
(parameterize ([new-eqs (make-hash)]) (parameterize ([new-eqs (make-hash)])
(define eqs (hash-copy (env-eqs e))) (define eqs (hash-copy (env-eqs e)))
(define t* (bind-names t eqs L)) (define t* (bind-names t eqs L))
(define u* (bind-names u eqs L)) (define u* (bind-names u eqs L))
(cond (cond
[(or (not t*) (not u*)) [(or (unif-fail? t*) (unif-fail? u*))
e] e]
[else [else
(define bn-eqs (hash-copy eqs)) (define bn-eqs (hash-copy eqs))
@ -216,7 +229,6 @@
;; eqs dqs -> dqs or #f ;; eqs dqs -> dqs or #f
;; simplified - first element in lhs of all inequations is a var not occuring in lhs of eqns ;; simplified - first element in lhs of all inequations is a var not occuring in lhs of eqns
(define (check-and-resimplify eqs dqs L) (define (check-and-resimplify eqs dqs L)
;(printf "c-a-r: ~s\n~s\n" dqs eqs)
(define-values (dqs-notok dqs-ok) (define-values (dqs-notok dqs-ok)
(partition (λ (dq) (partition (λ (dq)
(hash-has-key? (hash-has-key?
@ -242,30 +254,14 @@
;; disunfy* pat* pat* eqs lang -> dq or boolean (dq is a pat*) ;; disunfy* pat* pat* eqs lang -> dq or boolean (dq is a pat*)
(define (disunify* u* t* eqs L) (define (disunify* u* t* eqs L)
;(printf "du*: ~s ~s ~s\n" t* u* eqs)
(parameterize ([new-eqs (make-hash)]) (parameterize ([new-eqs (make-hash)])
(let ([res (unify* u* t* eqs L)]) (let ([res (unify* u* t* eqs L)])
(cond (cond
[(not res) #t] [(unif-fail? res) #t]
[(empty? (hash-keys (new-eqs))) #f] [(empty? (hash-keys (new-eqs))) #f]
[else [else
(extend-dq (new-eqs) base-dq)])))) (extend-dq (new-eqs) base-dq)]))))
(define (update-env e new-eqs the-dqs)
(env (for/fold ([eqs (env-eqs e)])
([(k v) (in-hash new-eqs)])
(hash-set eqs k v))
the-dqs))
(define (update-ineqs e new-es)
(struct-copy env e
[dqs (cons (for/fold ([dq '((list)(list))])
([(l r) (in-hash new-es)])
(match dq
[`((,vars ...) (,rhss ...))
`((,vars ... (name ,l ,(bound))) (,rhss ... ,r))]))
(env-dqs e))]))
;; the "root" pats will be pats without names, ;; the "root" pats will be pats without names,
;; which match both pat and pat*... ;; which match both pat and pat*...
@ -278,14 +274,14 @@
(error 'bind-names "pat*, not a pat: ~s\n" pat)] (error 'bind-names "pat*, not a pat: ~s\n" pat)]
[`(name ,name ,pat) [`(name ,name ,pat)
(define b-pat (bind-names pat e L)) (define b-pat (bind-names pat e L))
(and b-pat (and/fail (not-failed? b-pat)
(let recur ([id name]) (let recur ([id name])
(define res (hash-ref e (lvar id) (uninstantiated))) (define res (hash-ref e (lvar id) (uninstantiated)))
(match res (match res
[(uninstantiated) [(uninstantiated)
(when (equal? b-pat (bound)) (when (equal? b-pat (bound))
(error 'bind-names "tried to set something to bound")) (error 'bind-names "tried to set something to bound"))
(and (not (occurs?* id b-pat e L)) (and/fail (not (occurs?* id b-pat e L))
(hash-set! e (lvar id) b-pat) (hash-set! e (lvar id) b-pat)
;; here we only bind to things in the LOCAL pattern ;; here we only bind to things in the LOCAL pattern
;; so don't update new-eqs ;; so don't update new-eqs
@ -300,15 +296,18 @@
[_ (void)]) [_ (void)])
next] next]
[else ;; some pat* (res is already bound) [else ;; some pat* (res is already bound)
(and (unify-update* id b-pat res e L) (and/fail (not-failed? (unify-update* id b-pat res e L))
`(name ,id ,(bound)))])))] `(name ,id ,(bound)))])))]
[`(list ,pats ...) [`(list ,pats ...)
(let/ec fail (let/ec fail
`(list ,@(for/list ([p pats]) `(list ,@(for/list ([p pats])
(or (bind-names p e L) (fail #f)))))] (define res (bind-names p e L))
(if (not-failed? res)
res
(fail (unif-fail))))))]
[`(mismatch-name ,name ,p) [`(mismatch-name ,name ,p)
(define b-pat (bind-names p e L)) (define b-pat (bind-names p e L))
(and b-pat (and/fail (not-failed? b-pat)
`(mismatch-name ,name ,(bind-names p e L)))] `(mismatch-name ,name ,(bind-names p e L)))]
[_ pat])) [_ pat]))
@ -317,7 +316,6 @@
(define (unify* t0 u0 e L) (define (unify* t0 u0 e L)
(define t (resolve t0 e)) (define t (resolve t0 e))
(define u (resolve u0 e)) (define u (resolve u0 e))
;(printf "unify*: ~s ~s\n" t u)
(match* (t u) (match* (t u)
;; mismatch patterns ;; mismatch patterns
[(`(mismatch-name ,name ,t-pat) u) [(`(mismatch-name ,name ,t-pat) u)
@ -334,18 +332,15 @@
[(_ `(name ,name ,(bound))) [(_ `(name ,name ,(bound)))
(unify* u t e L)] (unify* u t e L)]
;; cstrs ;; cstrs
#;[(`(nt ,n) `(cstr (,nts ...) ,p)) ;; remove ?? put back?
`(cstr ,(sort (remove-duplicates (cons n nts))
symbol<?) ,p)]
[(`(cstr (,nts1 ...) ,p1) `(cstr (,nts2 ...) ,p2)) [(`(cstr (,nts1 ...) ,p1) `(cstr (,nts2 ...) ,p2))
(let ([res (unify* p1 p2 e L)]) (let ([res (unify* p1 p2 e L)])
(and res (and/fail (not-failed? res)
(when (lvar? res) (when (lvar? res)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p1 p2 e)) (error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p1 p2 e))
`(cstr ,(merge-ids/sorted nts1 nts2) ,res)))] `(cstr ,(merge-ids/sorted nts1 nts2) ,res)))]
[(`(cstr ,nts ,p) _) [(`(cstr ,nts ,p) _)
(let ([res (unify* p u e L)]) (let ([res (unify* p u e L)])
(and res (and/fail (not-failed? res)
(match res (match res
[(lvar id) [(lvar id)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p u e)] (error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" p u e)]
@ -367,7 +362,7 @@
(if (hash-has-key? (compiled-lang-collapsible-nts L) p) (if (hash-has-key? (compiled-lang-collapsible-nts L) p)
(unify* (hash-ref (compiled-lang-collapsible-nts L) p) u e L) (unify* (hash-ref (compiled-lang-collapsible-nts L) p) u e L)
(let ([res (unify* u u e L)]) (let ([res (unify* u u e L)])
(and res (and/fail (not-failed? res)
(when (lvar? res) (when (lvar? res)
(error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" u u e)) (error 'unify* "unify* returned lvar as result: ~s\n~s\n~s\n" u u e))
`(cstr (,p) ,res))))] `(cstr (,p) ,res))))]
@ -375,16 +370,19 @@
(unify* `(nt ,p) t e L)] (unify* `(nt ,p) t e L)]
;; other pat stuff ;; other pat stuff
[(`(list ,ts ...) `(list ,us ...)) [(`(list ,ts ...) `(list ,us ...))
(and (= (length ts) (length us)) (and/fail (= (length ts) (length us))
(let/ec fail (let/ec fail
`(list ,@(for/list ([t ts] [u us]) `(list ,@(for/list ([t ts] [u us])
(or (unify* t u e L) (fail #f))))))] (let ([res (unify* t u e L)])
(if (not-failed? res)
res
(fail (unif-fail))))))))]
[((? number-type? t) (? number-type? u)) [((? number-type? t) (? number-type? u))
(cond (cond
[(number-superset? t u) u] [(number-superset? t u) u]
[(number-superset? u t) t])] [(number-superset? u t) t])]
[((? number-type? t) _) [((? number-type? t) _)
(and ((number-pred t) u) (and/fail ((number-pred t) u)
u)] u)]
[(_ (? number-type? u)) [(_ (? number-type? u))
(unify* u t e L)] (unify* u t e L)]
@ -395,7 +393,7 @@
[(`variable-not-otherwise-mentioned `variable) [(`variable-not-otherwise-mentioned `variable)
`variable-not-otherwise-mentioned] `variable-not-otherwise-mentioned]
[(`variable-not-otherwise-mentioned (? symbol? s)) [(`variable-not-otherwise-mentioned (? symbol? s))
(and (not (memq s (compiled-lang-literals L))) (and/fail (not (memq s (compiled-lang-literals L)))
(not (base-type? s)) (not (base-type? s))
s)] s)]
[(`variable `variable) [(`variable `variable)
@ -403,7 +401,7 @@
[(_ `variable) [(_ `variable)
(unify* u t e L)] (unify* u t e L)]
[(`variable (? symbol? s)) [(`variable (? symbol? s))
(and (not (base-type? s)) (and/fail (not (base-type? s))
s)] s)]
;; string stuff ;; string stuff
[(`string `string) [(`string `string)
@ -417,7 +415,7 @@
[(`boolean `boolean) [(`boolean `boolean)
`boolean] `boolean]
[(`string `boolean) [(`string `boolean)
#f] (unif-fail)]
[(_ `boolean) [(_ `boolean)
(unify* u t e L)] (unify* u t e L)]
[(`boolean (? boolean? b)) [(`boolean (? boolean? b))
@ -425,12 +423,12 @@
;; other ;; other
[((? base-type? t) (? base-type? u)) [((? base-type? t) (? base-type? u))
(and (equal? t u) (and/fail (equal? t u)
t)] t)]
[((? (compose not pair?) t) (? (compose not pair?) u)) [((? (compose not pair?) t) (? (compose not pair?) u))
(and (equal? t u) (and/fail (equal? t u)
t)] t)]
[(_ _) #f])) [(_ _) (unif-fail)]))
(define (resolve pat env) (define (resolve pat env)
(match pat (match pat
@ -446,11 +444,10 @@
;; unify-update* : id pat* pat* env lang -> pat* or #f ;; unify-update* : id pat* pat* env lang -> pat* or #f
(define (unify-update* id pat-1 pat-2 e L) (define (unify-update* id pat-1 pat-2 e L)
;(printf "unify-update ~s ~s ~s\n" id pat-1 pat-2)
(let ([u-res (unify* pat-1 pat-2 e L)]) (let ([u-res (unify* pat-1 pat-2 e L)])
(and (not (occurs?* id pat-1 e L)) (and/fail (not (occurs?* id pat-1 e L))
(not (occurs?* id pat-2 e L)) (not (occurs?* id pat-2 e L))
(when u-res (when (not-failed? u-res)
(when (equal? u-res (bound)) (error 'update "tried to set something to bound")) (when (equal? u-res (bound)) (error 'update "tried to set something to bound"))
(unless (equal? u-res (hash-ref e (lvar id) (uninstantiated))) (unless (equal? u-res (hash-ref e (lvar id) (uninstantiated)))
(hash-set! e (lvar id) u-res) (hash-set! e (lvar id) u-res)
@ -478,7 +475,6 @@
;; TODO: replace name in p*'s with lvar - this is the most obvious of many ;; TODO: replace name in p*'s with lvar - this is the most obvious of many
;; functions that would be improved by this ;; functions that would be improved by this
(define (occurs?* name p e L) (define (occurs?* name p e L)
;(printf "occurs: ~s ~s\n" name p)
(match p (match p
[`(name ,name-p ,(bound)) [`(name ,name-p ,(bound))
(or (eq? name name-p) (or (eq? name name-p)
@ -497,11 +493,10 @@
(define (instantiate* id pat e L) (define (instantiate* id pat e L)
;(printf "inst*: ~s ~s\n" id pat)
(define id-pat (resolve (lookup-pat id e) e)) (define id-pat (resolve (lookup-pat id e) e))
(match id-pat (match id-pat
[`(name ,next-id ,(bound)) [`(name ,next-id ,(bound))
(and (instantiate* next-id pat e L) (and/fail (not-failed? (instantiate* next-id pat e L))
(not (occurs?* id (lvar next-id) e L)) (not (occurs?* id (lvar next-id) e L))
(hash-set! e (lvar id) (lvar next-id)) (hash-set! e (lvar id) (lvar next-id))
`(name ,next-id ,(bound)))] `(name ,next-id ,(bound)))]
@ -513,16 +508,15 @@
pat] pat]
[else [else
(define id-2-pat (resolve (lookup-pat id-2 e) e)) (define id-2-pat (resolve (lookup-pat id-2 e) e))
;(printf "id: ~s id-pat: ~s id-2: ~s id-2-pat: ~s\n" id id-pat id-2 id-2-pat)
(define res (unify-update* id id-pat id-2-pat e L)) (define res (unify-update* id id-pat id-2-pat e L))
(and res (and/fail (not-failed? res)
(not (occurs?* id-2 (lvar id) e L)) (not (occurs?* id-2 (lvar id) e L))
(hash-set! e (lvar id-2) (lvar id)) (hash-set! e (lvar id-2) (lvar id))
(unless (ground-pat-eq? id-pat id-2-pat) (unless (ground-pat-eq? id-pat id-2-pat)
(hash-set! (new-eqs) (lvar id-2) (lvar id))) (hash-set! (new-eqs) (lvar id-2) (lvar id)))
`(name ,id ,(bound)))])] `(name ,id ,(bound)))])]
[else [else
(and (unify-update* id id-pat pat e L) (and/fail (not-failed? (unify-update* id id-pat pat e L))
`(name ,id ,(bound)))])])) `(name ,id ,(bound)))])]))
;; we want to consider ground pats that are equal ;; we want to consider ground pats that are equal
@ -607,10 +601,7 @@
rep) rep)
(define (lookup id env) (define (lookup id env)
(define res (hash-ref env (lvar id) (λ () (define res (hash-ref env (lvar id) (λ () #f)))
#;(hash-set! env (lvar id) 'any)
#;'any
#f)))
(match res (match res
[(lvar new-id) [(lvar new-id)
(lookup new-id env)] (lookup new-id env)]

View File

@ -587,6 +587,13 @@
(J (a any_1) any_2)] (J (a any_1) any_2)]
[(J #t #f)]) [(J #t #f)])
(test (term (a #t))
#t)
(test (term (a 42))
#t)
(test (term (a #f))
#t)
(test (with-handlers ([exn:fail? exn-message]) (test (with-handlers ([exn:fail? exn-message])
(generate-term L0 #:satisfying (c any) +inf.0)) (generate-term L0 #:satisfying (c any) +inf.0))
#rx".*generate-term:.*relation.*") #rx".*generate-term:.*relation.*")

View File

@ -277,8 +277,44 @@
(p*e `(cstr (e) string) (hash))) (p*e `(cstr (e) string) (hash)))
(check-equal? (unify/format `number `(nt e) (hash) L0) (check-equal? (unify/format `number `(nt e) (hash) L0)
(p*e `(cstr (e) number) (hash))) (p*e `(cstr (e) number) (hash)))
;; test non-terminal against all built-ins (and reverse)
;; add reversal to unify/format
;; tests specific to #f
;; (which don't work in the above format)
(check-equal? (unify #f '(name x any) (env (hash) '()) #f)
(p*e `(name x ,(bound))
(env (hash (lvar 'x) #f) '())))
(check-equal? (unify #t '(name x any) (env (hash) '()) #f)
(p*e `(name x ,(bound))
(env (hash (lvar 'x) #t) '())))
(check-equal? (unify* #t 'any (hash) #f)
#t)
(check-equal? (unify* #f 'any (hash) #f)
#f)
(check-equal? (unify* #t 'any (hash) #f)
#t)
(check-equal? (unify* #f 'number (hash) #f)
(unif-fail))
(check-equal? (unify* '(list 1) 1 (hash) #f)
(unif-fail))
(check-equal? (unify* 'boolean #t (hash) #f)
#t)
(check-equal? (unify* 'boolean #f (hash) #f)
#f)
(check-equal? (unify* 'number #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'integer #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'natural #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'real #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'string #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'variable #f (hash) #f)
(unif-fail))
(check-equal? (unify* 'variable-not-otherwise-mentioned #f (hash) #f)
(unif-fail))
(define-syntax (unify-all/results stx) (define-syntax (unify-all/results stx)
(syntax-case stx () (syntax-case stx ()
@ -452,13 +488,13 @@
(check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) `(nt v) (hash) L0) (check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) `(nt v) (hash) L0)
`(cstr (e v) (list (nt e) (nt v)))) `(cstr (e v) (list (nt e) (nt v))))
(check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) 5 (hash) L0) (check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) 5 (hash) L0)
#f) (unif-fail))
(check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) `(list (nt e) (nt v)) (hash) L0) (check-equal? (unify*/lt `(cstr (e) (list (nt e) (nt v))) `(list (nt e) (nt v)) (hash) L0)
`(cstr (e) (list (nt e) (nt v)))) `(cstr (e) (list (nt e) (nt v))))
(check-equal? (unify*/lt `(cstr (e) number) `(cstr (v) natural) (hash) L0) (check-equal? (unify*/lt `(cstr (e) number) `(cstr (v) natural) (hash) L0)
`(cstr (e v) natural)) `(cstr (e v) natural))
(check-equal? (unify*/lt `(cstr (e) (list number variable)) `(cstr (e) number) (hash) L0) (check-equal? (unify*/lt `(cstr (e) (list number variable)) `(cstr (e) number) (hash) L0)
#f) (unif-fail))
(check-equal? (unify*/lt `(cstr (e) (list number variable-not-otherwise-mentioned)) (check-equal? (unify*/lt `(cstr (e) (list number variable-not-otherwise-mentioned))
`(cstr (e) (list integer variable)) (hash) L0) `(cstr (e) (list integer variable)) (hash) L0)
`(cstr (e) (list integer variable-not-otherwise-mentioned))) `(cstr (e) (list integer variable-not-otherwise-mentioned)))
@ -665,9 +701,10 @@
h)) h))
(list `(name x ,(bound)) (list `(name x ,(bound))
(make-hash (list (cons (lvar 'x) 'any))))) (make-hash (list (cons (lvar 'x) 'any)))))
(check-false (let ([h (make-hash (list (cons (lvar 'x) (lvar 'y)) (check-equal? (let ([h (make-hash (list (cons (lvar 'x) (lvar 'y))
(cons (lvar 'y) 'any)))]) (cons (lvar 'y) 'any)))])
(bind-names `(list (name x 5) (name y 6)) h L0))) (bind-names `(list (name x 5) (name y 6)) h L0))
(unif-fail))
(define-syntax do-unify (define-syntax do-unify
(λ (stx) (λ (stx)