diff --git a/collects/mzlib/contracts.ss b/collects/mzlib/contracts.ss index 10aab74..90eef82 100644 --- a/collects/mzlib/contracts.ss +++ b/collects/mzlib/contracts.ss @@ -10,17 +10,18 @@ opt-> opt->* class-contract + class-contract/prim (rename -contract? contract?) provide/contract define/contract) (require-for-syntax mzscheme - (lib "list.ss") - (lib "match.ss") + "list.ss" + "match.ss" (lib "name.ss" "syntax")) - (require (lib "class.ss") - (lib "etc.ss")) + (require "private/class-sneaky.ss" + "etc.ss") (require (lib "contract-helpers.scm" "mzlib" "private")) (require-for-syntax (prefix a: (lib "contract-helpers.scm" "mzlib" "private"))) @@ -701,7 +702,7 @@ ; - (define-syntax-set (-> ->* ->d ->d* case-> class-contract) + (define-syntax-set (-> ->* ->d ->d* case-> class-contract class-contract/prim) ;; ->/proc : syntax -> syntax ;; the transformer for the -> macro @@ -757,11 +758,12 @@ (case->/h stx (syntax->list (syntax (cases ...))))]) (let ([outer-args (syntax (val pos-blame neg-blame src-info))] [impl-args (syntax (ant conq val tbb src-info))]) + (ensure-cases-disjoint stx (extract-argument-lists impl-builder-cases)) (with-syntax ([outer-args outer-args] [(inner-check ...) (make-inner-check outer-args)] [(body ...) (make-bodies outer-args)] [(impl-builder-case ...) impl-builder-cases] - [(impl-info ...) impl-infos]) + [(impl-info ...) impl-infos]) (with-syntax ([inner-lambda (set-inferred-name-from stx @@ -782,7 +784,56 @@ (lambda impl-args impl-lambda-body) (lambda (x y z) (or (impl-info x y z) ...))))))))))])) - ;; case->/h : syntax (listof syntax) -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax)) + ;; exactract-argument-lists : syntax -> (listof syntax) + (define (extract-argument-lists stx) + (map (lambda (x) + (syntax-case x () + [(arg-list body) (syntax arg-list)])) + (syntax->list stx))) + + ;; ensure-cases-disjoint : syntax syntax[list] -> void + (define (ensure-cases-disjoint stx cases) + (let ([individual-cases null] + [dot-min #f]) + (for-each (lambda (case) + (let ([this-case (get-case case)]) + (cond + [(number? this-case) + (cond + [(member this-case individual-cases) + (raise-syntax-error 'case-> (format "found multiple cases with ~a arguments" this-case) stx)] + [(and dot-min (dot-min . <= . this-case)) + (raise-syntax-error 'case-> + (format "found overlapping cases (~a+ followed by ~a)" dot-min this-case) + stx)] + [else (set! individual-cases (cons this-case individual-cases))])] + [(pair? this-case) + (let ([new-dot-min (car this-case)]) + (cond + [dot-min + (if (dot-min . <= . new-dot-min) + (raise-syntax-error 'case-> + (format "found overlapping cases (~a+ followed by ~a+)" dot-min new-dot-min) + stx) + (set! dot-min new-dot-min))] + [else + (set! dot-min new-dot-min)]))]))) + cases))) + + ;; get-case : syntax -> (union number (cons number 'more)) + (define (get-case stx) + (let ([ilist (syntax-object->datum stx)]) + (if (list? ilist) + (length ilist) + (cons + (let loop ([i ilist]) + (cond + [(pair? i) (+ 1 (loop (cdr i)))] + [else 0])) + 'more)))) + + + ;; case->/h : syntax (listof syntax) -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax) (syntax -> syntax) syntax syntax) ;; like the other /h functions, but composes the wrapper functions ;; together and combines the cases of the case-lambda into a single list. (define (case->/h orig-stx cases) @@ -816,7 +867,10 @@ [impl-infos impl-infos]) (syntax (impl-info . impl-infos))))))]))) - (define (class-contract/proc stx) + (define (class-contract/proc stx) (class-contract-mo? stx #f)) + (define (class-contract/prim/proc stx) (class-contract-mo? stx #t)) + + (define (class-contract-mo? stx use-make-object?) (syntax-case stx () [(form (method-specifier meth-name meth-contract) ...) (and @@ -829,7 +883,7 @@ [super-meth-names (map prefix-super val-meth-names)] [val-meth-contracts (syntax->list (syntax (meth-contract ...)))] [val-meth-contract-vars (generate-temporaries val-meth-contracts)]) - + (let ([ht (make-hash-table)]) (for-each (lambda (name) (let ([key (syntax-e name)]) @@ -850,50 +904,62 @@ val-meth-contracts val-publics?)] [(meth-contract-var ...) val-meth-contract-vars] - [(method-contract-declarations ...) (map (lambda (meth-name meth-contract-var public?) + [(method-contract-declarations ...) (map (lambda (src-stx meth-name meth-contract-var public?) (if public? - (make-public-method-contract-declaration meth-name meth-contract-var) - (make-override-method-contract-declaration meth-name meth-contract-var))) + (make-public-method-contract-declaration src-stx + meth-name + meth-contract-var) + (make-override-method-contract-declaration src-stx + meth-name + meth-contract-var))) + val-meth-contracts val-meth-names val-meth-contract-vars val-publics?)] [this (datum->syntax-object (syntax form) 'this stx)] [super-init (datum->syntax-object (syntax form) 'super-instantiate stx)] [super-make (datum->syntax-object (syntax form) 'super-make-object stx)]) - - (syntax/loc stx - (let ([meth-contract-var meth-contract] ...) - (make-contract - (lambda outer-args - (unless (class? val) - (raise-contract-error src-info pos-blame neg-blame "expected a class, got: ~e" val)) - (let ([class-i (class->interface val)]) - (void) - (unless (method-in-interface? 'meth-name class-i) - (raise-contract-error src-info - pos-blame - neg-blame - "expected class to have method ~a, got: ~e" - 'meth-name - val)) - ... + (with-syntax ([call-super-initializer + (if use-make-object? + (syntax/loc stx + (begin (init-rest args) + (apply super-make args))) + (syntax/loc stx + (super-init ())))]) + (syntax/loc stx + (let ([meth-contract-var meth-contract] ...) + (make-contract + (lambda outer-args + (unless (class? val) + (raise-contract-error src-info pos-blame neg-blame "expected a class, got: ~e" val)) + (let ([class-i (class->interface val)]) + (void) + (unless (method-in-interface? 'meth-name class-i) + (raise-contract-error src-info + pos-blame + neg-blame + "expected class to have method ~a, got: ~e" + 'meth-name + val)) + ... + + (let ([override-spec? (eq? 'override 'method-specifier)] + [override? (method-in-interface? 'get-meth-contract class-i)]) + (unless (boolean=? override-spec? override?) + (if override-spec? + (error 'class-contract "method ~a is declared as an overriding method in ~e, but isn't" 'meth-name val) + (error 'class-contract "method ~a is declared as a public method in ~e, but isn't" 'meth-name val)))) + ...) - (let ([override-spec? (eq? 'override 'method-specifier)] - [override? (method-in-interface? 'get-meth-contract class-i)]) - (unless (boolean=? override-spec? override?) - (if override-spec? - (error 'class-contract "method ~a is declared as an overriding method in ~e, but isn't" 'meth-name val) - (error 'class-contract "method ~a is declared as a public method in ~e, but isn't" 'meth-name val)))) - ...) - - (class*/names-sneaky (this super-init super-make) val () - - method-contract-declarations ... - - (rename [super-meth-name meth-name] ...) - method ... - (super-init ()))) - (lambda x (error 'impl-contract "unimplemented")))))))] + (class*/names-sneaky + (this super-init super-make) val () + + method-contract-declarations ... + + (rename [super-meth-name meth-name] ...) + method ... + call-super-initializer)) + (lambda x (error 'impl-contract "unimplemented"))))))))] [(_ (meth-specifier meth-name meth-contract) ...) (for-each (lambda (specifier name) (unless (method-specifier? name) @@ -922,12 +988,28 @@ (with-syntax ([(val pos-blame neg-blame src-info) outer-args] [super-method-name (prefix-super method-name)] [method-name method-name] + [method-name-string (symbol->string (syntax-e method-name))] [contract-var contract-var]) (syntax/loc contract-stx (define/override method-name (lambda args - (let ([super-method (lambda x (super-method-name . x))]) - (apply (check-contract contract-var super-method pos-blame neg-blame src-info) args))))))) + (let ([super-method (lambda x (super-method-name . x))] + [method-specific-src-info + (if (identifier? src-info) + (datum->syntax-object + src-info + (string->symbol + (string-append + (symbol->string (syntax-e src-info)) + " method " + method-name-string))) + src-info)]) + (apply (check-contract contract-var + super-method + pos-blame + neg-blame + method-specific-src-info) + args))))))) ;; make-wrapper-method/impl : syntax syntax[identifier] syntax[identifier] syntax -> syntax ;; constructs a wrapper method that checks the pre and post-condition, and @@ -954,25 +1036,25 @@ args))))))) ;; make-public-method-contract-declaration : syntax syntax -> syntax - (define (make-public-method-contract-declaration meth-name meth-contract-var) + (define (make-public-method-contract-declaration src-stx meth-name meth-contract-var) (with-syntax ([get-contract (method-name->contract-method-name meth-name)] [meth-contract-var meth-contract-var] [meth-name meth-name]) - (syntax - (define/public (get-contract) - meth-contract-var)))) + (syntax/loc src-stx + (define/public (get-contract) + meth-contract-var)))) ;; make-override-method-contract-declaration : syntax syntax -> syntax - (define (make-override-method-contract-declaration meth-name meth-contract-var) + (define (make-override-method-contract-declaration src-stx meth-name meth-contract-var) (with-syntax ([get-contract (method-name->contract-method-name meth-name)] [super-get-contract (prefix-super (method-name->contract-method-name meth-name))] [meth-contract-var meth-contract-var] [meth-name meth-name]) - (syntax - (begin - (rename [super-get-contract get-contract]) - (define/override (get-contract) - meth-contract-var))))) + (syntax/loc src-stx + (begin + (rename [super-get-contract get-contract]) + (define/override (get-contract) + meth-contract-var))))) ;; prefix-super : syntax[identifier] -> syntax[identifier] ;; adds super- to the front of the identifier @@ -1026,26 +1108,33 @@ [ignore-range-checking? (syntax-case rng-normal (any) [any #t] - [_ #f])]) - (with-syntax ([(dom ...) (all-but-last (syntax->list (syntax (ct ...))))] - [rng (if ignore-range-checking? - (syntax any?) ;; hack to simplify life... - rng-normal)]) + [_ #f])] + [range-values + (syntax-case rng-normal (any) + [any (syntax (any?))] ;; range-values isn't actually used in this case + [(values x ...) + (eq? (syntax-e (syntax values)) 'values) + (syntax (x ...))] + [_ (with-syntax ([rng-normal rng-normal]) + (syntax (rng-normal)))])]) + (with-syntax ([(dom ...) (all-but-last (syntax->list (syntax (ct ...))))]) (with-syntax ([(dom-x ...) (generate-temporaries (syntax (dom ...)))] [(arg-x ...) (generate-temporaries (syntax (dom ...)))] + [(rng-x ...) (generate-temporaries range-values)] + [(rng ...) range-values] [arity (length (syntax->list (syntax (dom ...))))]) (let ([->add-outer-check (lambda (body) (with-syntax ([body body]) (syntax/loc stx (let ([dom-x dom] ... - [rng-x rng]) + [rng-x rng] ...) (unless (-contract? dom-x) (error '-> "expected contract as argument, given: ~e" dom-x)) ... (unless (-contract? rng-x) - (error '-> "expected contract as argument, given: ~e" rng-x)) + (error '-> "expected contract as argument, given: ~e" rng-x)) ... body))))] - [->body (syntax (->* (dom-x ...) (rng-x)))]) + [->body (syntax (->* (dom-x ...) (rng-x ...)))]) (let-values ([(->*add-outer-check ->*make-inner-check ->*make-body diff --git a/collects/tests/mzscheme/contracts.ss b/collects/tests/mzscheme/contracts.ss index 1276e92..ab3b21b 100644 --- a/collects/tests/mzscheme/contracts.ss +++ b/collects/tests/mzscheme/contracts.ss @@ -164,7 +164,43 @@ 'neg) 1 2 'bad) "neg") - + + (test/spec-passed + 'contract-arrow-values1 + '(let-values ([(a b) ((contract (-> integer? (values integer? integer?)) + (lambda (x) (values x x)) + 'pos + 'neg) + 2)]) + 1)) + + (test/spec-failed + 'contract-arrow-values2 + '((contract (-> integer? (values integer? integer?)) + (lambda (x) (values x x)) + 'pos + 'neg) + #f) + "neg") + + (test/spec-failed + 'contract-arrow-values3 + '((contract (-> integer? (values integer? integer?)) + (lambda (x) (values 1 #t)) + 'pos + 'neg) + 1) + "pos") + + (test/spec-failed + 'contract-arrow-values4 + '((contract (-> (integer?) (values integer? integer?)) + (lambda (x) (values #t 1)) + 'pos + 'neg) + 1) + "pos") + (test/spec-failed 'contract-d1 '(contract (integer? . ->d . (lambda (x) (lambda (y) (= x y)))) @@ -855,6 +891,17 @@ 1) "pos") + (test/spec-passed + 'class-contract/prim + '(make-object + (class (contract (class-contract/prim) + (class object% (init x) (init y) (init z) (super-make-object)) + 'pos-c + 'neg-c) + (init-rest x) + (apply super-make-object x)) + 1 2 3)) + (test/spec-failed 'class-contract=>1 '(let* ([c% (contract (class-contract (public m ((>=/c 10) . -> . (>=/c 10)))) @@ -879,8 +926,7 @@ 'pos-d 'neg-d)]) (send (make-object d%) m 100)) - "pos-d") - + "pos-d") (test/spec-passed/result 'class-contract=>2 @@ -894,8 +940,7 @@ (is-a? (make-object wd%) (class->interface wc%)) (is-a? (instantiate wd% ()) wc%) (is-a? (instantiate wd% ()) (class->interface wc%)))) - (list #t #t #t #t #t #t)) - + (list #t #t #t #t #t #t)) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;