From 172ea135520d24639ed3ae8ddd0853a2f9aae6f1 Mon Sep 17 00:00:00 2001 From: Burke Fetscher Date: Tue, 7 May 2013 15:20:37 -0500 Subject: [PATCH] redex: fix and clean up metafunction expansion - defer disequation expansion so that generated code is linear w/r/t to the number of clauses - fix variable renaming for disequations --- .../redex/private/reduction-semantics.rkt | 126 +++++++++++------- collects/redex/private/term-fn.rkt | 4 +- 2 files changed, 77 insertions(+), 53 deletions(-) diff --git a/collects/redex/private/reduction-semantics.rkt b/collects/redex/private/reduction-semantics.rkt index 987c4b20ea..d78c448692 100644 --- a/collects/redex/private/reduction-semantics.rkt +++ b/collects/redex/private/reduction-semantics.rkt @@ -17,7 +17,8 @@ racket/list racket/set data/union-find - mzlib/etc) + mzlib/etc + (rename-in racket/match (match match:))) (require (for-syntax syntax/name "loc-wrapper-ct.rkt" @@ -1317,7 +1318,7 @@ (syntax->list #'(lhs-names ...)) (syntax->list #'(lhs-namess/ellipses ...)) (syntax->list (syntax (rhs/wheres ...))))] - [((gen-clause lhs-pat lhs-ps/pat*) ...) + [(gen-clause ...) (make-mf-clauses (syntax->list #'(lhs ...)) (syntax->list #'(rhs ...)) (syntax->list #'((stuff ...) ...)) @@ -1339,9 +1340,7 @@ [parent-cases #,(if prev-metafunction #`(metafunc-proc-cases #,(term-fn-get-id (syntax-local-value prev-metafunction))) - #'null)] - [new-lhs-pats '(lhs-pat ...)] - [new-lhs-ps/pats '(lhs-ps/pat* ...)]) + #'null)]) (build-metafunction lang @@ -1367,16 +1366,11 @@ [prev-metafunction #`(extend-mf-clauses #,(term-fn-get-id (syntax-local-value prev-metafunction)) (λ () - #,(check-pats #'(list gen-clause ...))) - new-lhs-ps/pats)] + (add-mf-dqs #,(check-pats #'(list gen-clause ...)))))] [else #`(memoize0 (λ () - #,(check-pats #'(list gen-clause ...))))]) - #,(if prev-metafunction - #`(extend-lhs-pats #,(term-fn-get-id (syntax-local-value prev-metafunction)) - new-lhs-pats) - #`new-lhs-pats))) + (add-mf-dqs #,(check-pats #'(list gen-clause ...)))))]))) #,(if dom-ctcs #'dsc #f) `(codom-side-conditions-rewritten ...) 'name)))) @@ -1384,27 +1378,23 @@ (map syntax-local-introduce (syntax->list #'(original-names ...)))))))))))])) -(define (extend-lhs-pats old-m new-pats) - (append new-pats (metafunc-proc-lhs-pats old-m))) - -(define (extend-mf-clauses old-mf new-clauses new-lhs-ps/pats) +(define (extend-mf-clauses old-mf new-clauses) (memoize0 (λ () - (define old-clauses - (for/list ([old-clauses (in-list ((metafunc-proc-gen-clauses old-mf)))] - [old-lhs-pat (in-list (metafunc-proc-lhs-pats old-mf))]) - (define new-dqs (for/list ([new-lhs-ps/pat (in-list new-lhs-ps/pats)]) - (dqn (first new-lhs-ps/pat) - old-lhs-pat - (second new-lhs-ps/pat)))) - (struct-copy clause old-clauses - [eq/dqs - (append - new-dqs - (clause-eq/dqs old-clauses))]))) - (append - (new-clauses) - old-clauses)))) + (define new-cs (new-clauses)) + (define new-lhss + (for/list ([c new-cs]) + (match: c + [(clause `(list ,c-lhs ,c-rhs) c-eq/dqs c-prems c-lang c-name) + c-lhs]))) + (define new-old-cs + (for/list ([old-c (in-list ((metafunc-proc-gen-clauses old-mf)))]) + (match: old-c + [(clause `(list ,c-lhs ,c-rhs) c-eq/dqs c-prems c-lang c-name) + (define new-dqs (make-clause-dqs c-lhs new-lhss (length new-lhss))) + (struct-copy clause old-c + [eq/dqs (append new-dqs c-eq/dqs)])]))) + (append new-cs new-old-cs)))) (define uniq (gensym)) (define (memoize0 t) @@ -1415,34 +1405,71 @@ ans))) (define-for-syntax (make-mf-clauses lhss rhss extrass nts err-name name lang) - (define-values (rev-clauses _1 _2) - (for/fold ([clauses '()] [prev-lhs-pats '()] [prev-lhs-pats* '()]) + (define rev-clauses + (for/fold ([clauses '()]) ([lhs (in-list lhss)] [rhs (in-list rhss)] [extras (in-list extrass)]) - (with-syntax* ([(lhs-syncheck-expr lhs-pat (names ...) (names/ellipses ...)) (rewrite-side-conditions/check-errs lang err-name #t lhs)] - [((lhs-pat-ps* ...) lhs-pat*) (fix-and-extract-dq-vars #'lhs-pat)]) + (with-syntax ([(lhs-syncheck-expr lhs-pat (names ...) (names/ellipses ...)) (rewrite-side-conditions/check-errs lang err-name #t lhs)]) (define-values (ps-rw extra-eqdqs p-names) (rewrite-prems #f (syntax->list extras) (syntax->datum #'(names ...)) lang 'define-metafunction)) (define-values (rhs-pats mf-clausess) (rewrite-terms (list rhs) p-names)) (define clause-stx (with-syntax ([(prem-rw ...) ps-rw] [(eqs ...) extra-eqdqs] - [(((prev-lhs-pat-ps ...) prev-lhs-pat) ...) prev-lhs-pats*] [(mf-clauses ...) mf-clausess] [(rhs-pat) rhs-pats]) - #`((begin - lhs-syncheck-expr - (clause '(list lhs-pat rhs-pat) - (list eqs ... (dqn '(prev-lhs-pat-ps ...) 'prev-lhs-pat 'lhs-pat) ...) - (list prem-rw ... mf-clauses ...) - #,lang - '#,name)) - lhs-pat - ((lhs-pat-ps* ...) lhs-pat*)))) - (values (cons clause-stx clauses) - (cons #'lhs-pat prev-lhs-pats) - (cons #'((lhs-pat-ps* ...) lhs-pat*) prev-lhs-pats*))))) + #`(begin + lhs-syncheck-expr + (clause '(list lhs-pat rhs-pat) + (list eqs ...) + (list prem-rw ... mf-clauses ...) + #,lang + '#,name)))) + (cons clause-stx clauses)))) (reverse rev-clauses)) +(define (add-mf-dqs clauses) + (define-values (new-clauses _) + (for/fold ([new-cs '()] [prev-lhss '()]) + ([c clauses]) + (match: c + [(clause `(list ,c-lhs ,c-rhs) c-eq/dqs c-prems c-lang c-name) + (define new-dqs (make-clause-dqs c-lhs prev-lhss)) + (define new-c (struct-copy clause c + [eq/dqs (append new-dqs c-eq/dqs)])) + (values (cons new-c new-cs) + (cons c-lhs prev-lhss))]))) + (reverse new-clauses)) + +(define (make-clause-dqs rhs-pat prev-lhs-pats [n 0]) + (define rhs-vs (pat-vars rhs-pat)) + (define fixed-lhss + (for/list ([lhs (in-list prev-lhs-pats)]) + (begin0 + (let recur ([p lhs]) + (match: p + [`(name ,v ,p) + (define new-v (string->symbol (format "~s_lhs~s" v n))) + (let loop ([new-v new-v]) + (if (set-member? rhs-vs new-v) + (loop (string->symbol (format "~s:" new-v))) + `(name ,new-v ,(recur p))))] + [`(list ,ps ...) + `(list ,@(map recur ps))] + [_ p])) + (set! n (add1 n))))) + (for/list ([plhs (in-list fixed-lhss)]) + (define lhs-vs (pat-vars plhs)) + (dqn (set->list lhs-vs) plhs rhs-pat))) + +(define (pat-vars p) + (match: p + [`(name ,v ,p) + (set-add (pat-vars p) v)] + [`(list ,ps ...) + (apply set-union (set) (map pat-vars ps))] + [_ + (set)])) + (define-for-syntax (fix-and-extract-dq-vars pat) (define new-ids (hash)) (let recur ([pat pat]) @@ -1453,8 +1480,7 @@ (let ([vn (syntax-e #'vname)]) (hash-ref new-ids vn (λ () - (define new - (syntax-e (generate-temporary (format "~s_" vn)))) + (define new (string->symbol (format "~s_p" vn))) (set! new-ids (hash-set new-ids vn new)) new))) #'vname)]) diff --git a/collects/redex/private/term-fn.rkt b/collects/redex/private/term-fn.rkt index e1a1217da1..b3680fb1ca 100644 --- a/collects/redex/private/term-fn.rkt +++ b/collects/redex/private/term-fn.rkt @@ -22,7 +22,6 @@ metafunc-proc-dom-pat metafunc-proc-cases metafunc-proc-gen-clauses - metafunc-proc-lhs-pats metafunc-proc? make-metafunc-proc @@ -78,7 +77,7 @@ variable-not-otherwise-mentioned hole symbol)) (define-values (struct:metafunc-proc make-metafunc-proc metafunc-proc? metafunc-proc-ref metafunc-proc-set!) - (make-struct-type 'metafunc-proc #f 11 0 #f null (current-inspector) 0)) + (make-struct-type 'metafunc-proc #f 10 0 #f null (current-inspector) 0)) (define metafunc-proc-clause-names (make-struct-field-accessor metafunc-proc-ref 1)) (define metafunc-proc-pict-info (make-struct-field-accessor metafunc-proc-ref 2)) (define metafunc-proc-lang (make-struct-field-accessor metafunc-proc-ref 3)) @@ -88,7 +87,6 @@ (define metafunc-proc-dom-pat (make-struct-field-accessor metafunc-proc-ref 7)) (define metafunc-proc-cases (make-struct-field-accessor metafunc-proc-ref 8)) (define metafunc-proc-gen-clauses (make-struct-field-accessor metafunc-proc-ref 9)) -(define metafunc-proc-lhs-pats (make-struct-field-accessor metafunc-proc-ref 10)) (define (build-disappeared-use id-stx-table nt id-stx)