diff --git a/collects/mzlib/contracts.ss b/collects/mzlib/contracts.ss index 1f3c2cb..4bce46b 100644 --- a/collects/mzlib/contracts.ss +++ b/collects/mzlib/contracts.ss @@ -8,6 +8,7 @@ case-> opt-> opt->* + class-contract (rename -contract? contract?) provide/contract define/contract) @@ -15,9 +16,7 @@ (require-for-syntax mzscheme (lib "list.ss") (lib "match.ss") - (lib "pretty.ss") - (lib "name.ss" "syntax") - (lib "stx.ss" "syntax")) + (lib "name.ss" "syntax")) (require (lib "class.ss") (lib "etc.ss")) @@ -71,7 +70,7 @@ (make-set!-transformer (lambda (stx) - ;; build-src-loc-string : syntax -> string + ;; build-src-loc-string/unk : syntax -> (union #f string) (define (build-src-loc-string/unk stx) (let ([source (syntax-source stx)] [line (syntax-line stx)] @@ -86,9 +85,10 @@ (format "~a: ~a" source pos)] [pos (format "~a" pos)] - [else "<>"]))) + [else #f]))) - (with-syntax ([neg-blame-str (build-src-loc-string/unk stx)]) + (with-syntax ([neg-blame-str (or (build-src-loc-string/unk stx) + "")]) (syntax-case stx (set!) [(set! _ arg) (raise-syntax-error 'define/contract @@ -460,10 +460,46 @@ pos-blame a-contract name)) - (check-contract a-contract name pos-blame neg-blame src-info #f)))))]))) + (check-contract a-contract name pos-blame neg-blame src-info)))))]))) - ;; check-contract : contract any symbol symbol syntax (union false? string?) - (define (check-contract contract val pos neg src-info extra-message) + ;; check-contract : contract any symbol symbol syntax -> ... + (define (check-contract contract val pos neg src-info) + (cond + [(contract? contract) + ((contract-f contract) + val + (lambda (rev-contract) (check-contract rev-contract val neg pos src-info)) + (lambda (same-contract) (check-contract same-contract val pos neg src-info)) + (lambda () (raise-contract-error + src-info + pos + neg + "expected type <~a>, given: ~e" + (flat-named-contract-type-name contract) + val)) + (lambda (v) v))] + [(flat-named-contract? contract) + (if ((flat-named-contract-predicate contract) val) + val + (raise-contract-error + src-info + pos + neg + "expected type <~a>, given: ~e" + (flat-named-contract-type-name contract) + val))] + [else + (if (contract val) + val + (raise-contract-error + src-info + pos + neg + "~agiven: ~e" + (predicate->expected-msg contract) + val))])) + + (define (check-implication contract1 contract2 val tbb src-info) (cond [(contract? contract) ((contract-f contract) val pos neg src-info)] @@ -484,18 +520,18 @@ src-info pos neg - "~agiven: ~e~a" + "~agiven: ~e" (predicate->expected-msg contract) - val - (if extra-message - extra-message - "")))])) + val))])) ;; raise-contract-error : (union syntax #f) symbol symbol string args ... -> alpha ;; doesn't return (define (raise-contract-error src-info to-blame other-party fmt . args) (let ([blame-src (if (syntax? src-info) - (string-append (build-src-loc-string src-info) ": ") + (let ([src-loc-str (build-src-loc-string src-info)]) + (if src-loc-str + (string-append src-loc-str ": ") + "")) "")] [specific-blame (let ([datum (syntax-object->datum src-info)]) @@ -513,12 +549,18 @@ (apply format fmt args))) (current-continuation-marks))))) - ;; contract = (make-contract (alpha sym sym (union syntax #f) -> alpha)) + ;; contract = (make-contract (alpha + ;; sym + ;; sym + ;; (union syntax #f) + ;; (contract beta sym sym (union syntax #f) -> beta) + ;; -> + ;; beta)) ;; generic contract container; ;; the first argument to f is the value to test the contract. ;; the second to f is a symbol representing the name of the positive blame ;; the third to f is the symbol representing the name of the negative blame - ;; the final argument is the src-info. + ;; the fourth argument is the src-info. (define-struct contract (f)) ;; flat-named-contract = (make-flat-named-contract string (any -> boolean)) @@ -643,32 +685,35 @@ (let-values ([(make-outer-check xxx build-pieces) (/h meth-contract-stx)]) (list make-outer-check xxx build-pieces)))) (syntax->list (syntax (meth-contract ...))))]) - (let ([outer-args (syntax (val pos neg src-info))]) + (let* ([outer-args (syntax (val pos-blame neg-blame src-info))] + [meth-names (syntax->list (syntax (meth-name ...)))] + [super-meth-names (map prefix-super meth-names)]) (with-syntax ([outer-args outer-args] - [(super-meth-name ...) (map prefix-super (syntax->list (syntax (meth-name ...))))]) + [(super-meth-name ...) super-meth-names] + [(method ...) (map (lambda (a b c) (make-wrapper-method outer-args a b c)) + meth-names + super-meth-names + build-pieces)]) (foldr (lambda (f stx) (f stx)) (syntax (make-contract (lambda outer-args (unless (class? val) - (raise-contract-error src-info pos neg "expected a class, got: ~e" val)) + (raise-contract-error src-info pos-blame neg-blame "expected a class, got: ~e" val)) (let ([class-i (class->interface val)]) (void) (unless (method-in-interface? 'meth-name class-i) (raise-contract-error src-info - pos neg + pos-blame + neg -blame "expected class to have method ~a, got: ~e" 'meth-name val)) ...) (class val (rename [super-meth-name meth-name] ...) - - (define/override meth-name - (lambda x (super-meth-name . x))) - ... - + method ... (super-instantiate ()))))) make-outer-checks))))] [(_ (meth-name meth-contract) ...) @@ -684,6 +729,21 @@ (syntax->list (syntax (clz ...))))])) + ;; make-wrapper-method : syntax[identifier] syntax[identifier] (syntax -> syntax) -> syntax + ;; constructs a wrapper method that checks the pre and post-condition, and + ;; calls the super method inbetween. + (define (make-wrapper-method outer-args method-name super-method-name build-piece) + (with-syntax ([super-method-name super-method-name] + [method-name method-name] + [(val pos-blame neg-blame src-info) outer-args] + [super-call (car (generate-temporaries (list super-method-name)))]) + (with-syntax ([(args body) (build-piece (syntax (super-call pos-blame neg-blame src-info)))]) + (syntax + (define/override method-name + (let ([super-call (lambda x (super-method-name . x))]) + (lambda args + body))))))) + ;; prefix-super : syntax[identifier] -> syntax[identifier] ;; adds super- to the front of the identifier (define (prefix-super stx) @@ -703,9 +763,15 @@ ;; (is a function of the right arity?) ;; - a piece of syntax that has the arguments to the wrapper ;; and the body of the wrapper. - ;; the first functions accepts `body' and it wraps + ;; the first function accepts a body expression and wraps + ;; the body expression with checks. In addition, it + ;; adds a let that binds the contract exprssions to names + ;; the results of the other functions mention these names. ;; the second and third function's input syntax should be four ;; names: val, pos-blame, neg-blame, src-info. + ;; the third function returns a syntax list with two elements, + ;; the argument list (to be used as the first arg to lambda, + ;; or as a case-lambda clause) and the body of the function. ;; They are combined into a lambda for the -> ->* ->d ->d* macros, ;; and combined into a case-lambda for the case-> macro. @@ -748,7 +814,7 @@ (syntax ((arg-x ...) (val - (check-contract dom-x arg-x neg-blame pos-blame src-info #f) + (check-contract dom-x arg-x neg-blame pos-blame src-info) ...))))) (lambda (stx) (->*make-body stx)))))))))])) @@ -774,7 +840,7 @@ (error '->* "expected contract as argument, given: ~e" rng-x)) ... body)))) (lambda (stx) - (with-syntax ([(val pos-blame neg-blame src-info) stx]) + (with-syntax ([(val check-rev-contract check-same-contract) stx]) (syntax (unless (and (procedure? val) (procedure-arity-includes? val arity)) @@ -791,15 +857,11 @@ ((arg-x ...) (let-values ([(res-x ...) (val - (check-contract dom-x arg-x neg-blame pos-blame src-info #f) + (check-rev-contract dom-x arg-x) ...)]) - (values (check-contract - rng-x - res-x - pos-blame - neg-blame - src-info - #f) + (values (check-same-contract + rng-x + res-x) ...))))))))] [(_ (dom ...) rest (rng ...)) (with-syntax ([(dom-x ...) (generate-temporaries (syntax (dom ...)))] @@ -822,7 +884,7 @@ (error '->* "expected contract for range position, given: ~e" rng-x)) ... body)))) (lambda (stx) - (with-syntax ([(val pos-blame neg-blame src-info) stx]) + (with-syntax ([(val check-rev-contract check-same-contract failure) stx]) (syntax (unless (procedure? val) (raise-contract-error @@ -839,16 +901,15 @@ (let-values ([(res-x ...) (apply val - (check-contract dom-x arg-x neg-blame pos-blame src-info #f) + (check-contract dom-x arg-x neg-blame pos-blame src-info) ... - (check-contract dom-rest-x rest-arg-x neg-blame pos-blame src-info #f))]) + (check-contract dom-rest-x rest-arg-x neg-blame pos-blame src-info))]) (values (check-contract rng-x res-x pos-blame neg-blame - src-info - #f) + src-info) ...))))))))])) ;; ->d/h : stx -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax)) @@ -897,11 +958,10 @@ rng-contract)) (check-contract rng-contract - (val (check-contract dom-x arg-x neg-blame pos-blame src-info #f) ...) + (val (check-contract dom-x arg-x neg-blame pos-blame src-info) ...) pos-blame neg-blame - src-info - #f)))))))))])) + src-info)))))))))])) ;; ->d*/h : stx -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax)) (define (->d*/h stx) @@ -946,7 +1006,7 @@ (call-with-values (lambda () (val - (check-contract dom-x arg-x neg-blame pos-blame src-info #f) + (check-contract dom-x arg-x neg-blame pos-blame src-info) ...)) (lambda results (unless (= (length results) (length rng-contracts)) @@ -961,8 +1021,7 @@ result pos-blame neg-blame - src-info - #f)) + src-info)) rng-contracts results))))))))))))] [(_ (dom ...) rest rng-mk) @@ -1007,9 +1066,9 @@ (lambda () (apply val - (check-contract dom-x arg-x neg-blame pos-blame src-info #f) + (check-contract dom-x arg-x neg-blame pos-blame src-info) ... - (check-contract dom-rest-x rest-arg-x neg-blame pos-blame src-info #f))) + (check-contract dom-rest-x rest-arg-x neg-blame pos-blame src-info))) (lambda results (unless (= (length results) (length rng-contracts)) (error '->d* @@ -1023,8 +1082,7 @@ result pos-blame neg-blame - src-info - #f)) + src-info )) rng-contracts results))))))))))))]))