diff --git a/collects/mzlib/contracts.ss b/collects/mzlib/contracts.ss index 7813e03..7d0d862 100644 --- a/collects/mzlib/contracts.ss +++ b/collects/mzlib/contracts.ss @@ -680,7 +680,7 @@ (cond [(flat-named-contract? fc) (flat-named-contract-type-name fc)] [else (or (predicate->type-name fc) - "unknown type")])) + (format "unknown contract ~s" fc))])) ; @@ -818,8 +818,10 @@ (define (class-contract/proc stx) (syntax-case stx () - [(_ (meth-name meth-contract) ...) - (andmap identifier? (syntax->list (syntax (meth-name ...)))) + [(_ (method-specifier meth-name meth-contract) ...) + (and + (andmap method-specifier? (syntax->list (syntax (method-specifier ...)))) + (andmap identifier? (syntax->list (syntax (meth-name ...))))) (match-let ([(`(,make-outer-checks ,xxx ,build-pieces) ...) (map (lambda (meth-contract-stx) (let ([/h (select/h meth-contract-stx 'class-contract stx)]) @@ -829,90 +831,146 @@ (syntax->list (syntax (meth-contract ...))))]) (let* ([outer-args (syntax (val pos-blame neg-blame src-info))] [val-meth-names (syntax->list (syntax (meth-name ...)))] - [super-meth-names (map prefix-super val-meth-names)]) + [val-publics? (map (lambda (x) (eq? 'public (syntax-e x))) + (syntax->list (syntax (method-specifier ...))))] + [super-meth-names (map prefix-super val-meth-names)] + [val-meth-contracts (syntax->list (syntax (meth-contract ...)))] + [val-meth-contract-vars (generate-temporaries val-meth-contracts)]) (with-syntax ([outer-args outer-args] [(super-meth-name ...) super-meth-names] - [(later-method ...) (map (lambda (a b c) (make-wrapper/extending-method outer-args a b c)) - val-meth-names - super-meth-names - build-pieces)] - [(first-method ...) (map (lambda (a b c) (make-wrapper-method outer-args a b c)) - val-meth-names - super-meth-names - (syntax->list (syntax meth-contract ...)))]) + [(get-meth-contract ...) (map method-name->contract-method-name val-meth-names)] + [(method ...) (map (lambda (meth-name meth-contract-var contract-stx public?) + (if public? + (make-wrapper-method outer-args meth-name meth-contract-var contract-stx) + (make-wrapper-method/impl outer-args meth-name meth-contract-var contract-stx))) + val-meth-names + val-meth-contract-vars + val-meth-contracts + val-publics?)] + [(meth-contract-var ...) val-meth-contract-vars] + [(method-contract-declarations ...) (map (lambda (meth-name meth-contract-var public?) + (if public? + (make-public-method-contract-declaration meth-name meth-contract-var) + (make-override-method-contract-declaration meth-name meth-contract-var))) + val-meth-names + val-meth-contract-vars + val-publics?)]) (foldr (lambda (f stx) (f stx)) (syntax/loc stx - (make-contract - (lambda outer-args - (unless (class? val) - (raise-contract-error src-info pos-blame neg-blame "expected a class, got: ~e" val)) - (let ([class-i (class->interface val)]) - (void) - (unless (method-in-interface? 'meth-name class-i) - (raise-contract-error src-info - pos-blame - neg-blame - "expected class to have method ~a, got: ~e" - 'meth-name - val)) - ...) - (if (implementation? val class-with-contracts<%>) - '(class val - (define/override (get-method-contracts) - (list (cons 'meth-name meth-contract) ...)) - (rename [super-meth-name meth-name] ...) - later-method ... - (super-instantiate ())) - (class* val (class-with-contracts<%>) - - (define/public (get-method-contracts) - (list (cons 'meth-name meth-contract) ...)) - - (rename [super-meth-name meth-name] ...) - first-method ... - (super-instantiate ())))) - (lambda x (error 'impl-contract "unimplemented")))) + (let ([meth-contract-var meth-contract] ...) + (make-contract + (lambda outer-args + (unless (class? val) + (raise-contract-error src-info pos-blame neg-blame "expected a class, got: ~e" val)) + (let ([class-i (class->interface val)]) + (void) + (unless (method-in-interface? 'meth-name class-i) + (raise-contract-error src-info + pos-blame + neg-blame + "expected class to have method ~a, got: ~e" + 'meth-name + val)) + ... + + (let ([override-spec? (eq? 'override 'method-specifier)] + [override? (method-in-interface? 'get-meth-contract class-i)]) + (unless (boolean=? override-spec? override?) + (if override-spec? + (error 'class-contract "method ~a is declared as an overriding method in ~e, but isn't" 'meth-name val) + (error 'class-contract "method ~a is declared as a public method in ~e, but isn't" 'meth-name val)))) + ...) + + (class val + + method-contract-declarations ... + + (rename [super-meth-name meth-name] ...) + method ... + (super-instantiate ()))) + (lambda x (error 'impl-contract "unimplemented"))))) make-outer-checks))))] - [(_ (meth-name meth-contract) ...) - (for-each (lambda (name) + [(_ (meth-specifier meth-name meth-contract) ...) + (for-each (lambda (specifier name) + (unless (method-specifier? name) + (raise-syntax-error 'class-contract "expected either public or override" stx specifier)) (unless (identifier? name) (raise-syntax-error 'class-contract "expected name" stx name))) + (syntax->list (syntax (meth-specifier ...))) (syntax->list (syntax (meth-name ...))))] [(_ clz ...) (for-each (lambda (clz) (syntax-case clz () - [(a b) (void)] + [(a b c) (void)] [else (raise-syntax-error 'class-contract "bad method/contract clause" stx clz)])) (syntax->list (syntax (clz ...))))])) + ;; method-specifier? : syntax -> boolean + ;; returns #t if x is the syntax for a valid method specifier + (define (method-specifier? x) + (or (eq? 'public (syntax-e x)) + (eq? 'override (syntax-e x)))) - ;; make-wrapper-method : syntax[identifier] syntax[identifier] (syntax -> syntax) -> syntax + ;; make-wrapper-method : syntax syntax[identifier] syntax[identifier] syntax -> syntax ;; constructs a wrapper method that checks the pre and post-condition, and ;; calls the super method inbetween. - (define (make-wrapper-method-old outer-args method-name super-method-name build-piece) - (with-syntax ([super-method-name super-method-name] - [method-name method-name] - [(val pos-blame neg-blame src-info) outer-args] - [super-call (car (generate-temporaries (list super-method-name)))]) - (with-syntax ([(args body) (build-piece (syntax (super-call pos-blame neg-blame src-info)))]) - (syntax - (define/override method-name - (let ([super-call (lambda x (super-method-name . x))]) - (lambda args - body))))))) - - (define (make-wrapper-method outer-args method-name super-method-name contract) + (define (make-wrapper-method outer-args method-name contract-var contract-stx) (with-syntax ([(val pos-blame neg-blame src-info) outer-args] - [super-method-name super-method-name] + [super-method-name (prefix-super method-name)] [method-name method-name] - [contract contract]) - (syntax + [contract-var contract-var]) + (syntax/loc contract-stx (define/override method-name - (let ([super-method (lambda x (super-method-name . x))]) - (lambda args - (apply (check-contract super-method contract pos-blame neg-blame src-info) args))))))) - + (lambda args + (let ([super-method (lambda x (super-method-name . x))]) + (apply (check-contract contract-var super-method pos-blame neg-blame src-info) args))))))) + + ;; make-wrapper-method/impl : syntax syntax[identifier] syntax[identifier] syntax -> syntax + ;; constructs a wrapper method that checks the pre and post-condition, and + ;; calls the super method inbetween. + (define (make-wrapper-method/impl outer-args method-name contract-var contract-stx) + (with-syntax ([(val pos-blame neg-blame src-info) outer-args] + [super-method-name (prefix-super method-name)] + [method-name method-name] + [get-super-contract (prefix-super (method-name->contract-method-name method-name))] + [contract-var contract-var]) + (syntax/loc contract-stx + (define/override method-name + (lambda args + (let ([super-method (lambda x (super-method-name . x))]) + (apply (check-implication contract-var + (get-super-contract) + (check-contract contract-var + super-method + pos-blame + neg-blame + src-info) + pos-blame + src-info) + args))))))) + + ;; make-public-method-contract-declaration : syntax syntax -> syntax + (define (make-public-method-contract-declaration meth-name meth-contract-var) + (with-syntax ([get-contract (method-name->contract-method-name meth-name)] + [meth-contract-var meth-contract-var] + [meth-name meth-name]) + (syntax + (define/public (get-contract) + meth-contract-var)))) + + ;; make-override-method-contract-declaration : syntax syntax -> syntax + (define (make-override-method-contract-declaration meth-name meth-contract-var) + (with-syntax ([get-contract (method-name->contract-method-name meth-name)] + [super-get-contract (prefix-super (method-name->contract-method-name meth-name))] + [meth-contract-var meth-contract-var] + [meth-name meth-name]) + (syntax + (begin + (rename [super-get-contract get-contract]) + (define/override (get-contract) + meth-contract-var))))) + ;; prefix-super : syntax[identifier] -> syntax[identifier] ;; adds super- to the front of the identifier (define (prefix-super stx) @@ -1418,9 +1476,7 @@ [else (cons (- n i) (loop (- i 1)))])))))) - (define class-with-contracts<%> - (interface () - )) + (define class-with-contracts<%> (interface ())) (define-syntax (opt-> stx) (syntax-case stx () diff --git a/collects/tests/mzscheme/contracts.ss b/collects/tests/mzscheme/contracts.ss index f1b712d..2d2ffa6 100644 --- a/collects/tests/mzscheme/contracts.ss +++ b/collects/tests/mzscheme/contracts.ss @@ -817,7 +817,7 @@ (test/spec-passed/result 'class-contract1 '(send - (make-object (contract (class-contract (m (integer? . -> . integer?))) + (make-object (contract (class-contract (public m (integer? . -> . integer?))) (class object% (define/public (m x) x) (super-instantiate ())) 'pos 'neg)) @@ -827,7 +827,7 @@ (test/spec-failed 'class-contract2 - '(contract (class-contract (m (integer? . -> . integer?))) + '(contract (class-contract (public m (integer? . -> . integer?))) object% 'pos 'neg) @@ -836,7 +836,7 @@ (test/spec-failed 'class-contract3 '(send - (make-object (contract (class-contract (m (integer? . -> . integer?))) + (make-object (contract (class-contract (public m (integer? . -> . integer?))) (class object% (define/public (m x) x) (super-instantiate ())) 'pos 'neg)) @@ -847,7 +847,7 @@ (test/spec-failed 'class-contract4 '(send - (make-object (contract (class-contract (m (integer? . -> . integer?))) + (make-object (contract (class-contract (public m (integer? . -> . integer?))) (class object% (define/public (m x) 'x) (super-instantiate ())) 'pos 'neg)) @@ -855,6 +855,32 @@ 1) "pos") + (test/spec-failed + 'class-contract=>1 + '(let* ([c% (contract (class-contract (public m ((>=/c 10) . -> . (>=/c 10)))) + (class object% (define/public (m x) x) (super-instantiate ())) + 'pos-c + 'neg-c)] + [d% (contract (class-contract (override m ((>=/c 15) . -> . (>=/c 5)))) + (class c% (define/override (m x) x) (super-instantiate ())) + 'pos-d + 'neg-d)]) + (send (make-object d%) m 12)) + "pos-d") + + (test/spec-failed + 'class-contract=>2 + '(let* ([c% (contract (class-contract (public m ((>=/c 10) . -> . (>=/c 10)))) + (class object% (define/public (m x) x) (super-instantiate ())) + 'pos-c + 'neg-c)] + [d% (contract (class-contract (override m ((>=/c 15) . -> . (>=/c 5)))) + (class c% (define/override (m x) 8) (super-instantiate ())) + 'pos-d + 'neg-d)]) + (send (make-object d%) m 100)) + "pos-d") + )) (report-errs) \ No newline at end of file