diff --git a/collects/scheme/match/compiler.ss b/collects/scheme/match/compiler.ss index 2b13087f8b..1bb7315c79 100644 --- a/collects/scheme/match/compiler.ss +++ b/collects/scheme/match/compiler.ss @@ -5,8 +5,10 @@ syntax/stx "patterns.ss" "split-rows.ss" + "reorder.ss" scheme/struct-info scheme/stxparam + scheme/nest (only-in srfi/1 delete-duplicates)) (provide compile*) @@ -77,35 +79,36 @@ ;; vectors are handled specially ;; because each arity is like a different constructor [(eq? 'vector k) - (let () - (define ht - (hash-on (lambda (r) (length (Vector-ps (Row-first-pat r)))) rows)) - (with-syntax - ([(clauses ...) - (hash-table-map - ht - (lambda (arity rows) - (define ns (build-list arity values)) - (with-syntax ([(tmps ...) (generate-temporaries ns)]) - (with-syntax ([body - (compile* - (append (syntax->list #'(tmps ...)) xs) - (map (lambda (row) - (define-values (p1 ps) - (Row-split-pats row)) - (make-Row (append (Vector-ps p1) ps) - (Row-rhs row) - (Row-unmatch row) - (Row-vars-seen row))) - rows) - esc)] - [(n ...) ns]) - #`[(#,arity) - (let ([tmps (vector-ref #,x n)] ...) - body)]))))]) - #`[(vector? #,x) - (case (vector-length #,x) - clauses ...)]))] + (nest + ([let ()] + [let ([ht (hash-on (lambda (r) + (length (Vector-ps (Row-first-pat r)))) rows)])] + [with-syntax + ([(clauses ...) + (hash-table-map + ht + (lambda (arity rows) + (define ns (build-list arity values)) + (with-syntax ([(tmps ...) (generate-temporaries ns)]) + (with-syntax ([body + (compile* + (append (syntax->list #'(tmps ...)) xs) + (map (lambda (row) + (define-values (p1 ps) + (Row-split-pats row)) + (make-Row (append (Vector-ps p1) ps) + (Row-rhs row) + (Row-unmatch row) + (Row-vars-seen row))) + rows) + esc)] + [(n ...) ns]) + #`[(#,arity) + (let ([tmps (vector-ref #,x n)] ...) + body)]))))])]) + #`[(vector? #,x) + (case (vector-length #,x) + clauses ...)])] ;; it's a structure [(box? k) ;; all the rows are structures with the same predicate @@ -192,23 +195,28 @@ (error 'compile-one "Or block with multiple rows: ~a" block)) (let* ([row (car block)] [pats (Row-pats row)] + [seen (Row-vars-seen row)] ;; all the pattern alternatives [qs (Or-ps (car pats))] ;; the variables bound by this pattern - they're the same for the ;; whole list - [vars (bound-vars (car qs))]) - (with-syntax ([vars vars]) + [vars + (for/list ([bv (bound-vars (car qs))] + #:when (for/and ([seen-var seen]) + (not (free-identifier=? bv (car seen-var))))) + bv)]) + (with-syntax ([(var ...) vars]) ;; do the or matching, and bind the results to the appropriate ;; variables #`(let/ec exit (let ([esc* (lambda () (exit (#,esc)))]) - (let-values ([vars + (let-values ([(var ...) #,(compile* (list x) (map (lambda (q) (make-Row (list q) - #'(values . vars) + #'(values var ...) #f - (Row-vars-seen row))) + seen)) qs) #'esc*)]) ;; then compile the rest of the row @@ -216,9 +224,7 @@ (list (make-Row (cdr pats) (Row-rhs row) (Row-unmatch row) - (let ([vs (syntax->list #'vars)]) - (append (map cons vs vs) - (Row-vars-seen row))))) + (append (map cons vars vars) seen))) esc))))))] ;; the App rule [(App? first) @@ -295,99 +301,101 @@ #`(cond [(pred? #,x) body] [else (#,esc)]))] ;; Generalized sequences... slightly tested [(GSeq? first) - (let* ([headss (GSeq-headss first)] - [mins (GSeq-mins first)] - [maxs (GSeq-maxs first)] - [onces? (GSeq-onces? first)] - [tail (GSeq-tail first)] - [k (Row-rhs (car block))] - [xvar (car (generate-temporaries (list #'x)))] - [complete-heads-pattern - (lambda (ps) - (define (loop ps pat) - (if (pair? ps) - (make-Pair (car ps) (loop (cdr ps) pat)) - pat)) - (loop ps (make-Var xvar)))] - [heads - (for/list ([ps headss]) - (complete-heads-pattern ps))] - [head-idss - (for/list ([heads headss]) - (apply append (map bound-vars heads)))] - [hid-argss (map generate-temporaries head-idss)] - [head-idss* (map generate-temporaries head-idss)] - [hid-args (apply append hid-argss)] - [reps (generate-temporaries (for/list ([head heads]) 'rep))]) - (with-syntax ([x xvar] - [var0 (car vars)] - [((hid ...) ...) head-idss] - [((hid* ...) ...) head-idss*] - [((hid-arg ...) ...) hid-argss] - [(rep ...) reps] - [(maxrepconstraint ...) - ;; FIXME: move to side condition to appropriate pattern - (for/list ([repvar reps] [maxrep maxs]) - (if maxrep #`(< #,repvar #,maxrep) #`#t))] - [(minrepclause ...) - (for/list ([repvar reps] [minrep mins] #:when minrep) - #`[(< #,repvar #,minrep) (fail)])] - [((hid-rhs ...) ...) - (for/list ([hid-args hid-argss] [once? onces?]) - (for/list ([hid-arg hid-args]) - (if once? - #`(car (reverse #,hid-arg)) - #`(reverse #,hid-arg))))] - [(parse-loop failkv fail-tail) - (generate-temporaries #'(parse-loop failkv fail-tail))]) - (with-syntax ([(rhs ...) - #`[(let ([hid-arg (cons hid* hid-arg)] ...) - (if maxrepconstraint - (let ([rep (add1 rep)]) - (parse-loop x #,@hid-args #,@reps fail)) - (begin (fail)))) - ...]] - [tail-rhs - #`(cond minrepclause ... - [else - (let ([hid hid-rhs] ... ... - [fail-tail fail]) - #,(compile* - (cdr vars) - (list (make-Row rest-pats k - (Row-unmatch (car block)) - (Row-vars-seen - (car block)))) - #'fail-tail))])]) - (parameterize ([current-renaming - (for/fold ([ht (copy-mapping (current-renaming))]) - ([id (apply append head-idss)] - [id* (apply append head-idss*)]) - (free-identifier-mapping-put! ht id id*) - (free-identifier-mapping-for-each - ht - (lambda (k v) - (when (free-identifier=? v id) - (free-identifier-mapping-put! ht k id*)))) - ht)]) - #`(let parse-loop ([x var0] - [hid-arg null] ... ... - [rep 0] ... - [failkv #,esc]) - #,(compile* (list #'x) - (append - (map (lambda (pats rhs) - (make-Row pats - rhs - (Row-unmatch (car block)) - null)) - (map list heads) - (syntax->list #'(rhs ...))) - (list (make-Row (list tail) - #`tail-rhs - (Row-unmatch (car block)) - null))) - #'failkv))))))] + (nest + ([let* ([headss (GSeq-headss first)] + [mins (GSeq-mins first)] + [maxs (GSeq-maxs first)] + [onces? (GSeq-onces? first)] + [tail (GSeq-tail first)] + [k (Row-rhs (car block))] + [xvar (car (generate-temporaries (list #'x)))] + [complete-heads-pattern + (lambda (ps) + (define (loop ps pat) + (if (pair? ps) + (make-Pair (car ps) (loop (cdr ps) pat)) + pat)) + (loop ps (make-Var xvar)))] + [heads + (for/list ([ps headss]) + (complete-heads-pattern ps))] + [head-idss + (for/list ([heads headss]) + (apply append (map bound-vars heads)))] + [hid-argss (map generate-temporaries head-idss)] + [head-idss* (map generate-temporaries head-idss)] + [hid-args (apply append hid-argss)] + [reps (generate-temporaries (for/list ([head heads]) 'rep))])] + [with-syntax + ([x xvar] + [var0 (car vars)] + [((hid ...) ...) head-idss] + [((hid* ...) ...) head-idss*] + [((hid-arg ...) ...) hid-argss] + [(rep ...) reps] + [(maxrepconstraint ...) + ;; FIXME: move to side condition to appropriate pattern + (for/list ([repvar reps] [maxrep maxs]) + (if maxrep #`(< #,repvar #,maxrep) #`#t))] + [(minrepclause ...) + (for/list ([repvar reps] [minrep mins] #:when minrep) + #`[(< #,repvar #,minrep) (fail)])] + [((hid-rhs ...) ...) + (for/list ([hid-args hid-argss] [once? onces?]) + (for/list ([hid-arg hid-args]) + (if once? + #`(car (reverse #,hid-arg)) + #`(reverse #,hid-arg))))] + [(parse-loop failkv fail-tail) + (generate-temporaries #'(parse-loop failkv fail-tail))])] + [with-syntax ([(rhs ...) + #`[(let ([hid-arg (cons hid* hid-arg)] ...) + (if maxrepconstraint + (let ([rep (add1 rep)]) + (parse-loop x #,@hid-args #,@reps fail)) + (begin (fail)))) + ...]] + [tail-rhs + #`(cond minrepclause ... + [else + (let ([hid hid-rhs] ... ... + [fail-tail fail]) + #,(compile* + (cdr vars) + (list (make-Row rest-pats k + (Row-unmatch (car block)) + (Row-vars-seen + (car block)))) + #'fail-tail))])])] + [parameterize ([current-renaming + (for/fold ([ht (copy-mapping (current-renaming))]) + ([id (apply append head-idss)] + [id* (apply append head-idss*)]) + (free-identifier-mapping-put! ht id id*) + (free-identifier-mapping-for-each + ht + (lambda (k v) + (when (free-identifier=? v id) + (free-identifier-mapping-put! ht k id*)))) + ht)])]) + #`(let parse-loop ([x var0] + [hid-arg null] ... ... + [rep 0] ... + [failkv #,esc]) + #,(compile* (list #'x) + (append + (map (lambda (pats rhs) + (make-Row pats + rhs + (Row-unmatch (car block)) + null)) + (map list heads) + (syntax->list #'(rhs ...))) + (list (make-Row (list tail) + #`tail-rhs + (Row-unmatch (car block)) + null))) + #'failkv)))] [else (error 'compile "unsupported pattern: ~a~n" first)])) (define (compile* vars rows esc) @@ -424,20 +432,22 @@ ;; otherwise, we split the matrix into blocks ;; and compile each block with a reference to its continuation - (let ([fns - (let loop ([blocks (reverse (split-rows rows))] [esc esc] [acc null]) - (if (null? blocks) - ;; if we're done, return the blocks - (reverse acc) - (with-syntax (;; f is the name this block will have - [(f) (generate-temporaries #'(f))] - ;; compile the block, with jumps to the previous - ;; esc - [c (compile-one vars (car blocks) esc)]) - ;; then compile the rest, with our name as the esc - (loop (cdr blocks) #'f (cons #'[f (lambda () c)] acc)))))]) - (with-syntax ([(fns ... [_ (lambda () body)]) fns]) - (let/wrap #'(fns ...) #'body))))) + (let*-values + ([(rows vars) (reorder-columns rows vars)] + [(fns) + (let loop ([blocks (reverse (split-rows rows))] [esc esc] [acc null]) + (if (null? blocks) + ;; if we're done, return the blocks + (reverse acc) + (with-syntax (;; f is the name this block will have + [(f) (generate-temporaries #'(f))] + ;; compile the block, with jumps to the previous + ;; esc + [c (compile-one vars (car blocks) esc)]) + ;; then compile the rest, with our name as the esc + (loop (cdr blocks) #'f (cons #'[f (lambda () c)] acc)))))]) + (with-syntax ([(fns ... [_ (lambda () body)]) fns]) + (let/wrap #'(fns ...) #'body))))) ;; (require mzlib/trace) ;; (trace compile* compile-one) diff --git a/collects/scheme/match/gen-match.ss b/collects/scheme/match/gen-match.ss index c0f948ba66..e121a8a184 100644 --- a/collects/scheme/match/gen-match.ss +++ b/collects/scheme/match/gen-match.ss @@ -1,7 +1,7 @@ #lang scheme/base (require "patterns.ss" "compiler.ss" - syntax/stx + syntax/stx scheme/nest (for-template scheme/base (only-in "patterns.ss" match:error))) (provide go) @@ -9,52 +9,41 @@ ;; this parses the clauses using parse/cert, then compiles them ;; go : syntax syntax syntax certifier -> syntax (define (go parse/cert stx exprs clauses cert) - (parameterize ([orig-stx stx]) - (syntax-case clauses () - [([pats . rhs] ...) - (let ([len (length (syntax->list exprs))]) - (with-syntax ([(xs ...) (generate-temporaries exprs)] - [(exprs ...) exprs] - [(fail) (generate-temporaries #'(fail))]) - (with-syntax ([body (compile* - (syntax->list #'(xs ...)) - (map (lambda (pats rhs) - (unless (= len - (length (syntax->list pats))) - (raise-syntax-error - 'match - (format "~a, expected ~a and got ~a" - "wrong number of match clauses" - len - (length (syntax->list pats))) - pats)) - (syntax-case* rhs (=>) - (lambda (x y) - (eq? (syntax-e x) - (syntax-e y))) - [((=> unm) . rhs) - (make-Row (map (lambda (s) - (parse/cert s cert)) - (syntax->list pats)) - #`(begin . rhs) - #'unm - null)] - [_ - (make-Row (map (lambda (s) - (parse/cert s cert)) - (syntax->list pats)) - #`(begin . #,rhs) - #f - null)])) - (syntax->list #'(pats ...)) - (syntax->list #'(rhs ...))) - #'fail)] - [orig-expr (if (= 1 len) - (stx-car #'(xs ...)) - #'(list xs ...))]) - (quasisyntax/loc stx - (let ([xs exprs] - ...) - (let ([fail (lambda () - #,(syntax/loc stx (match:error orig-expr)))]) - body))))))]))) + (syntax-case clauses () + [([pats . rhs] ...) + (nest + ([parameterize ([orig-stx stx])] + [let ([len (length (syntax->list exprs))])] + [with-syntax ([(xs ...) (generate-temporaries exprs)] + [(exprs ...) exprs] + [(fail) (generate-temporaries #'(fail))])] + [with-syntax + ([body + (compile* + (syntax->list #'(xs ...)) + (for/list ([pats (syntax->list #'(pats ...))] + [rhs (syntax->list #'(rhs ...))]) + (let ([lp (length (syntax->list pats))]) + (unless (= len lp) + (raise-syntax-error + 'match + (format + "wrong number of match clauses, expected ~a and got ~a" + len lp) + pats)) + (let ([mk (lambda (unm rhs) + (make-Row (for/list ([p (syntax->list pats)]) + (parse/cert p cert)) + #`(begin . #,rhs) unm null))]) + (syntax-case* rhs (=>) + (lambda (x y) (eq? (syntax-e x) (syntax-e y))) + [((=> unm) . rhs) (mk #'unm #'rhs)] + [_ (mk #f rhs)])))) + #'fail)] + [orig-expr + (if (= 1 len) (stx-car #'(xs ...)) #'(list xs ...))])]) + (quasisyntax/loc stx + (let ([xs exprs] ...) + (let ([fail (lambda () + #,(syntax/loc stx (match:error orig-expr)))]) + body))))])) diff --git a/collects/scheme/match/parse-helper.ss b/collects/scheme/match/parse-helper.ss index 84a6827c7f..411630d9cb 100644 --- a/collects/scheme/match/parse-helper.ss +++ b/collects/scheme/match/parse-helper.ss @@ -119,7 +119,7 @@ error-msg) (let* ([expander (syntax-local-value (cert expander))] [transformer (accessor expander)]) - (unless transformer (raise-syntax-error #f error-msg #'expander)) + (unless transformer (raise-syntax-error #f error-msg expander)) (let* ([introducer (make-syntax-introducer)] [certifier (match-expander-certifier expander)] [mstx (introducer (syntax-local-introduce stx))] diff --git a/collects/scheme/match/patterns.ss b/collects/scheme/match/patterns.ss index 258ac7d9f7..30543f7bcb 100644 --- a/collects/scheme/match/patterns.ss +++ b/collects/scheme/match/patterns.ss @@ -39,8 +39,6 @@ ;; start is what index to start at (define-struct (Vector CPat) (ps) #:transparent) -(define-struct (VectorSeq Pat) (p count start) #:transparent) - (define-struct (Pair CPat) (a d) #:transparent) (define-struct (MPair CPat) (a d) #:transparent) diff --git a/collects/scheme/match/reorder.ss b/collects/scheme/match/reorder.ss new file mode 100644 index 0000000000..8ec4a4bd36 --- /dev/null +++ b/collects/scheme/match/reorder.ss @@ -0,0 +1,86 @@ +#lang scheme/base + +(require "patterns.ss" + scheme/list + (only-in srfi/1/list take-while) + (for-syntax scheme/base)) + +(provide reorder-columns) + +#| +(define p-x (make-Var #'x)) +(define p-y (make-Var #'y)) +(define p-d (make-Dummy #'_)) + +(define p-cons (make-Pair p-x p-y)) +(define p-vec (make-Vector (list p-x p-y p-d))) + +(define r1 (make-Row (list p-x) #'#f #f null)) +(define r2 (make-Row (list p-y) #'#f #f null)) +(define r3 (make-Row (list p-cons) #'#f #f null)) +(define r4 (make-Row (list p-vec p-d) #'#f #f null)) + +(define r5 (make-Row (list p-x p-y p-cons) #'1 #f null)) +(define r6 (make-Row (list p-cons p-y p-vec) #'1 #f null)) +|# + +(define-sequence-syntax in-par + (lambda () (raise-syntax-error 'in-par "bad")) + (lambda (orig-stx stx) + (syntax-case stx () + [((id) (_ lst-exprs)) + #'[(id) + (:do-in + ;;outer bindings + ([(lst) lst-exprs]) + ;; outer check + (void) ; (unless (list? lst) (in-list lst)) + ;; loop bindings + ([lst lst]) + ;; pos check + (not (ormap null? lst)) + ;; inner bindings + ([(id) (map car lst)]) + ;; pre guard + #t + ;; post guard + #t + ;; loop args + ((map cdr lst)))]] + [_ (error 'no (syntax->datum stx))]))) + +(define (or-all? ps l) + (ormap (lambda (p) (andmap p l)) ps)) + +(define (score col) + (define n (length col)) + (define c (car col)) + (define preds (list Var? Pair? Null?)) + (cond [(or-all? preds col) (add1 n)] + [(andmap CPat? col) n] + [(Var? c) (length (take-while Var? col))] + [(Pair? c) (length (take-while Pair? col))] + [(Vector? c) (length (take-while Vector? col))] + [(Box? c) (length (take-while Box? col))] + [else 0])) + +(define (reorder-by ps scores*) + (for/fold + ([pats null]) + ([score-ref scores*]) + (cons (list-ref ps score-ref) pats))) + + +(define (reorder-columns rows vars) + (define scores (for/list ([i (in-naturals)] + [column (in-par (map (compose Row-pats) rows))]) + (cons i (score column)))) + (define scores* (reverse (map car (sort scores > #:key cdr)))) + (values + (for/list ([row rows]) + (let ([ps (Row-pats row)]) + (make-Row (reorder-by ps scores*) + (Row-rhs row) + (Row-unmatch row) + (Row-vars-seen row)))) + (reorder-by vars scores*)))