diff --git a/pkgs/redex-pkgs/redex-doc/redex/scribblings/ref.scrbl b/pkgs/redex-pkgs/redex-doc/redex/scribblings/ref.scrbl index be6275c981..6390ec306b 100644 --- a/pkgs/redex-pkgs/redex-doc/redex/scribblings/ref.scrbl +++ b/pkgs/redex-pkgs/redex-doc/redex/scribblings/ref.scrbl @@ -1259,10 +1259,13 @@ and @racket[#f] otherwise. (define-judgment-form language mode-spec contract-spec + invariant-spec rule rule ...) ([mode-spec (code:line #:mode (form-id pos-use ...))] [contract-spec (code:line) (code:line #:contract (form-id @#,ttpattern-sequence ...))] + [invariant-spec (code:line #:inv @#,tttterm) + (code:line)] [pos-use I O] [rule [premise @@ -1299,7 +1302,10 @@ and input positions in premises must be @|tttterm|s; input positions in conclusi output positions in premises must be @|ttpattern|s. When the optional @racket[contract-spec] declaration is present, Redex dynamically checks that the terms flowing through these positions match the provided patterns, raising an exception recognized by -@racket[exn:fail:redex] if not. +@racket[exn:fail:redex] if not. The term in the optional @racket[invariant-spec] is +evaluated after the output positions have been computed and the contract has matched +successfully, with variables from the contract bound; a result of @racket[#f] is +considered to be a contract violation and an exception is raised. For example, the following defines addition on natural numbers: @interaction[ diff --git a/pkgs/redex-pkgs/redex-lib/redex/private/judgment-form.rkt b/pkgs/redex-pkgs/redex-lib/redex/private/judgment-form.rkt index 8267794661..b69f327347 100644 --- a/pkgs/redex-pkgs/redex-lib/redex/private/judgment-form.rkt +++ b/pkgs/redex-pkgs/redex-lib/redex/private/judgment-form.rkt @@ -79,6 +79,33 @@ (define-syntax (fresh stx) (raise-syntax-error 'fresh "used outside of reduction-relation")) (define-syntax (with stx) (raise-syntax-error 'with "used outside of reduction-relation")) +(module mode-utils racket/base + + (require racket/list) + + (provide split-by-mode + assemble) + + (define (split-by-mode xs mode) + (for/fold ([ins '()] [outs '()]) + ([x (reverse xs)] + [m (reverse mode)]) + (case m + [(I) (values (cons x ins) outs)] + [(O) (values ins (cons x outs))] + [else (error 'split-by-mode "ack ~s" m)]))) + + (define (assemble mode inputs outputs) + (let loop ([ms mode] [is inputs] [os outputs]) + (if (null? ms) + '() + (case (car ms) + [(I) (cons (car is) (loop (cdr ms) (cdr is) os))] + [(O) (cons (car os) (loop (cdr ms) is (cdr os)))]))))) + +(require 'mode-utils + (for-syntax 'mode-utils)) + (define-for-syntax (generate-binding-constraints names names/ellipses bindings syn-err-name) (define (id/depth stx) (syntax-case stx () @@ -313,14 +340,6 @@ (reverse subs))) this-output))) -(define (assemble mode inputs outputs) - (let loop ([ms mode] [is inputs] [os outputs]) - (if (null? ms) - '() - (case (car ms) - [(I) (cons (car is) (loop (cdr ms) (cdr is) os))] - [(O) (cons (car os) (loop (cdr ms) is (cdr os)))])))) - (define (verify-name-ok orig-name the-name) (unless (symbol? the-name) (error orig-name "expected a single name, got ~s" the-name))) @@ -550,7 +569,7 @@ (define-for-syntax (do-extended-judgment-form lang syn-err-name body orig stx is-relation?) (define nts (definition-nts lang stx syn-err-name)) - (define-values (judgment-form-name dup-form-names mode position-contracts clauses rule-names) + (define-values (judgment-form-name dup-form-names mode position-contracts invariant clauses rule-names) (parse-judgment-form-body body syn-err-name stx (identifier? orig) is-relation?)) (define definitions (with-syntax ([judgment-form-runtime-proc @@ -563,7 +582,7 @@ #'mk-judgment-form-proc #'#,lang #'jf-lws '#,rule-names #'judgment-runtime-gen-clauses #'mk-judgment-gen-clauses #'jf-term-proc #,is-relation?)) (define-values (mk-judgment-form-proc mk-judgment-gen-clauses) - (compile-judgment-form #,judgment-form-name #,mode #,lang #,clauses #,rule-names #,position-contracts + (compile-judgment-form #,judgment-form-name #,mode #,lang #,clauses #,rule-names #,position-contracts #,invariant #,orig #,stx #,syn-err-name judgment-runtime-gen-clauses)) (define judgment-form-runtime-proc (mk-judgment-form-proc #,lang)) (define jf-lws (compiled-judgment-form-lws #,clauses #,judgment-form-name #,stx)) @@ -625,32 +644,41 @@ (cons #f names))]))) (values (reverse backward-rules) (reverse backward-names))) - (define-values (name/mode mode-stx name/contract contract rules rule-names) + (define-values (name/mode mode-stx name/contract contract invariant rules rule-names) (syntax-parse body #:context full-stx [((~or (~seq #:mode ~! mode:mode-spec) - (~seq #:contract ~! contract:contract-spec)) + (~seq #:contract ~! contract:contract-spec) + (~seq #:inv ~! inv:expr)) ... rule:expr ...) (let-values ([(name/mode mode) (syntax-parse #'(mode ...) - [((name the-mode ...)) (values #'name (car (syntax->list #'(mode ...))))] - [_ - (raise-syntax-error - #f - (if (null? (syntax->list #'(mode ...))) - "expected definition to include a mode specification" - "expected definition to include only one mode specification") - full-stx)])] + [((name the-mode ...)) (values #'name (car (syntax->list #'(mode ...))))] + [_ + (raise-syntax-error + #f + (if (null? (syntax->list #'(mode ...))) + "expected definition to include a mode specification" + "expected definition to include only one mode specification") + full-stx)])] [(name/ctc ctc) (syntax-parse #'(contract ...) - [() (values #f #f)] - [((name . contract)) (values #'name (syntax->list #'contract))] - [(_ . dups) - (raise-syntax-error - syn-err-name "expected at most one contract specification" - #f #f (syntax->list #'dups))])]) + [() (values #f #f)] + [((name . contract)) (values #'name (syntax->list #'contract))] + [(_ . dups) + (raise-syntax-error + syn-err-name "expected at most one contract specification" + #f #f (syntax->list #'dups))])] + [(invt) + (syntax-parse #'(inv ...) + [() #f] + [(invar) #'invar] + [(_ . dups) + (raise-syntax-error + syn-err-name "expected at most one invariant specification" + #f #f (syntax->list #'dups))])]) (define-values (parsed-rules rule-names) (parse-rules (syntax->list #'(rule ...)))) - (values name/mode mode name/ctc ctc parsed-rules rule-names))])) + (values name/mode mode name/ctc ctc invt parsed-rules rule-names))])) (check-clauses full-stx syn-err-name rules #t) (check-dup-rule-names full-stx syn-err-name rule-names) (check-arity-consistency mode-stx contract full-stx) @@ -667,7 +695,7 @@ [(symbol? (syntax-e name)) (symbol->string (syntax-e name))] [else (syntax-e name)]))) - (values form-name dup-names mode-stx contract rules string-rule-names)) + (values form-name dup-names mode-stx contract invariant rules string-rule-names)) ;; names : (listof (or/c #f syntax[string])) (define-for-syntax (check-dup-rule-names full-stx syn-err-name names) @@ -795,11 +823,13 @@ [(_ jf-expr) #'(#%expression (judgment-holds/derivation build-derivations #t jf-expr any))])) -(define-for-syntax (do-compile-judgment-form-proc name mode-stx clauses rule-names contracts nts orig lang stx syn-error-name) +(define-for-syntax (do-compile-judgment-form-proc name mode-stx clauses rule-names contracts orig-ctcs nts orig lang stx syn-error-name) (with-syntax ([(init-jf-derivation-id) (generate-temporaries '(init-jf-derivation-id))]) (define mode (cdr (syntax->datum mode-stx))) (define-values (input-contracts output-contracts) - (if contracts + (values (first contracts) + (second contracts)) + #;(if contracts (let-values ([(ins outs) (split-by-mode contracts mode)]) (values ins outs)) (values #f #f))) @@ -833,17 +863,17 @@ ;; pieces of a 'let' expression to be combined: first some bindings ([compiled-lhs (compile-pattern lang `lhs #t)] #,@(if input-contracts - (list #`[compiled-input-ctcs #,(contracts-compilation input-contracts)]) + (list #`[compiled-input-ctcs (compile-pattern lang `#,input-contracts #f)]) (list)) #,@(if output-contracts - (list #`[compiled-output-ctcs #,(contracts-compilation output-contracts)]) + (list #`[compiled-output-ctcs (compile-pattern lang `#,output-contracts #f)]) (list))) ;; and then the body of the let, but expected to be behind a (λ (input) ...). (let ([jf-derivation-id init-jf-derivation-id]) (begin lhs-syncheck-exp #,@(if input-contracts - (list #`(check-judgment-form-contract '#,name input compiled-input-ctcs 'I '#,mode)) + (list #`(check-judgment-form-contract '#,name input #f compiled-input-ctcs '#,orig-ctcs 'I '#,mode)) (list)) (combine-judgment-rhses compiled-lhs @@ -855,7 +885,7 @@ body)) #,(if output-contracts #`(λ (output) - (check-judgment-form-contract '#,name output compiled-output-ctcs 'O '#,mode)) + (check-judgment-form-contract '#,name input output compiled-output-ctcs '#,orig-ctcs 'O '#,mode)) #`void))))))))])) (when (identifier? orig) @@ -942,14 +972,28 @@ (list (reverse rhses) (reverse sc/wheres)))) -(define (check-judgment-form-contract form-name term+trees contracts mode modes) - (define terms (if (eq? mode 'O) - (derivation-with-output-only-output term+trees) - term+trees)) +(define (check-judgment-form-contract form-name input-term output-term+trees contracts orig-ctcs mode modes) + (define o-term (and (eq? mode 'O) + (derivation-with-output-only-output output-term+trees))) (define description (case mode [(I) "input"] [(O) "output"])) + (when contracts + (case mode + [(I) + (unless (match-pattern contracts input-term) + (redex-error form-name (string-append "judgment input values do not match its contract;\n" + " (unknown output values indicated by _)\n contract: ~a\n values: ~a") + (cons form-name orig-ctcs) + (cons form-name (assemble modes input-term (build-list (length modes) + (λ (_) '_))))))] + [(O) + (define io-term (assemble modes input-term o-term)) + (unless (match-pattern contracts io-term) + (redex-error form-name "judgment values do not match its contract;\n contract: ~a\n values: ~a" + (cons form-name orig-ctcs) (cons form-name io-term)))])) + #; (when contracts (let loop ([rest-modes modes] [rest-terms terms] [rest-ctcs contracts] [pos 1]) (unless (null? rest-modes) @@ -1125,7 +1169,7 @@ (define-syntax (compile-judgment-form stx) (syntax-case stx () - [(_ judgment-form-name mode-arg lang raw-clauses rule-names ctcs orig full-def syn-err-name judgment-form-runtime-gen-clauses) + [(_ judgment-form-name mode-arg lang raw-clauses rule-names ctcs invt orig full-def syn-err-name judgment-form-runtime-gen-clauses) (let ([nts (definition-nts #'lang #'full-def (syntax-e #'syn-err-name))] [rule-names (syntax->datum #'rule-names)] [syn-err-name (syntax-e #'syn-err-name)] @@ -1136,6 +1180,35 @@ [mode (cdr (syntax->datum #'mode-arg))]) (unless (jf-is-relation? #'judgment-form-name) (mode-check (cdr (syntax->datum #'mode-arg)) clauses nts syn-err-name stx)) + (define maybe-wrap-contract (if (syntax-e #'invt) + (λ (ctc-stx) + #`(side-condition #,ctc-stx (term invt))) + values)) + (define-values (i-ctc-syncheck-expr i-ctc) + (syntax-case #'ctcs () + [#f (values #'(void) #f)] + [(p ...) + (let-values ([(i-ctcs o-ctcs) (split-by-mode (syntax->list #'(p ...)) mode)]) + (with-syntax* ([(i-ctcs ...) i-ctcs] + [(syncheck i-ctc-pat (names ...) (names/ellipses ...)) + (rewrite-side-conditions/check-errs #'lang #'syn-error-name #f #'(i-ctcs ...))]) + (values #'syncheck #'i-ctc-pat)))])) + (define-values (ctc-syncheck-expr ctc) + (cond + [(not (or (syntax-e #'ctcs) + (syntax-e #'invt))) + (values #'(void) #f)] + [else + (define ctc-stx ((if (syntax-e #'invt) + (λ (ctc-stx) + #`(side-condition #,ctc-stx (term invt))) + values) + (if (syntax-e #'ctcs) + #'ctcs + #'any))) + (with-syntax ([(syncheck ctc-pat (names ...) (names/ellipses ...)) + (rewrite-side-conditions/check-errs #'lang #'syn-error-name #f ctc-stx)]) + (values #'syncheck #'ctc-pat))])) (define-values (syncheck-exprs contracts) (syntax-case #'ctcs () [#f (values '() #f)] @@ -1156,7 +1229,8 @@ #'mode-arg clauses rule-names - contracts + (list i-ctc ctc) + #'ctcs nts #'orig #'lang @@ -1179,7 +1253,8 @@ (λ () #,(check-pats #'(list comp-clauses ...))))))) - #`(begin #,@syncheck-exprs (values #,proc-stx #,gen-stx)))])) + #`(begin #,i-ctc-syncheck-expr #,ctc-syncheck-expr + (values #,proc-stx #,gen-stx)))])) (define-for-syntax (rewrite-relation-prems clauses) (map (λ (c) @@ -1334,40 +1409,6 @@ (syntax->list #'(x ...))) (raise-syntax-error syn-error-name "error checking failed.2" stx))])) -(define-for-syntax (split-by-mode xs mode) - (for/fold ([ins '()] [outs '()]) - ([x (reverse xs)] - [m (reverse mode)]) - (case m - [(I) (values (cons x ins) outs)] - [(O) (values ins (cons x outs))] - [else (error 'split-by-mode "ack ~s" m)]))) - -(define-for-syntax (fuse-by-mode ins outs mode) - (let loop ([is (reverse ins)] - [os (reverse outs)] - [ms (reverse mode)] - [res '()]) - (define err (λ () (error 'fuse-by-mode "mismatched mode and split: ~s ~s ~s" ins outs mode))) - (cond - [(and (empty? ms) - (empty? is) - (empty? os)) - res] - [(empty? ms) - (err)] - [else - (case (car ms) - [(I) (if (empty? is) - (err) - (loop (cdr is) os (cdr ms) - (cons (car is) res)))] - [(O) (if (empty? os) - (err) - (loop is (cdr os) (cdr ms) - (cons (car os) res)))] - [else (error 'fuse-by-mode "ack ~s" (car ms))])]))) - (define-for-syntax (ellipsis? stx) (and (identifier? stx) (free-identifier=? stx (quote-syntax ...)))) @@ -1405,7 +1446,7 @@ [(syncheck-exps conc/+rw names) (rewrite-pats conc/+ lang)] [(ps-rw eqs p-names) (rewrite-prems #t (syntax->list #'(prems ...)) names lang 'define-judgment-form)] [(conc/-rw conc-mfs) (rewrite-terms conc/- p-names)] - [(conc) (fuse-by-mode conc/+rw conc/-rw mode)]) + [(conc) (assemble mode conc/+rw conc/-rw)]) (with-syntax ([(c ...) conc] [(c-mf ...) conc-mfs] [(eq ...) eqs] @@ -1422,7 +1463,7 @@ (define-values (p/-s p/+s) (split-by-mode (syntax->list prem-body) p-mode)) (define-values (p/-rws mf-apps) (rewrite-terms p/-s ns in-judgment-form?)) (define-values (syncheck-exps p/+rws new-names) (rewrite-pats p/+s lang)) - (define p-rw (fuse-by-mode p/-rws p/+rws p-mode)) + (define p-rw (assemble p-mode p/-rws p/+rws)) (with-syntax ([(p ...) p-rw]) (values (cons #`(begin #,@syncheck-exps diff --git a/pkgs/redex-pkgs/redex-test/redex/tests/err-loc-test.rkt b/pkgs/redex-pkgs/redex-test/redex/tests/err-loc-test.rkt index ea0743312b..ada8763603 100644 --- a/pkgs/redex-pkgs/redex-test/redex/tests/err-loc-test.rkt +++ b/pkgs/redex-pkgs/redex-test/redex/tests/err-loc-test.rkt @@ -153,6 +153,11 @@ (ctc-fail q s)] [(ctc-fail c s) (ctc-fail a s)])) + (eval '(define-judgment-form L + #:mode (inv-fail I O) + #:contract (inv-fail s_1 s_2) + #:inv ,(not (eq? (term s_1) (term s_2))) + [(inv-fail s a)])) (exec-runtime-error-tests "run-err-tests/judgment-form-contracts.rktd") (exec-runtime-error-tests "run-err-tests/judgment-form-undefined.rktd") (exec-runtime-error-tests "run-err-tests/judgment-form-ellipses.rktd")) diff --git a/pkgs/redex-pkgs/redex-test/redex/tests/run-err-tests/judgment-form-contracts.rktd b/pkgs/redex-pkgs/redex-test/redex/tests/run-err-tests/judgment-form-contracts.rktd index 84cfa15b4d..fc97e8e8bc 100644 --- a/pkgs/redex-pkgs/redex-test/redex/tests/run-err-tests/judgment-form-contracts.rktd +++ b/pkgs/redex-pkgs/redex-test/redex/tests/run-err-tests/judgment-form-contracts.rktd @@ -1,12 +1,15 @@ -(#rx"input q at position 1" +(#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail q _\\)" ([judgment (ctc-fail q s)]) (judgment-holds judgment)) -(#rx"output q at position 2" +(#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail a q\\)" ([judgment (ctc-fail a s)]) (judgment-holds judgment)) -(#rx"input q at position 1" +(#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail q _\\)" ([judgment (ctc-fail b s)]) (judgment-holds judgment)) -(#rx"output q at position 2" +(#rx"contract: \\(ctc-fail s s\\).*values: \\(ctc-fail a q\\)" ([judgment (ctc-fail c s)]) (judgment-holds judgment)) +(#rx"contract: \\(inv-fail s_1 s_2\\).*values: \\(inv-fail a a\\)" + ([judgment (inv-fail a s)]) + (judgment-holds judgment))