diff --git a/collects/mzlib/contract.ss b/collects/mzlib/contract.ss index 97f1544..41834be 100644 --- a/collects/mzlib/contract.ss +++ b/collects/mzlib/contract.ss @@ -167,7 +167,7 @@ add struct contracts for immutable structs? [(rename . _) (raise-syntax-error 'provide/contract "malformed rename clause" provide-stx clause)] [(struct struct-name ((field-name contract) ...)) - (and (identifier? (syntax struct-name)) + (and (well-formed-struct-name? (syntax struct-name)) (andmap identifier? (syntax->list (syntax (field-name ...))))) (let ([sc (build-struct-code provide-stx (syntax struct-name) @@ -180,6 +180,11 @@ add struct contracts for immutable structs? "missing fields" provide-stx clause)] + [(struct name . rest) + (not (well-formed-struct-name? (syntax name))) + (raise-syntax-error 'provide/contract "name must be an identifier or two identifiers with parens around them" + provide-stx + (syntax name))] [(struct name (fields ...)) (for-each (lambda (field) (syntax-case field () @@ -220,11 +225,36 @@ add struct contracts for immutable structs? provide-stx (syntax unk))]))])) + ;; well-formed-struct-name? : syntax -> bool + (define (well-formed-struct-name? stx) + (or (identifier? stx) + (syntax-case stx () + [(name super) + (and (identifier? (syntax name)) + (identifier? (syntax super))) + #t] + [else #f]))) + ;; build-struct-code : syntax syntax (listof syntax) (listof syntax) -> syntax ;; constructs the code for a struct clause ;; first arg is the original syntax object, for source locations - (define (build-struct-code stx struct-name field-names field-contracts) - (let* ([field-contract-ids (map (lambda (field-name) + (define (build-struct-code stx struct-name-position field-names field-contracts) + (let* ([struct-name (syntax-case struct-name-position () + [(a b) (syntax a)] + [else struct-name-position])] + [parent-struct-count (let ([parent-info (extract-parent-struct-info struct-name-position)]) + (and parent-info + (let ([fields (cadddr parent-info)]) + (cond + [(null? fields) 0] + [(not (car (last-pair fields))) + (raise-syntax-error + 'provide/contract + "cannot determine the number of fields in super struct" + provide-stx + struct-name)] + [else (length fields)]))))] + [field-contract-ids (map (lambda (field-name) (a:mangle-id provide-stx "provide/contract-field-contract" field-name @@ -239,25 +269,35 @@ add struct contracts for immutable structs? [predicate-id (build-predicate-id struct-name)] [constructor-id (build-constructor-id struct-name)]) (with-syntax ([(selector-codes ...) - (map (lambda (selector-id field-contract-id) - (code-for-one-id stx - selector-id - (build-selector-contract struct-name - predicate-id - field-contract-id) - #f)) - selector-ids - field-contract-ids)] + (filter + (lambda (x) x) + (map/count (lambda (selector-id field-contract-id index) + (if (or (not parent-struct-count) + (parent-struct-count . <= . index)) + (code-for-one-id stx + selector-id + (build-selector-contract struct-name + predicate-id + field-contract-id) + #f) + #f)) + selector-ids + field-contract-ids))] [(mutator-codes ...) - (map (lambda (mutator-id field-contract-id) - (code-for-one-id stx - mutator-id - (build-mutator-contract struct-name - predicate-id - field-contract-id) - #f)) - mutator-ids - field-contract-ids)] + (filter + (lambda (x) x) + (map/count (lambda (mutator-id field-contract-id index) + (if (or (not parent-struct-count) + (parent-struct-count . <= . index)) + (code-for-one-id stx + mutator-id + (build-mutator-contract struct-name + predicate-id + field-contract-id) + #f) + #f)) + mutator-ids + field-contract-ids))] [predicate-code (code-for-one-id stx predicate-id (syntax (-> any? boolean?)) #f)] [constructor-code (code-for-one-id stx @@ -283,6 +323,32 @@ add struct contracts for immutable structs? predicate-code constructor-code (provide struct-name struct:struct-name)))))) + + ;; map/count : (X Y int -> Z) (listof X) (listof Y) -> (listof Z) + (define (map/count f l1 l2) + (let loop ([l1 l1] + [l2 l2] + [i 0]) + (cond + [(and (null? l1) (null? l2)) '()] + [(or (null? l1) (null? l2)) (error 'map/count "mismatched lists")] + [else (cons (f (car l1) (car l2) i) + (loop (cdr l1) + (cdr l2) + (+ i 1)))]))) + + ;; extract-struct-info : syntax -> (union #f (list syntax syntax (listof syntax) ...)) + (define (extract-parent-struct-info stx) + (syntax-case stx () + [(a b) + (syntax-local-value + (syntax b) + (lambda () + (raise-syntax-error 'provide/contract + "expected a struct name" + provide-stx + (syntax a))))] + [a #f])) ;; build-constructor-contract : syntax (listof syntax) syntax -> syntax (define (build-constructor-contract stx field-contract-ids predicate-id) diff --git a/collects/tests/mzscheme/contract-test.ss b/collects/tests/mzscheme/contract-test.ss index 34ea0e9..dd9fab6 100644 --- a/collects/tests/mzscheme/contract-test.ss +++ b/collects/tests/mzscheme/contract-test.ss @@ -1093,11 +1093,29 @@ 'provide/contract7 '(let () (eval '(module contract-test-suite7 mzscheme + (require (lib "contract.ss")) + (define-struct s (a b)) + (define-struct (t s) (c d)) + (provide/contract + (struct s ((a any?) (b any?))) + (struct (t s) ((a any?) (b any?) (c any?) (d any?)))))) + (eval '(require contract-test-suite7)) + (eval '(let ([x (make-t 1 2 3 4)]) + (s-a x) + (s-b x) + (t-c x) + (t-d x) + (void))))) + + (test/spec-passed + 'provide/contract8 + '(let () + (eval '(module contract-test-suite8 mzscheme (require (lib "contract.ss")) (provide/contract (rename the-internal-name the-external-name integer?)) (define the-internal-name 1) (+ the-internal-name 1))) - (eval '(require contract-test-suite7)) + (eval '(require contract-test-suite8)) (eval '(+ the-external-name 1)))) @@ -1164,6 +1182,27 @@ 'pos 'neg)) + (test/spec-passed/result + 'object-contract/field6 + '(send (contract (object-contract [m (integer? . -> . integer?)]) + (new (class object% (define x 1) (define/public (m y) x) (super-new))) + 'pos + 'neg) + m + 2) + 1) + + #; + (test/spec-passed/result + 'object-contract/field7 + '(send (contract (object-contract) + (new (class object% (define x 1) (define/public (m y) x) (super-new))) + 'pos + 'neg) + m + 2) + 1) + (test/spec-passed/result 'object-contract->1 '(send @@ -2378,6 +2417,21 @@ (test/well-formed #'(case-> (->d* (any? any?) (lambda x any?)) (-> integer? integer?))) (test/well-formed #'(case-> (->d* (any? any?) any? (lambda x any?)) (-> integer? integer?))) + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + ;; ;; + ;; Inferred Name Tests ;; + ;; ;; + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + + (eval + '(module contract-test-suite-inferred-name1 mzscheme + (require (lib "contract.ss")) + (define contract-inferred-name-test-contract (-> integer? any)) + (define (contract-inferred-name-test x) #t) + (provide/contract (contract-inferred-name-test contract-inferred-name-test-contract)))) + (eval '(require contract-test-suite-inferred-name1)) + (eval '(test 'contract-inferred-name-test object-name contract-inferred-name-test)) + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;; ;; Contract Name Tests ;;