diff --git a/collects/scheme/match/match.ss b/collects/scheme/match/match.ss index f7323367ba..08aad02f86 100644 --- a/collects/scheme/match/match.ss +++ b/collects/scheme/match/match.ss @@ -6,6 +6,7 @@ (only-in "match-expander.ss" define-match-expander) "define-forms.ss" + "struct.ss" (for-syntax "parse.ss" "gen-match.ss" (only-in "patterns.ss" match-...-nesting))) @@ -13,6 +14,7 @@ (provide (for-syntax match-...-nesting) match-equality-test define-match-expander + struct* exn:misc:match?) (define-forms parse/cert diff --git a/collects/scheme/match/struct.ss b/collects/scheme/match/struct.ss new file mode 100644 index 0000000000..9a57511a4a --- /dev/null +++ b/collects/scheme/match/struct.ss @@ -0,0 +1,73 @@ +#lang scheme/base +(require scheme/match/match-expander + (for-syntax scheme/base + scheme/struct-info + syntax/boundmap + scheme/list)) + +(define-match-expander + struct* + (lambda (stx) + (syntax-case stx () + [(_ struct-name (field+pat ...)) + (let* ([fail (lambda () + (raise-syntax-error + 'struct* "not a structure definition" + stx #'struct-name))] + [v (if (identifier? #'struct-name) + (syntax-local-value #'struct-name fail) + (fail))] + [field-acc->pattern (make-free-identifier-mapping)]) + (unless (struct-info? v) (fail)) + ; Check each pattern and capture the field-accessor name + (for-each (lambda (an) + (syntax-case an () + [(field pat) + (unless (identifier? #'field) + (raise-syntax-error + 'struct* "not an identifier for field name" + stx #'field)) + (let ([field-acc + (datum->syntax #'field + (string->symbol + (format "~a-~a" + (syntax-e #'struct-name) + (syntax-e #'field))) + #'field)]) + (when (free-identifier-mapping-get field-acc->pattern field-acc (lambda () #f)) + (raise-syntax-error 'struct* "Field name appears twice" stx #'field)) + (free-identifier-mapping-put! field-acc->pattern field-acc #'pat))] + [_ + (raise-syntax-error + 'struct* "expected a field pattern of the form ( )" + stx an)])) + (syntax->list #'(field+pat ...))) + (let* (; Get the structure info + [acc (fourth (extract-struct-info v))] + ;; the accessors come in reverse order + [acc (reverse acc)] + ;; remove the first element, if it's #f + [acc (cond [(empty? acc) acc] + [(not (first acc)) (rest acc)] + [else acc])] + ; Order the patterns in the order of the accessors + [pats-in-order + (for/list ([field-acc (in-list acc)]) + (begin0 + (free-identifier-mapping-get + field-acc->pattern field-acc + (lambda () (syntax/loc stx _))) + ; Use up pattern + (free-identifier-mapping-put! + field-acc->pattern field-acc #f)))]) + ; Check that all patterns were used + (free-identifier-mapping-for-each + field-acc->pattern + (lambda (field-acc pat) + (when pat + (raise-syntax-error 'struct* "field name not associated with given structure type" + stx field-acc)))) + (quasisyntax/loc stx + (struct struct-name #,pats-in-order))))]))) + +(provide struct*) \ No newline at end of file diff --git a/collects/scribblings/reference/match.scrbl b/collects/scribblings/reference/match.scrbl index 10cb2fca78..2e0693b95b 100644 --- a/collects/scribblings/reference/match.scrbl +++ b/collects/scribblings/reference/match.scrbl @@ -445,4 +445,23 @@ default is @scheme[equal?].} @; ---------------------------------------------------------------------- +@section{Library Extensions} + +@defform[(struct* struct-id ([field pat] ...))]{ + Matches an instance of a structure type named @scheme[struct-id], where the field @scheme[field] in the instance matches the corresponding @scheme[pat]. + + Any field of @scheme[struct-id] may be omitted and they may occur in any order. + + @defexamples[ + #:eval match-eval + (define-struct tree (val left right)) + (match (make-tree 0 (make-tree 1 #f #f) #f) + [(struct* tree ([val a] + [left (struct* tree ([right #f] [val b]))])) + (list a b)]) + ] + } + +@; ---------------------------------------------------------------------- + @close-eval[match-eval] diff --git a/collects/tests/match/plt-match-tests.ss b/collects/tests/match/plt-match-tests.ss index 06276f5684..013c2b0a82 100644 --- a/collects/tests/match/plt-match-tests.ss +++ b/collects/tests/match/plt-match-tests.ss @@ -209,6 +209,89 @@ )) +(define struct*-tests + (make-test-suite + "Tests of struct*" + (make-test-case "not an id for struct" + (assert-exn exn:fail:syntax? + (lambda () + (expand #'(let () + (define-struct tree (val left right)) + (match (make-tree 0 1 2) + [(struct* 4 ()) + #f])))))) + (make-test-case "not a struct-info for struct" + (assert-exn exn:fail:syntax? + (lambda () + (expand #'(let () + (define-syntax tree 1) + (match 1 + [(struct* tree ()) + #f])))))) + (make-test-case "bad form" + (assert-exn exn:fail:syntax? + (lambda () + (expand #'(let () + (define-struct tree (val left right)) + (match (make-tree 0 1 2) + [(struct* tree ([val])) + #f])))))) + (make-test-case "bad form" + (assert-exn exn:fail:syntax? + (lambda () + (expand #'(let () + (define-struct tree (val left right)) + (match (make-tree 0 1 2) + [(struct* tree (val)) + #f])))))) + (make-test-case "field appears twice" + (assert-exn exn:fail:syntax? + (lambda () + (expand #'(let () + (define-struct tree (val left right)) + (match (make-tree 0 1 2) + [(struct* tree ([val 0] [val 0])) + #f])))))) + (make-test-case "not a field" + (assert-exn exn:fail:syntax? + (lambda () + (expand #'(let () + (define-struct tree (val left right)) + (match (make-tree 0 1 2) + [(struct* tree ([feet 0])) + #f])))))) + (make-test-case "super structs don't work" + (assert-exn exn:fail:syntax? + (lambda () + (expand #'(let () + (define-struct extra (foo)) + (define-struct (tree extra) (val left right)) + (match (make-tree #f 0 1 2) + [(struct* tree ([extra #f] [val 0])) + #f])))))) + (make-test-case "super struct kinda work" + (let () + (define-struct extra (foo)) + (define-struct (tree extra) (val left right)) + (match (make-tree #f 0 1 2) + [(struct* tree ([val a])) + (assert = 0 a)]))) + (make-test-case "from documentation" + (let () + (define-struct tree (val left right)) + (match-define + (struct* + tree + ([val a] + [left + (struct* + tree + ([right #f] + [val b]))])) + (make-tree 0 (make-tree 1 #f #f) #f)) + (assert = 0 a) + (assert = 1 b))))) + (define plt-match-tests (make-test-suite "Tests for plt-match.ss" doc-tests @@ -217,6 +300,7 @@ nonlinear-tests match-expander-tests reg-tests + struct*-tests )) (define (run-tests)