(module subst mzscheme (require (lib "match.ss") (prefix plt: (lib "plt-match.ss")) (lib "list.ss")) (provide plt-subst subst all-vars variable subterm subterms constant build subst/proc alpha-rename free-vars/memoize) (define-syntax (all-vars stx) (raise-syntax-error 'subst "all-vars out of context" stx)) (define-syntax (variable stx) (raise-syntax-error 'subst "variable out of context" stx)) (define-syntax (subterm stx) (raise-syntax-error 'subst "subterm out of context" stx)) (define-syntax (subterms stx) (raise-syntax-error 'subst "subterms out of context" stx)) (define-syntax (constant stx) (raise-syntax-error 'subst "constant out of context" stx)) (define-syntax (build stx) (raise-syntax-error 'subst "build out of context" stx)) (define-syntax (make-subst stx) (syntax-case stx () [(_ subst match) (syntax (define-syntax (subst stx) (syntax-case stx () [(_ (pat rhs (... ...)) (... ...)) (with-syntax ([term/arg #'term/arg] [constant/arg #'constant/arg] [variable/arg #'variable/arg] [combine/arg #'combine/arg] [sub-piece/arg #'subpiece/arg]) (define (handle-rhs rhs-stx) (syntax-case rhs-stx (all-vars build subterm subterms variable constant) [((all-vars all-vars-exp) (build build-exp) sub-pieces (... ...)) (with-syntax ([(sub-pieces (... ...)) (map (lambda (subterm-stx) (syntax-case subterm-stx (subterm subterms) [(subterm vars body) (syntax (list (sub-piece/arg vars body)))] [(subterms vars terms) (syntax (let ([terms-var terms]) (unless (list? terms-var) (error 'subst "expected a list of terms for `subterms' subclause, got: ~e" terms-var)) (map (lambda (x) (sub-piece/arg vars x)) terms-var)))] [else (raise-syntax-error 'subst "unknown all-vars subterm" stx subterm-stx)])) (syntax->list (syntax (sub-pieces (... ...)))))]) (syntax (apply combine/arg build-exp all-vars-exp (append sub-pieces (... ...)))))] [((all-vars) sub-pieces (... ...)) (raise-syntax-error 'subst "expected all-vars must have an argument" stx rhs-stx)] [((all-vars all-vars-exp) not-build-clause anything (... ...)) (raise-syntax-error 'subst "expected build clause" (syntax not-build-clause))] [((all-vars all-vars-exp)) (raise-syntax-error 'subst "missing build clause" (syntax (all-vars all-vars-exp)))] [((constant)) (syntax (constant/arg term/arg))] [((variable)) (syntax (variable/arg (lambda (x) x) term/arg))] [(unk unk-more (... ...)) (raise-syntax-error 'subst "unknown clause" (syntax unk))])) (with-syntax ([(expanded-rhs (... ...)) (map handle-rhs (syntax->list (syntax ((rhs (... ...)) (... ...)))))]) (syntax (let ([separate (lambda (term/arg constant/arg variable/arg combine/arg sub-piece/arg) (match term/arg [pat expanded-rhs] (... ...) [else (error 'subst "no matching clauses for ~s\n" term/arg)]))]) (lambda (var val exp) (subst/proc var val exp separate))))))])))])) (make-subst subst match) (make-subst plt-subst plt:match) (define (subst/proc var val exp separate) (let* ([free-vars-cache (make-hash-table 'equal)] [fv-val (free-vars/memoize free-vars-cache val separate)]) (let loop ([exp exp]) (let ([fv-exp (free-vars/memoize free-vars-cache exp separate)] [handle-constant (lambda (x) x)] [handle-variable (lambda (rebuild var-name) (if (equal? var-name var) val (rebuild var-name)))] [handle-complex (lambda (maker vars . subpieces) (cond [(ormap (lambda (var) (memq var fv-val)) vars) => (lambda (to-be-renamed-l) (let ([to-be-renamed (car to-be-renamed-l)]) (loop (alpha-rename to-be-renamed (pick-new-name to-be-renamed (cons to-be-renamed fv-val)) exp separate))))] [else (apply maker vars (map (lambda (subpiece) (let ([sub-term-binders (subpiece-binders subpiece)] [sub-term (subpiece-term subpiece)]) (if (memq var sub-term-binders) sub-term (loop sub-term)))) subpieces))]))]) (if (member var fv-exp) (separate exp handle-constant handle-variable handle-complex make-subpiece) exp))))) (define-struct subpiece (binders term) (make-inspector)) ;; alpha-rename : symbol symbol term separate -> term ;; renames the occurrences of to-be-renamed that are ;; bound in the "first level" of exp. (define (alpha-rename to-be-renamed new-name exp separate) (define (first exp) (separate exp first-handle-constant first-handle-variable first-handle-complex first-handle-complex-subpiece)) (define (first-handle-constant x) x) (define (first-handle-variable rebuild var) (rebuild var)) (define (first-handle-complex maker vars . subpieces) (let ([replaced-vars (map (lambda (x) (if (eq? x to-be-renamed) new-name x)) vars)]) (apply maker replaced-vars subpieces))) (define (first-handle-complex-subpiece binders subterm) (if (memq to-be-renamed binders) (beyond-first subterm) subterm)) (define (beyond-first exp) (define (handle-constant x) x) (define (handle-variable rebuild var) (if (eq? var to-be-renamed) (rebuild new-name) (rebuild var))) (define (handle-complex maker vars . subpieces) (apply maker vars subpieces)) (define (handle-complex-subpiece binders subterm) (if (memq to-be-renamed binders) subterm (beyond-first subterm))) (separate exp handle-constant handle-variable handle-complex handle-complex-subpiece)) (first exp)) ;; free-vars/memoize : hash-table[sexp -o> (listof symbols)] sexp separate -> (listof symbols) ;; doesn't cache against separate -- if it changes, a new hash-table must be passed in, ;; or the caching will be wrong (define (free-vars/memoize cache exp separate) (hash-table-get cache exp (lambda () (let ([res (free-vars/compute cache exp separate)]) (hash-table-put! cache exp res) res)))) ;; free-vars/memoize : hash-table[sexp -o> (listof symbols)] sexp separate -> (listof symbols) (define (free-vars/compute cache exp separate) (let ([handle-constant (lambda (x) '())] [handle-variable (lambda (rebuild var) (list var))] [handle-complex (lambda (maker vars . subpieces) (apply append subpieces))] [handle-complex-subpiece (lambda (binders subterm) (foldl remove-all (free-vars/memoize cache subterm separate) binders))]) (separate exp handle-constant handle-variable handle-complex handle-complex-subpiece))) (define (remove-all var lst) (let loop ([lst lst] [ans '()]) (cond [(null? lst) ans] [else (if (eq? (car lst) var) (loop (cdr lst) ans) (loop (cdr lst) (cons (car lst) ans)))]))) (define (lc-direct-subst var val exp) (let ([fv-exp (lc-direct-free-vars exp)]) (if (memq var fv-exp) (match exp [`(lambda ,vars ,body) (if (memq var vars) exp (let* ([fv-val (lc-direct-free-vars val)] [vars1 (map (lambda (var) (pick-new-name var fv-val)) vars)]) `(lambda ,vars1 ,(lc-direct-subst var val (lc-direct-subst/l vars vars1 body)))))] [`(let (,l-var ,exp) ,body) (if (eq? l-var var) `(let (,l-var ,(lc-direct-subst var val exp)) ,body) (let* ([fv-val (lc-direct-free-vars val)] [l-var1 (pick-new-name l-var fv-val)]) `(let (,l-var1 ,(lc-direct-subst var val exp)) ,(lc-direct-subst var val (lc-direct-subst l-var l-var1 body)))))] [(? number?) exp] [(and var1 (? symbol?)) (if (eq? var1 var) val var1)] [`(,@(args ...)) `(,@(map (lambda (arg) (lc-direct-subst var val arg)) args))]) exp))) ;; lc-direct-subst/l : (listof symbol) (listof symbol) (listof symbol) sexp -> exp ;; substitutes each of vars with vals in exp ;; [assume vals don't contain any vars] (define (lc-direct-subst/l vars vals exp) (foldr (lambda (var val exp) (lc-direct-subst var val exp)) exp vars vals)) ;; lc-direct-free-vars : sexp -> (listof symbol) ;; returns the free variables in exp (define (lc-direct-free-vars exp) (let ([ht (make-hash-table)]) (let loop ([exp exp] [binding-vars null]) (match exp [(? symbol?) (unless (memq exp binding-vars) (hash-table-put! ht exp #t))] [(? number?) (void)] [`(lambda ,vars ,body) (loop body (append vars binding-vars))] [`(let (,var ,exp) ,body) (loop exp binding-vars) (loop body (cons var binding-vars))] [`(,@(args ...)) (for-each (lambda (arg) (loop arg binding-vars)) args)])) (hash-table-map ht (lambda (x y) x)))) ;; pick-new-name : symbol (listof symbol) -> symbol ;; returns a primed version of `var' that does ;; not occur in vars (possibly with no primes) (define (pick-new-name var vars) (if (member var vars) (pick-new-name (prime var) vars) var)) ;; prime : symbol -> symbol ;; adds an @ at the end of the symbol (define (prime var) (string->symbol (string-append (symbol->string var) "@"))))