From e9e1d4b5b77b8e686d7d82a84c6a758d75dbe262 Mon Sep 17 00:00:00 2001 From: Jon Rafkind Date: Sat, 5 Sep 2009 16:47:59 +0000 Subject: [PATCH] define-struct/contract can handle sub-typing now svn: r15887 --- collects/scheme/private/contract.ss | 146 ++++++++++++++++++---------- 1 file changed, 94 insertions(+), 52 deletions(-) diff --git a/collects/scheme/private/contract.ss b/collects/scheme/private/contract.ss index 3bb1ef893f..bb6ffefff7 100644 --- a/collects/scheme/private/contract.ss +++ b/collects/scheme/private/contract.ss @@ -133,40 +133,43 @@ improve method arity mismatch contract violation error messages? (define-struct s-info (auto-value-stx transparent? def-stxs? def-vals?)) (define (build-struct-names name field-infos) - (let ([name-str (symbol->string (syntax-e name))]) - (list* name - (datum->syntax - name - (string->symbol - (string-append "struct:" name-str))) - (datum->syntax - name - (string->symbol - (string-append "make-" name-str))) - (datum->syntax - name - (string->symbol - (string-append name-str "?"))) - (apply append - (for/list ([finfo field-infos]) - (let ([field-str (symbol->string (syntax-e (field-info-stx finfo)))]) - (cons (datum->syntax - name - (string->symbol - (string-append name-str "-" field-str))) - (if (field-info-mutable? finfo) - (list (datum->syntax - name - (string->symbol - (string-append "set-" name-str "-" field-str "!")))) - null)))))))) + (let ([name-str (symbol->string (syntax-case name () + [id (identifier? #'id) + (syntax-e #'id)] + [(sub super) + (syntax-e #'sub)]))]) + (list* + (syntax-case name () + [id (identifier? #'id) #'id] + [(sub super) #'sub]) + (datum->syntax + name + (string->symbol + (string-append "struct:" name-str))) + (datum->syntax + name + (string->symbol + (string-append "make-" name-str))) + (datum->syntax + name + (string->symbol + (string-append name-str "?"))) + (apply append + (for/list ([finfo field-infos]) + (let ([field-str (symbol->string (syntax-e (field-info-stx finfo)))]) + (cons (datum->syntax + name + (string->symbol + (string-append name-str "-" field-str))) + (if (field-info-mutable? finfo) + (list (datum->syntax + name + (string->symbol + (string-append "set-" name-str "-" field-str "!")))) + null)))))))) (define (build-contracts stx pred field-infos) - (list* (quasisyntax/loc stx - (-> #,@(map field-info-ctc - (filter (λ (f) - (not (field-info-auto? f))) - field-infos)) any/c)) + (list* (syntax/loc stx any/c) (syntax/loc stx any/c) (apply append (for/list ([finfo field-infos]) @@ -178,7 +181,7 @@ improve method arity mismatch contract violation error messages? (quasisyntax/loc stx (-> #,pred #,field-ctc void?))) null))))))) - + (define (check-field f ctc) (let ([p-list (syntax->list f)]) (if p-list @@ -202,14 +205,14 @@ improve method arity mismatch contract violation error messages? [(eq? elem '#:mutable) (begin (when mutable? (raise-syntax-error 'define-struct/contract - "redundant #:mutable" - (car rest))) + "redundant #:mutable" + (car rest))) (loop (cdr rest) #t auto?))] [(eq? elem '#:auto) (begin (when auto? (raise-syntax-error 'define-struct/contract - "redundant #:mutable" - (car rest))) + "redundant #:mutable" + (car rest))) (loop (cdr rest) mutable? #t))] [else (raise-syntax-error 'define-struct/contract "expected #:mutable or #:auto" @@ -250,7 +253,7 @@ improve method arity mismatch contract violation error messages? "redundant #:mutable" (car kwds))) (for ([finfo field-infos]) - (set-field-info-mutable?! finfo #t)) + (set-field-info-mutable?! finfo #t)) (loop (cdr kwd-list) auto-value-stx transparent? #t def-stxs? def-vals?)] [(eq? kwd '#:transparent) @@ -280,15 +283,27 @@ improve method arity mismatch contract violation error messages? (syntax-case stx () [(_ name ([field ctc] ...) kwds ...) (let ([fields (syntax->list #'(field ...))]) - (unless (identifier? #'name) + (unless (or (identifier? #'name) + (syntax-case #'name () + [(x y) (and (identifier? #'x) + (identifier? #'y))] + [_ #f])) (raise-syntax-error 'define-struct/contract - "expected identifier for struct name" + "expected identifier for struct name or a sub-type relationship (subtype supertype)" #'name)) (let* ([field-infos (map check-field fields (syntax->list #'(ctc ...)))] [sinfo (check-kwds (syntax->list #'(kwds ...)) field-infos)] [names (build-struct-names #'name field-infos)] [pred (cadddr names)] - [ctcs (build-contracts stx pred field-infos)]) + [ctcs (build-contracts stx pred field-infos)] + [super-fields (syntax-case #'name () + [(child parent) + (let ([v (syntax-local-value #'parent (lambda () #f))]) + (unless (struct-info? v) + (raise-syntax-error #f "identifier is not bound to a structure type" stx #'parent)) + (let ([v (extract-struct-info v)]) + (cadddr v)))] + [else '()])]) (let-values ([(non-auto-fields auto-fields) (let loop ([fields field-infos] [nautos null] @@ -309,19 +324,28 @@ improve method arity mismatch contract violation error messages? (field-info-stx (car fields)))))))]) (with-syntax ([ctc-bindings (let ([val-bindings (if (s-info-def-vals? sinfo) - (cons (cadr names) (map list (cddr names) ctcs)) + (cons (cadr names) + (map list (cddr names) + ctcs)) null)]) (if (s-info-def-stxs? sinfo) (cons (car names) val-bindings) val-bindings))] [orig stx] + [struct-name (syntax-case #'name () + [id (identifier? #'id) #'id] + [(id1 super) #'id1])] [(auto-check ...) (let* ([av-stx (if (s-info-auto-value-stx sinfo) (s-info-auto-value-stx sinfo) #'#f)] [av-id (datum->syntax av-stx (string->symbol - (string-append (symbol->string (syntax-e #'name)) + (string-append (syntax-case #'name () + [id (identifier? #'id) + (symbol->string (syntax-e #'id))] + [(id1 super) + (symbol->string (syntax-e #'id1))]) ":auto-value")) av-stx)]) (for/list ([finfo auto-fields]) @@ -331,20 +355,38 @@ improve method arity mismatch contract violation error messages? '(struct name) 'cant-happen #,(id->contract-src-info av-id)))))] + ;; a list of variables, one for each super field + [(super-fields ...) (generate-temporaries super-fields)] + ;; the contract for a super field is any/c becuase the + ;; super constructor will have its own contract + [(super-contracts ...) (for/list ([i (in-list super-fields)]) + (datum->syntax stx 'any/c))] + [(non-auto-contracts ...) + (map field-info-ctc + (filter (lambda (f) + (not (field-info-auto? f))) + field-infos))] + ;; the make-foo function. this is used to make the contract + ;; print the right name in the blame + [maker (caddr names)] [(non-auto-name ...) (map field-info-stx non-auto-fields)]) (syntax/loc stx (begin (define-values () (begin auto-check ... (values))) - (with-contract #:type struct name - ctc-bindings - (define-struct/derived orig name (field ...) - kwds ... - #:guard (λ (non-auto-name ... struct-name) - (unless (eq? 'name struct-name) - (error (format "Cannot create subtype ~a of contracted struct ~a" - struct-name 'name))) - (values non-auto-name ...))))))))))] + (define (guard super-fields ... non-auto-name ... struct-name) + (values super-fields ... non-auto-name ...)) + (define blame-id + (current-contract-region)) + (with-contract #:type struct struct-name + ctc-bindings + (define-struct/derived orig name (field ...) + kwds ... + #:guard (-contract (-> super-contracts ... non-auto-contracts ... symbol? any) + guard + + blame-id blame-id + #'maker)))))))))] [(_ name . bad-fields) (identifier? #'name) (raise-syntax-error 'define-struct/contract