diff --git a/collects/datalog/stx.rkt b/collects/datalog/stx.rkt index ddc0d354bb..8885a676c2 100644 --- a/collects/datalog/stx.rkt +++ b/collects/datalog/stx.rkt @@ -1,5 +1,8 @@ #lang racket -(require (for-syntax syntax/parse) +(require (for-syntax syntax/parse + racket/local + racket/function + racket/list) datalog/ast datalog/eval) @@ -78,6 +81,40 @@ stx #:literals (:-) [(_ (~and tstx (:- head body ...))) + (local [(define (datalog-literal-variables stx) + (syntax-parse + stx + #:literals (:-) + [sym:id + empty] + [(~and tstx (sym:id arg ... :- ans ...)) + (append-map datalog-term-variables + (syntax->list #'(arg ... ans ...)))] + [(~and tstx (sym:id e ...)) + (append-map datalog-term-variables + (syntax->list #'(e ...)))])) + (define (datalog-term-variables stx) + (syntax-parse + stx + [sym:id + (list #'sym)] + [sym:expr + empty])) + (define head-vars (datalog-literal-variables #'head)) + (define body-vars + (append-map datalog-literal-variables (syntax->list #'(body ...)))) + (define body-vars-in-head + (filter + (λ (bv) + (findf (curry bound-identifier=? bv) + head-vars)) + body-vars)) + (define fake-lam + (quasisyntax/loc #'tstx + (lambda #,head-vars + (void #,@body-vars-in-head))))] + (syntax-local-lift-expression + fake-lam)) (quasisyntax/loc #'tstx (clause #'#,#'tstx (datalog-literal head) (list (datalog-literal body) ...)))]