diff --git a/collects/mzlib/contracts.ss b/collects/mzlib/contracts.ss index 9ecb2a4..e613723 100644 --- a/collects/mzlib/contracts.ss +++ b/collects/mzlib/contracts.ss @@ -1,6 +1,7 @@ (module contracts mzscheme (provide (rename -contract contract) + contract-=> -> ->d ->* @@ -488,16 +489,87 @@ "~agiven: ~e" (predicate->expected-msg contract) val))])) + + (define-syntax (contract-=> stx) + (syntax-case stx () + [(_ c1-e c2-e val-e tbb-e) + (with-syntax ([src-loc (datum->syntax-object stx 'here)]) + (syntax/loc stx + (contract-=> c1-e c2-e val-e tbb-e (quote-syntax src-loc))))] + [(_ c1-e c2-e val-e tbb-e src-loc-e) + (syntax/loc stx + (let ([c1 c1-e] + [c2 c2-e] + [val val-e] + [tbb tbb-e] + [src-loc src-loc-e]) + (unless (-contract? c1) + (error 'contract-=> "expected a contract as first argument, given: ~e, other args ~e ~e ~e ~e" + c1 + c2 + val + tbb + src-loc)) + (unless (-contract? c2) + (error 'contract-=> "expected a contract as second argument, given: ~e, other args ~e ~e ~e ~e" + c2 + c1 + val + tbb + src-loc)) + (unless (symbol? tbb) + (error 'contract-=> "expected symbol as names for assigning blame, given: ~e, other args ~e ~e ~e ~e" + tbb + c1 + c2 + val + src-loc)) + (unless (syntax? src-info) + (error 'contract "expected syntax as last argument, given: ~e, other args ~e ~e ~e ~e" + src-info + neg-blame + pos-blame + a-contract + name)) + (check-implication c1 c2 val tbb src-info)))])) + + ;; check-implication : contract contract any symbol (union syntax #f) -> any + (define (check-implication c1 c2 val tbb src-info) + (cond + [(and (contract? c1) (contract? c2)) + (error 'check-implication "not implemented")] + [(or (contract? c1) (contract? c2)) + (raise-contract-implication-error c1 c2 val tbb src-info)] + [else + (let ([test-contract + (lambda (c) + (cond + [(flat-named-contract? c) ((flat-named-contract-predicate c) val)] + [else (c val)]))]) + (if (or (not (test-contract c1)) + (test-contract c2)) + val + (raise-contract-implication-error c1 c2 val tbb src-info)))])) + + ;; raise-contract-implication-error : contract contract any symbol (union syntax #f) -> alpha + ;; escapes + (define (raise-contract-implication-error c1 c2 val tbb src-info) + (let ([blame-src (src-info-as-string src-info)]) + (raise + (make-exn + (string->immutable-string + (format "~a~a does not imply ~a for ~e" + blame-src + (contract->type-name c1) + (contract->type-name c2) + val)) + (current-continuation-marks))))) + ;; raise-contract-error : (union syntax #f) symbol symbol string args ... -> alpha ;; doesn't return (define (raise-contract-error src-info to-blame other-party fmt . args) - (let ([blame-src (if (syntax? src-info) - (let ([src-loc-str (build-src-loc-string src-info)]) - (if src-loc-str - (string-append src-loc-str ": ") - "")) - "")] + (let ([blame-src (src-info-as-string src-info)] [specific-blame (let ([datum (syntax-object->datum src-info)]) (if (symbol? datum) @@ -514,6 +586,15 @@ (apply format fmt args))) (current-continuation-marks))))) + ;; src-info-as-string : (union syntax #f) -> string + (define (src-info-as-string src-info) + (if (syntax? src-info) + (let ([src-loc-str (build-src-loc-string src-info)]) + (if src-loc-str + (string-append src-loc-str ": ") + "")) + "")) + ;; contract = (make-contract (alpha ;; sym ;; sym @@ -527,6 +608,8 @@ ;; the fourth argument is the src-info. (define-struct contract (f)) + (define-struct (->*contract contract) (doms rngs implication-maker)) + ;; flat-named-contract = (make-flat-named-contract string (any -> boolean)) ;; this holds flat contracts that have names for error reporting (define-struct flat-named-contract (type-name predicate)) @@ -573,6 +656,12 @@ (and m (cadr m)))))) + ;; contract->type-name : contract -> string + (define (contract->type-name c) + (cond + [(contract? c) "arrow contract"] + [else (flat-contract->type-name c)])) + ;; flat-contract->type-name : flat-contract -> string (define (flat-contract->type-name fc) (cond