diff --git a/collects/datalog/ast.rkt b/collects/datalog/ast.rkt index 99311e889d..84cafa9694 100644 --- a/collects/datalog/ast.rkt +++ b/collects/datalog/ast.rkt @@ -11,8 +11,16 @@ (or/c exact-nonnegative-integer? #f) (or/c exact-positive-integer? #f)))) -(define datum/c (or/c string? symbol?)) -(define datum-equal? equal?) +(define-struct predicate-sym (srcloc sym) #:prefab) +(define datum/c (or/c string? symbol? predicate-sym?)) +(define (datum-equal? x y) + (match* (x y) + [((predicate-sym _ x) y) + (datum-equal? x y)] + [(x (predicate-sym _ y)) + (datum-equal? x y)] + [(x y) + (equal? x y)])) (define-struct variable (srcloc sym) #:prefab) (define (variable-equal? v1 v2) @@ -21,8 +29,6 @@ (define (constant-equal? v1 v2) (equal? (constant-value v1) (constant-value v2))) -(define-struct predicate-sym (srcloc sym) #:prefab) - (define term/c (or/c variable? constant?)) (define (term-equal? t1 t2) (cond diff --git a/collects/datalog/pretty.rkt b/collects/datalog/pretty.rkt index 697684e346..255f0736a7 100644 --- a/collects/datalog/pretty.rkt +++ b/collects/datalog/pretty.rkt @@ -5,12 +5,16 @@ "private/pprint.rkt" "ast.rkt") -(define (format-datum s) - (cond - [(symbol? s) - (text (symbol->string s))] - [else - (text (format "~S" s))])) +(define format-datum + (match-lambda + [(predicate-sym _ s) + (format-datum s)] + [(? symbol? s) + (text (symbol->string s))] + [(? string? s) + (text (format "~S" s))] + [(? number? s) + (text (format "~S" s))])) (define (format-variable v) (format-datum (variable-sym v))) (define (format-constant c)