From 7e8816ce0f212c87e5a32433b869d0dc5ce14002 Mon Sep 17 00:00:00 2001 From: Stevie Strickland Date: Wed, 18 Feb 2009 02:25:45 +0000 Subject: [PATCH] * Add initial version of define-struct/contract * Allow uncontracted exports of syntax from a with-contract form. svn: r13717 --- collects/mzlib/contract.ss | 3 +- collects/scheme/private/contract.ss | 127 +++++++++++++++--- .../scribblings/reference/contracts.scrbl | 4 + collects/tests/mzscheme/contract-test.ss | 51 +++++++ 4 files changed, 163 insertions(+), 22 deletions(-) diff --git a/collects/mzlib/contract.ss b/collects/mzlib/contract.ss index a07bee76a0..ac8a793567 100644 --- a/collects/mzlib/contract.ss +++ b/collects/mzlib/contract.ss @@ -31,7 +31,8 @@ (require (except-in scheme/private/contract define/contract - with-contract) + with-contract + define-struct/contract) scheme/private/contract-guts scheme/private/contract-ds scheme/private/contract-opt diff --git a/collects/scheme/private/contract.ss b/collects/scheme/private/contract.ss index 00e7b49274..7fb1f48aa1 100644 --- a/collects/scheme/private/contract.ss +++ b/collects/scheme/private/contract.ss @@ -12,6 +12,7 @@ improve method arity mismatch contract violation error messages? (provide (rename-out [-contract contract]) recursive-contract provide/contract + define-struct/contract define/contract with-contract current-contract-region) @@ -125,7 +126,70 @@ improve method arity mismatch contract violation error messages? (syntax/loc define-stx (define/contract name+arg-list contract #:freevars () body0 body ...))])) +(define-for-syntax (ds/c-build-struct-names name fields) + (let ([name-str (symbol->string (syntax-e name))]) + (list* (datum->syntax + name + (string->symbol + (string-append "struct:" name-str))) + (datum->syntax + name + (string->symbol + (string-append "make-" name-str))) + (datum->syntax + name + (string->symbol + (string-append name-str "?"))) + (for/list ([field-str (map (compose symbol->string syntax-e) fields)]) + (datum->syntax + name + (string->symbol + (string-append name-str "-" field-str))))))) +(define-syntax (define-struct/contract stx) + (syntax-case stx () + [(_ name ([field ctc] ...)) + (let ([fields (syntax->list #'(field ...))]) + (unless (identifier? #'name) + (raise-syntax-error 'define-struct/contract + "expected identifier for struct name" + #'name)) + (for-each (λ (f) + (unless (identifier? f) + (raise-syntax-error 'define-struct/contract + "expected identifier for field name" + f))) + fields) + (let* ([names (ds/c-build-struct-names #'name fields)] + [pred (caddr names)] + [ctcs (list* (syntax/loc stx + (-> ctc ... any/c)) + (syntax/loc stx any/c) + (let ([field-ctc (quasisyntax/loc stx + (-> #,pred any/c))]) + (build-list + (length fields) + (λ (_) field-ctc))))]) + (with-syntax ([struct:name (car names)] + [(id/ctc ...) (map list (cdr names) ctcs)]) + (syntax/loc stx + (with-contract #:type struct name + (name struct:name id/ctc ...) + (define-struct name (field ...) + #:guard (λ (field ... struct-name) + (unless (eq? 'name struct-name) + (error (format "Cannot create subtype ~a of contracted struct ~a" + struct-name 'name))) + (values field ...))))))))] + [(_ name . bad-fields) + (identifier? #'name) + (raise-syntax-error 'define-struct/contract + "expected a list of field name/contract pairs" + #'bad-fields)] + [(_ . body) + (raise-syntax-error 'define-struct/contract + "expected a structure name" + #'body)])) ; ; @@ -180,35 +244,55 @@ improve method arity mismatch contract violation error messages? (define-syntax (with-contract-helper stx) (syntax-case stx () - [(_ blame-stx ()) + [(_ blame-stx () ()) (begin #'(define-values () (values)))] - [(_ blame-stx (i0 i ...)) + [(_ blame-stx (p0 p ...) (u ...)) (raise-syntax-error 'with-contract "no definition found for identifier" - #'i0)] - [(_ blame-stx (i ...) body0 body ...) + #'p0)] + [(_ blame-stx () (u0 u ...)) + (raise-syntax-error 'with-contract + "no definition found for identifier" + #'u0)] + [(_ blame-stx (p ...) (u ...) body0 body ...) (let ([expanded-body0 (local-expand #'body0 (syntax-local-context) (kernel-form-identifier-list))]) - (syntax-case expanded-body0 (begin define-values) + (define (filter-ids to-filter to-remove) + (filter (λ (i1) + (not (memf (λ (i2) + (bound-identifier=? i1 i2)) + to-remove))) + to-filter)) + (syntax-case expanded-body0 (begin define-values define-syntaxes) [(begin sub ...) (syntax/loc stx - (with-contract-helper blame-stx (i ...) sub ... body ...))] + (with-contract-helper blame-stx (p ...) (u ...) sub ... body ...))] + [(define-syntaxes (id ...) expr) + (let ([ids (syntax->list #'(id ...))]) + (for ([i1 (syntax->list #'(p ...))]) + (when (ormap (λ (i2) + (bound-identifier=? i1 i2)) + ids) + (raise-syntax-error 'with-contract + "cannot export syntax with a contract" + i1))) + (with-syntax ([def expanded-body0] + [unused-us (filter-ids (syntax->list #'(u ...)) ids)]) + (with-syntax () + (syntax/loc stx + (begin def (with-contract-helper blame-stx (p ...) unused-us body ...))))))] [(define-values (id ...) expr) - (with-syntax ([def expanded-body0] - [unused-is (let ([ids (syntax->list #'(id ...))]) - (filter (λ (i1) - (not (ormap (λ (i2) - (bound-identifier=? i1 i2)) - ids))) - (syntax->list #'(i ...))))]) - (with-syntax () + (let ([ids (syntax->list #'(id ...))]) + (with-syntax ([def expanded-body0] + [unused-ps (filter-ids (syntax->list #'(p ...)) ids)] + [unused-us (filter-ids (syntax->list #'(u ...)) ids)]) (syntax/loc stx - (begin def (with-contract-helper blame-stx unused-is body ...)))))] + (begin def (with-contract-helper blame-stx unused-ps unused-us body ...)))))] [else (quasisyntax/loc stx (begin #,expanded-body0 - (with-contract-helper blame-stx (i ...) body ...)))]))])) + (with-contract-helper blame-stx (p ...) (u ...) body ...)))]))])) (define-for-syntax (check-and-split-with-contracts single-allowed? args) (let loop ([args args] @@ -321,7 +405,7 @@ improve method arity mismatch contract violation error messages? [(ctc-id ...) (map (λ (i) (marker (a:mangle-id stx "with-contract-contract-id" i))) protected)] - [(ctc ...) protections] + [(ctc ...) (map marker protections)] [(p ...) protected] [(marked-p ...) (map marker protected)] [(src-info ...) (map id->contract-src-info protected)] @@ -329,9 +413,8 @@ improve method arity mismatch contract violation error messages? [(marked-u ...) (map marker unprotected)]) (quasisyntax/loc stx (begin - (define-values (free-ctc-id ... ctc-id ...) - (values (verify-contract 'with-contract free-ctc) ... - (verify-contract 'with-contract ctc) ...)) + (define-values (free-ctc-id ...) + (values (verify-contract 'with-contract free-ctc) ...)) (define blame-id (current-contract-region)) (define-values () @@ -349,7 +432,9 @@ improve method arity mismatch contract violation error messages? (quote-syntax blame-id) (quote-syntax blame-stx)) ...)) (splicing-syntax-parameterize ([current-contract-region (λ (stx) #'blame-stx)]) - (with-contract-helper blame-stx (marked-p ... marked-u ...) . #,(marker #'body))) + (with-contract-helper blame-stx (marked-p ...) (marked-u ...) . #,(marker #'body))) + (define-values (ctc-id ...) + (values (verify-contract 'with-contract ctc) ...)) (define-values () (begin (-contract ctc-id marked-p diff --git a/collects/scribblings/reference/contracts.scrbl b/collects/scribblings/reference/contracts.scrbl index 50cff07d7c..7b02e605c6 100644 --- a/collects/scribblings/reference/contracts.scrbl +++ b/collects/scribblings/reference/contracts.scrbl @@ -720,6 +720,10 @@ inside the @scheme[body] will be protected with contracts that blame the context of the @scheme[define/contract] form for the positive positions and the @scheme[define/contract] form for the negative ones.} +@defform*[[(define-struct/contract struct-id ([field-id contract-expr] ...))]]{ +Works like @scheme[define-struct], except that the arguments to the constructor +and accessors are protected by contracts.} + @defform*[[(contract contract-expr to-protect-expr positive-blame-expr negative-blame-expr) (contract contract-expr to-protect-expr diff --git a/collects/tests/mzscheme/contract-test.ss b/collects/tests/mzscheme/contract-test.ss index b009255492..fde5500ced 100644 --- a/collects/tests/mzscheme/contract-test.ss +++ b/collects/tests/mzscheme/contract-test.ss @@ -2478,6 +2478,57 @@ "top-level") + +; +; +; +; ; ;;;; ; +; ;; ; ; ; +; ; ; ; ; ; ; ; +; ; ; ; ; ; ; ; +; ;; ; ;;; ;;;; ; ; ;; ;;; ;;; ;;;; ; ;;;; ;; ;;; ;;;; ; ;;; ;; ; ;; ;;;; ; ;; ;;; ;;; ;;;; +; ; ;; ; ; ; ;; ;;; ; ; ; ; ; ; ;;; ; ; ; ; ; ; ; ; ; ; ;;; ; ; ;;; ; ; ; ; ; +; ; ; ;;;;; ; ; ; ; ;;;;; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ;; ; ; +; ; ; ; ; ; ; ; ; ;;;; ;; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ;; ; ; ; +; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ; +; ; ;;; ; ; ; ; ; ; ; ; ; ; ; ; ; ; ;;; ; ; ; ;; ; ; ; ; ; ; ; ; ; ; ;; ; ; ; ; +; ;; ; ;;; ;;;;;;;;;; ;;; ;;; ;;; ;; ;;; ;; ; ;;; ;; ; ;;; ;; ;;; ;;; ;; ;;; ;; ;; ;;; ;; +; +; +; + + (test/spec-passed + 'define-struct/contract1 + '(let () + (define-struct/contract foo ([x number?] [y number?])) + 1)) + + (test/spec-passed + 'define-struct/contract2 + '(let () + (define-struct/contract foo ([x number?] [y number?])) + (make-foo 1 2))) + + (test/spec-failed + 'define-struct/contract3 + '(let () + (define-struct/contract foo ([x number?] [y number?])) + (make-foo 1 #t)) + "top-level") + + (test/spec-passed + 'define-struct/contract4 + '(let () + (define-struct/contract foo ([x number?] [y number?])) + (foo-y (make-foo 2 3)))) + + (test/spec-failed + 'define-struct/contract5 + '(let () + (define-struct/contract foo ([x number?] [y number?])) + (foo-y 1)) + "top-level") + ; ; ;