redex: add #:pre to define-judgment-form

also, update contract error messages for the same
This commit is contained in:
Burke Fetscher 2014-07-22 14:27:08 -05:00
parent 45af9f8fe2
commit 54a6d3179d
4 changed files with 137 additions and 82 deletions

View File

@ -1259,10 +1259,13 @@ and @racket[#f] otherwise.
(define-judgment-form language (define-judgment-form language
mode-spec mode-spec
contract-spec contract-spec
invariant-spec
rule rule ...) rule rule ...)
([mode-spec (code:line #:mode (form-id pos-use ...))] ([mode-spec (code:line #:mode (form-id pos-use ...))]
[contract-spec (code:line) [contract-spec (code:line)
(code:line #:contract (form-id @#,ttpattern-sequence ...))] (code:line #:contract (form-id @#,ttpattern-sequence ...))]
[invariant-spec (code:line #:inv @#,tttterm)
(code:line)]
[pos-use I [pos-use I
O] O]
[rule [premise [rule [premise
@ -1299,7 +1302,10 @@ and input positions in premises must be @|tttterm|s; input positions in conclusi
output positions in premises must be @|ttpattern|s. When the optional @racket[contract-spec] output positions in premises must be @|ttpattern|s. When the optional @racket[contract-spec]
declaration is present, Redex dynamically checks that the terms flowing through declaration is present, Redex dynamically checks that the terms flowing through
these positions match the provided patterns, raising an exception recognized by these positions match the provided patterns, raising an exception recognized by
@racket[exn:fail:redex] if not. @racket[exn:fail:redex] if not. The term in the optional @racket[invariant-spec] is
evaluated after the output positions have been computed and the contract has matched
successfully, with variables from the contract bound; a result of @racket[#f] is
considered to be a contract violation and an exception is raised.
For example, the following defines addition on natural numbers: For example, the following defines addition on natural numbers:
@interaction[ @interaction[

View File

@ -79,6 +79,33 @@
(define-syntax (fresh stx) (raise-syntax-error 'fresh "used outside of reduction-relation")) (define-syntax (fresh stx) (raise-syntax-error 'fresh "used outside of reduction-relation"))
(define-syntax (with stx) (raise-syntax-error 'with "used outside of reduction-relation")) (define-syntax (with stx) (raise-syntax-error 'with "used outside of reduction-relation"))
(module mode-utils racket/base
(require racket/list)
(provide split-by-mode
assemble)
(define (split-by-mode xs mode)
(for/fold ([ins '()] [outs '()])
([x (reverse xs)]
[m (reverse mode)])
(case m
[(I) (values (cons x ins) outs)]
[(O) (values ins (cons x outs))]
[else (error 'split-by-mode "ack ~s" m)])))
(define (assemble mode inputs outputs)
(let loop ([ms mode] [is inputs] [os outputs])
(if (null? ms)
'()
(case (car ms)
[(I) (cons (car is) (loop (cdr ms) (cdr is) os))]
[(O) (cons (car os) (loop (cdr ms) is (cdr os)))])))))
(require 'mode-utils
(for-syntax 'mode-utils))
(define-for-syntax (generate-binding-constraints names names/ellipses bindings syn-err-name) (define-for-syntax (generate-binding-constraints names names/ellipses bindings syn-err-name)
(define (id/depth stx) (define (id/depth stx)
(syntax-case stx () (syntax-case stx ()
@ -313,14 +340,6 @@
(reverse subs))) (reverse subs)))
this-output))) this-output)))
(define (assemble mode inputs outputs)
(let loop ([ms mode] [is inputs] [os outputs])
(if (null? ms)
'()
(case (car ms)
[(I) (cons (car is) (loop (cdr ms) (cdr is) os))]
[(O) (cons (car os) (loop (cdr ms) is (cdr os)))]))))
(define (verify-name-ok orig-name the-name) (define (verify-name-ok orig-name the-name)
(unless (symbol? the-name) (unless (symbol? the-name)
(error orig-name "expected a single name, got ~s" the-name))) (error orig-name "expected a single name, got ~s" the-name)))
@ -550,7 +569,7 @@
(define-for-syntax (do-extended-judgment-form lang syn-err-name body orig stx is-relation?) (define-for-syntax (do-extended-judgment-form lang syn-err-name body orig stx is-relation?)
(define nts (definition-nts lang stx syn-err-name)) (define nts (definition-nts lang stx syn-err-name))
(define-values (judgment-form-name dup-form-names mode position-contracts clauses rule-names) (define-values (judgment-form-name dup-form-names mode position-contracts invariant clauses rule-names)
(parse-judgment-form-body body syn-err-name stx (identifier? orig) is-relation?)) (parse-judgment-form-body body syn-err-name stx (identifier? orig) is-relation?))
(define definitions (define definitions
(with-syntax ([judgment-form-runtime-proc (with-syntax ([judgment-form-runtime-proc
@ -563,7 +582,7 @@
#'mk-judgment-form-proc #'#,lang #'jf-lws #'mk-judgment-form-proc #'#,lang #'jf-lws
'#,rule-names #'judgment-runtime-gen-clauses #'mk-judgment-gen-clauses #'jf-term-proc #,is-relation?)) '#,rule-names #'judgment-runtime-gen-clauses #'mk-judgment-gen-clauses #'jf-term-proc #,is-relation?))
(define-values (mk-judgment-form-proc mk-judgment-gen-clauses) (define-values (mk-judgment-form-proc mk-judgment-gen-clauses)
(compile-judgment-form #,judgment-form-name #,mode #,lang #,clauses #,rule-names #,position-contracts (compile-judgment-form #,judgment-form-name #,mode #,lang #,clauses #,rule-names #,position-contracts #,invariant
#,orig #,stx #,syn-err-name judgment-runtime-gen-clauses)) #,orig #,stx #,syn-err-name judgment-runtime-gen-clauses))
(define judgment-form-runtime-proc (mk-judgment-form-proc #,lang)) (define judgment-form-runtime-proc (mk-judgment-form-proc #,lang))
(define jf-lws (compiled-judgment-form-lws #,clauses #,judgment-form-name #,stx)) (define jf-lws (compiled-judgment-form-lws #,clauses #,judgment-form-name #,stx))
@ -625,10 +644,11 @@
(cons #f names))]))) (cons #f names))])))
(values (reverse backward-rules) (values (reverse backward-rules)
(reverse backward-names))) (reverse backward-names)))
(define-values (name/mode mode-stx name/contract contract rules rule-names) (define-values (name/mode mode-stx name/contract contract invariant rules rule-names)
(syntax-parse body #:context full-stx (syntax-parse body #:context full-stx
[((~or (~seq #:mode ~! mode:mode-spec) [((~or (~seq #:mode ~! mode:mode-spec)
(~seq #:contract ~! contract:contract-spec)) (~seq #:contract ~! contract:contract-spec)
(~seq #:inv ~! inv:expr))
... ...
rule:expr ...) rule:expr ...)
(let-values ([(name/mode mode) (let-values ([(name/mode mode)
@ -648,9 +668,17 @@
[(_ . dups) [(_ . dups)
(raise-syntax-error (raise-syntax-error
syn-err-name "expected at most one contract specification" syn-err-name "expected at most one contract specification"
#f #f (syntax->list #'dups))])]
[(invt)
(syntax-parse #'(inv ...)
[() #f]
[(invar) #'invar]
[(_ . dups)
(raise-syntax-error
syn-err-name "expected at most one invariant specification"
#f #f (syntax->list #'dups))])]) #f #f (syntax->list #'dups))])])
(define-values (parsed-rules rule-names) (parse-rules (syntax->list #'(rule ...)))) (define-values (parsed-rules rule-names) (parse-rules (syntax->list #'(rule ...))))
(values name/mode mode name/ctc ctc parsed-rules rule-names))])) (values name/mode mode name/ctc ctc invt parsed-rules rule-names))]))
(check-clauses full-stx syn-err-name rules #t) (check-clauses full-stx syn-err-name rules #t)
(check-dup-rule-names full-stx syn-err-name rule-names) (check-dup-rule-names full-stx syn-err-name rule-names)
(check-arity-consistency mode-stx contract full-stx) (check-arity-consistency mode-stx contract full-stx)
@ -667,7 +695,7 @@
[(symbol? (syntax-e name)) [(symbol? (syntax-e name))
(symbol->string (syntax-e name))] (symbol->string (syntax-e name))]
[else (syntax-e name)]))) [else (syntax-e name)])))
(values form-name dup-names mode-stx contract rules string-rule-names)) (values form-name dup-names mode-stx contract invariant rules string-rule-names))
;; names : (listof (or/c #f syntax[string])) ;; names : (listof (or/c #f syntax[string]))
(define-for-syntax (check-dup-rule-names full-stx syn-err-name names) (define-for-syntax (check-dup-rule-names full-stx syn-err-name names)
@ -795,11 +823,13 @@
[(_ jf-expr) [(_ jf-expr)
#'(#%expression (judgment-holds/derivation build-derivations #t jf-expr any))])) #'(#%expression (judgment-holds/derivation build-derivations #t jf-expr any))]))
(define-for-syntax (do-compile-judgment-form-proc name mode-stx clauses rule-names contracts nts orig lang stx syn-error-name) (define-for-syntax (do-compile-judgment-form-proc name mode-stx clauses rule-names contracts orig-ctcs nts orig lang stx syn-error-name)
(with-syntax ([(init-jf-derivation-id) (generate-temporaries '(init-jf-derivation-id))]) (with-syntax ([(init-jf-derivation-id) (generate-temporaries '(init-jf-derivation-id))])
(define mode (cdr (syntax->datum mode-stx))) (define mode (cdr (syntax->datum mode-stx)))
(define-values (input-contracts output-contracts) (define-values (input-contracts output-contracts)
(if contracts (values (first contracts)
(second contracts))
#;(if contracts
(let-values ([(ins outs) (split-by-mode contracts mode)]) (let-values ([(ins outs) (split-by-mode contracts mode)])
(values ins outs)) (values ins outs))
(values #f #f))) (values #f #f)))
@ -833,17 +863,17 @@
;; pieces of a 'let' expression to be combined: first some bindings ;; pieces of a 'let' expression to be combined: first some bindings
([compiled-lhs (compile-pattern lang `lhs #t)] ([compiled-lhs (compile-pattern lang `lhs #t)]
#,@(if input-contracts #,@(if input-contracts
(list #`[compiled-input-ctcs #,(contracts-compilation input-contracts)]) (list #`[compiled-input-ctcs (compile-pattern lang `#,input-contracts #f)])
(list)) (list))
#,@(if output-contracts #,@(if output-contracts
(list #`[compiled-output-ctcs #,(contracts-compilation output-contracts)]) (list #`[compiled-output-ctcs (compile-pattern lang `#,output-contracts #f)])
(list))) (list)))
;; and then the body of the let, but expected to be behind a (λ (input) ...). ;; and then the body of the let, but expected to be behind a (λ (input) ...).
(let ([jf-derivation-id init-jf-derivation-id]) (let ([jf-derivation-id init-jf-derivation-id])
(begin (begin
lhs-syncheck-exp lhs-syncheck-exp
#,@(if input-contracts #,@(if input-contracts
(list #`(check-judgment-form-contract '#,name input compiled-input-ctcs 'I '#,mode)) (list #`(check-judgment-form-contract '#,name input #f compiled-input-ctcs '#,orig-ctcs 'I '#,mode))
(list)) (list))
(combine-judgment-rhses (combine-judgment-rhses
compiled-lhs compiled-lhs
@ -855,7 +885,7 @@
body)) body))
#,(if output-contracts #,(if output-contracts
#`(λ (output) #`(λ (output)
(check-judgment-form-contract '#,name output compiled-output-ctcs 'O '#,mode)) (check-judgment-form-contract '#,name input output compiled-output-ctcs '#,orig-ctcs 'O '#,mode))
#`void))))))))])) #`void))))))))]))
(when (identifier? orig) (when (identifier? orig)
@ -942,14 +972,28 @@
(list (reverse rhses) (list (reverse rhses)
(reverse sc/wheres)))) (reverse sc/wheres))))
(define (check-judgment-form-contract form-name term+trees contracts mode modes) (define (check-judgment-form-contract form-name input-term output-term+trees contracts orig-ctcs mode modes)
(define terms (if (eq? mode 'O) (define o-term (and (eq? mode 'O)
(derivation-with-output-only-output term+trees) (derivation-with-output-only-output output-term+trees)))
term+trees))
(define description (define description
(case mode (case mode
[(I) "input"] [(I) "input"]
[(O) "output"])) [(O) "output"]))
(when contracts
(case mode
[(I)
(unless (match-pattern contracts input-term)
(redex-error form-name (string-append "judgment input values do not match its contract;\n"
" (unknown output values indicated by _)\n contract: ~a\n values: ~a")
(cons form-name orig-ctcs)
(cons form-name (assemble modes input-term (build-list (length modes)
(λ (_) '_))))))]
[(O)
(define io-term (assemble modes input-term o-term))
(unless (match-pattern contracts io-term)
(redex-error form-name "judgment values do not match its contract;\n contract: ~a\n values: ~a"
(cons form-name orig-ctcs) (cons form-name io-term)))]))
#;
(when contracts (when contracts
(let loop ([rest-modes modes] [rest-terms terms] [rest-ctcs contracts] [pos 1]) (let loop ([rest-modes modes] [rest-terms terms] [rest-ctcs contracts] [pos 1])
(unless (null? rest-modes) (unless (null? rest-modes)
@ -1125,7 +1169,7 @@
(define-syntax (compile-judgment-form stx) (define-syntax (compile-judgment-form stx)
(syntax-case stx () (syntax-case stx ()
[(_ judgment-form-name mode-arg lang raw-clauses rule-names ctcs orig full-def syn-err-name judgment-form-runtime-gen-clauses) [(_ judgment-form-name mode-arg lang raw-clauses rule-names ctcs invt orig full-def syn-err-name judgment-form-runtime-gen-clauses)
(let ([nts (definition-nts #'lang #'full-def (syntax-e #'syn-err-name))] (let ([nts (definition-nts #'lang #'full-def (syntax-e #'syn-err-name))]
[rule-names (syntax->datum #'rule-names)] [rule-names (syntax->datum #'rule-names)]
[syn-err-name (syntax-e #'syn-err-name)] [syn-err-name (syntax-e #'syn-err-name)]
@ -1136,6 +1180,35 @@
[mode (cdr (syntax->datum #'mode-arg))]) [mode (cdr (syntax->datum #'mode-arg))])
(unless (jf-is-relation? #'judgment-form-name) (unless (jf-is-relation? #'judgment-form-name)
(mode-check (cdr (syntax->datum #'mode-arg)) clauses nts syn-err-name stx)) (mode-check (cdr (syntax->datum #'mode-arg)) clauses nts syn-err-name stx))
(define maybe-wrap-contract (if (syntax-e #'invt)
(λ (ctc-stx)
#`(side-condition #,ctc-stx (term invt)))
values))
(define-values (i-ctc-syncheck-expr i-ctc)
(syntax-case #'ctcs ()
[#f (values #'(void) #f)]
[(p ...)
(let-values ([(i-ctcs o-ctcs) (split-by-mode (syntax->list #'(p ...)) mode)])
(with-syntax* ([(i-ctcs ...) i-ctcs]
[(syncheck i-ctc-pat (names ...) (names/ellipses ...))
(rewrite-side-conditions/check-errs #'lang #'syn-error-name #f #'(i-ctcs ...))])
(values #'syncheck #'i-ctc-pat)))]))
(define-values (ctc-syncheck-expr ctc)
(cond
[(not (or (syntax-e #'ctcs)
(syntax-e #'invt)))
(values #'(void) #f)]
[else
(define ctc-stx ((if (syntax-e #'invt)
(λ (ctc-stx)
#`(side-condition #,ctc-stx (term invt)))
values)
(if (syntax-e #'ctcs)
#'ctcs
#'any)))
(with-syntax ([(syncheck ctc-pat (names ...) (names/ellipses ...))
(rewrite-side-conditions/check-errs #'lang #'syn-error-name #f ctc-stx)])
(values #'syncheck #'ctc-pat))]))
(define-values (syncheck-exprs contracts) (define-values (syncheck-exprs contracts)
(syntax-case #'ctcs () (syntax-case #'ctcs ()
[#f (values '() #f)] [#f (values '() #f)]
@ -1156,7 +1229,8 @@
#'mode-arg #'mode-arg
clauses clauses
rule-names rule-names
contracts (list i-ctc ctc)
#'ctcs
nts nts
#'orig #'orig
#'lang #'lang
@ -1179,7 +1253,8 @@
(λ () (λ ()
#,(check-pats #,(check-pats
#'(list comp-clauses ...))))))) #'(list comp-clauses ...)))))))
#`(begin #,@syncheck-exprs (values #,proc-stx #,gen-stx)))])) #`(begin #,i-ctc-syncheck-expr #,ctc-syncheck-expr
(values #,proc-stx #,gen-stx)))]))
(define-for-syntax (rewrite-relation-prems clauses) (define-for-syntax (rewrite-relation-prems clauses)
(map (λ (c) (map (λ (c)
@ -1334,40 +1409,6 @@
(syntax->list #'(x ...))) (syntax->list #'(x ...)))
(raise-syntax-error syn-error-name "error checking failed.2" stx))])) (raise-syntax-error syn-error-name "error checking failed.2" stx))]))
(define-for-syntax (split-by-mode xs mode)
(for/fold ([ins '()] [outs '()])
([x (reverse xs)]
[m (reverse mode)])
(case m
[(I) (values (cons x ins) outs)]
[(O) (values ins (cons x outs))]
[else (error 'split-by-mode "ack ~s" m)])))
(define-for-syntax (fuse-by-mode ins outs mode)
(let loop ([is (reverse ins)]
[os (reverse outs)]
[ms (reverse mode)]
[res '()])
(define err (λ () (error 'fuse-by-mode "mismatched mode and split: ~s ~s ~s" ins outs mode)))
(cond
[(and (empty? ms)
(empty? is)
(empty? os))
res]
[(empty? ms)
(err)]
[else
(case (car ms)
[(I) (if (empty? is)
(err)
(loop (cdr is) os (cdr ms)
(cons (car is) res)))]
[(O) (if (empty? os)
(err)
(loop is (cdr os) (cdr ms)
(cons (car os) res)))]
[else (error 'fuse-by-mode "ack ~s" (car ms))])])))
(define-for-syntax (ellipsis? stx) (define-for-syntax (ellipsis? stx)
(and (identifier? stx) (and (identifier? stx)
(free-identifier=? stx (quote-syntax ...)))) (free-identifier=? stx (quote-syntax ...))))
@ -1405,7 +1446,7 @@
[(syncheck-exps conc/+rw names) (rewrite-pats conc/+ lang)] [(syncheck-exps conc/+rw names) (rewrite-pats conc/+ lang)]
[(ps-rw eqs p-names) (rewrite-prems #t (syntax->list #'(prems ...)) names lang 'define-judgment-form)] [(ps-rw eqs p-names) (rewrite-prems #t (syntax->list #'(prems ...)) names lang 'define-judgment-form)]
[(conc/-rw conc-mfs) (rewrite-terms conc/- p-names)] [(conc/-rw conc-mfs) (rewrite-terms conc/- p-names)]
[(conc) (fuse-by-mode conc/+rw conc/-rw mode)]) [(conc) (assemble mode conc/+rw conc/-rw)])
(with-syntax ([(c ...) conc] (with-syntax ([(c ...) conc]
[(c-mf ...) conc-mfs] [(c-mf ...) conc-mfs]
[(eq ...) eqs] [(eq ...) eqs]
@ -1422,7 +1463,7 @@
(define-values (p/-s p/+s) (split-by-mode (syntax->list prem-body) p-mode)) (define-values (p/-s p/+s) (split-by-mode (syntax->list prem-body) p-mode))
(define-values (p/-rws mf-apps) (rewrite-terms p/-s ns in-judgment-form?)) (define-values (p/-rws mf-apps) (rewrite-terms p/-s ns in-judgment-form?))
(define-values (syncheck-exps p/+rws new-names) (rewrite-pats p/+s lang)) (define-values (syncheck-exps p/+rws new-names) (rewrite-pats p/+s lang))
(define p-rw (fuse-by-mode p/-rws p/+rws p-mode)) (define p-rw (assemble p-mode p/-rws p/+rws))
(with-syntax ([(p ...) p-rw]) (with-syntax ([(p ...) p-rw])
(values (cons #`(begin (values (cons #`(begin
#,@syncheck-exps #,@syncheck-exps

View File

@ -153,6 +153,11 @@
(ctc-fail q s)] (ctc-fail q s)]
[(ctc-fail c s) [(ctc-fail c s)
(ctc-fail a s)])) (ctc-fail a s)]))
(eval '(define-judgment-form L
#:mode (inv-fail I O)
#:contract (inv-fail s_1 s_2)
#:inv ,(not (eq? (term s_1) (term s_2)))
[(inv-fail s a)]))
(exec-runtime-error-tests "run-err-tests/judgment-form-contracts.rktd") (exec-runtime-error-tests "run-err-tests/judgment-form-contracts.rktd")
(exec-runtime-error-tests "run-err-tests/judgment-form-undefined.rktd") (exec-runtime-error-tests "run-err-tests/judgment-form-undefined.rktd")
(exec-runtime-error-tests "run-err-tests/judgment-form-ellipses.rktd")) (exec-runtime-error-tests "run-err-tests/judgment-form-ellipses.rktd"))

View File

@ -1,12 +1,15 @@
(#rx"input q at position 1" (#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail q _\\)"
([judgment (ctc-fail q s)]) ([judgment (ctc-fail q s)])
(judgment-holds judgment)) (judgment-holds judgment))
(#rx"output q at position 2" (#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail a q\\)"
([judgment (ctc-fail a s)]) ([judgment (ctc-fail a s)])
(judgment-holds judgment)) (judgment-holds judgment))
(#rx"input q at position 1" (#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail q _\\)"
([judgment (ctc-fail b s)]) ([judgment (ctc-fail b s)])
(judgment-holds judgment)) (judgment-holds judgment))
(#rx"output q at position 2" (#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail a q\\)"
([judgment (ctc-fail c s)]) ([judgment (ctc-fail c s)])
(judgment-holds judgment)) (judgment-holds judgment))
(#rx"contract: \\(inv-fail s_1 s_2\\).*values: \\(inv-fail a a\\)"
([judgment (inv-fail a s)])
(judgment-holds judgment))