diff --git a/collects/mzlib/contract.ss b/collects/mzlib/contract.ss index b351e8c..1f79335 100644 --- a/collects/mzlib/contract.ss +++ b/collects/mzlib/contract.ss @@ -611,24 +611,32 @@ improve method arity mismatch contract violation error messages? ; - (define-syntax-set (-> ->* ->d ->d* case-> object-contract) + (define-syntax-set (-> ->* ->d ->d* case-> object-contract opt-> opt->*) - (define (->/proc stx) (make-/proc #t ->/h stx)) - (define (->*/proc stx) (make-/proc #t ->*/h stx)) - (define (->d/proc stx) (make-/proc #t ->d/h stx)) - (define (->d*/proc stx) (make-/proc #t ->d*/h stx)) + (define (->/proc stx) (make-/proc #f ->/h stx)) + (define (->*/proc stx) (make-/proc #f ->*/h stx)) + (define (->d/proc stx) (make-/proc #f ->d/h stx)) + (define (->d*/proc stx) (make-/proc #f ->d*/h stx)) - (define (obj->/proc stx) (make-/proc #f ->/h stx)) - (define (obj->*/proc stx) (make-/proc #f ->*/h stx)) - (define (obj->d/proc stx) (make-/proc #f ->d/h stx)) - (define (obj->d*/proc stx) (make-/proc #f ->d*/h stx)) + (define (obj->/proc stx) (make-/proc #t ->/h stx)) + (define (obj->*/proc stx) (make-/proc #t ->*/h stx)) + (define (obj->d/proc stx) (make-/proc #t ->d/h stx)) + (define (obj->d*/proc stx) (make-/proc #t ->d*/h stx)) + (define (case->/proc stx) (make-case->/proc #f stx)) + (define (obj-case->/proc stx) (make-case->/proc #t stx)) + + (define (obj-opt->/proc stx) (make-opt->/proc #t stx)) + (define (obj-opt->*/proc stx) (make-opt->*/proc #t stx)) + (define (opt->/proc stx) (make-opt->/proc #f stx)) + (define (opt->*/proc stx) (make-opt->*/proc #f stx)) + ;; make-/proc : boolean ;; (syntax -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax) (syntax -> syntax))) ;; syntax ;; -> (syntax -> syntax) - (define (make-/proc show-first? /h stx) - (let-values ([(arguments-check build-proj check-val wrapper) (/h show-first? stx)]) + (define (make-/proc method-proc? /h stx) + (let-values ([(arguments-check build-proj check-val wrapper) (/h method-proc? stx)]) (let ([outer-args (syntax (val pos-blame neg-blame src-info orig-str name-id))]) (with-syntax ([inner-check (check-val outer-args)] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] @@ -653,17 +661,11 @@ improve method arity mismatch contract violation error messages? (lambda (pos-blame neg-blame src-info orig-str) proj-code)))))))))))) - ;; case->/proc : syntax -> syntax - ;; the transformer for the case-> macro - (define (case->/proc stx) (make-case->/proc #t stx)) - - (define (obj-case->/proc stx) (make-case->/proc #f stx)) - - (define (make-case->/proc show-first? stx) + (define (make-case->/proc method-proc? stx) (syntax-case stx () [(_ cases ...) (let-values ([(arguments-check build-projs check-val wrapper) - (case->/h show-first? stx (syntax->list (syntax (cases ...))))]) + (case->/h method-proc? stx (syntax->list (syntax (cases ...))))]) (let ([outer-args (syntax (val pos-blame neg-blame src-info orig-str name-id))]) (with-syntax ([(inner-check ...) (check-val outer-args)] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] @@ -686,6 +688,39 @@ improve method arity mismatch contract violation error messages? (lambda (pos-blame neg-blame src-info orig-str) proj-code))))))))))])) + (define (make-opt->/proc method-proc? stx) + (syntax-case stx () + [(_ (reqs ...) (opts ...) res) + (make-opt->*/proc method-proc? (syntax (opt->* (reqs ...) (opts ...) (res))))])) + + (define (make-opt->*/proc method-proc? stx) + (syntax-case stx () + [(_ (reqs ...) (opts ...) (ress ...)) + (let* ([res-vs (generate-temporaries (syntax->list (syntax (ress ...))))] + [req-vs (generate-temporaries (syntax->list (syntax (reqs ...))))] + [opt-vs (generate-temporaries (syntax->list (syntax (opts ...))))] + [cses + (reverse + (let loop ([opt-vs (reverse opt-vs)]) + (cond + [(null? opt-vs) (list req-vs)] + [else (cons (append req-vs (reverse opt-vs)) + (loop (cdr opt-vs)))])))]) + (with-syntax ([((double-res-vs ...) ...) (map (lambda (x) res-vs) cses)] + [(res-vs ...) res-vs] + [(req-vs ...) req-vs] + [(opt-vs ...) opt-vs] + [((case-doms ...) ...) cses]) + (with-syntax ([expanded-case-> + (make-case->/proc + method-proc? + (syntax (case-> (-> case-doms ... (values double-res-vs ...)) ...)))]) + (syntax/loc stx + (let ([res-vs ress] ... + [req-vs reqs] ... + [opt-vs opts] ...) + expanded-case->)))))])) + ;; exactract-argument-lists : syntax -> (listof syntax) (define (extract-argument-lists stx) (map (lambda (x) @@ -749,7 +784,7 @@ improve method arity mismatch contract violation error messages? ;; (syntax -> syntax)) ;; like the other /h functions, but composes the wrapper functions ;; together and combines the cases of the case-lambda into a single list. - (define (case->/h show-first? orig-stx cases) + (define (case->/h method-proc? orig-stx cases) (let loop ([cases cases] [name-ids '()]) (cond @@ -769,7 +804,7 @@ improve method arity mismatch contract violation error messages? (let-values ([(arguments-checks build-projs check-vals wrappers) (loop (cdr cases) (cons new-id name-ids))] [(arguments-check build-proj check-val wrapper) - (/h show-first? (car cases))]) + (/h method-proc? (car cases))]) (values (lambda (outer-args x) (with-syntax ([(val pos-blame neg-blame src-info orig-str name-id) outer-args] @@ -825,7 +860,7 @@ improve method arity mismatch contract violation error messages? ;; expand-mtd-contract : syntax -> (values syntax[expanded ctc] syntax[mtd-arg]) (define (expand-mtd-contract mtd-stx) - (syntax-case mtd-stx (case-> opt->) + (syntax-case mtd-stx (case-> opt-> opt->*) [(case-> cases ...) (let loop ([cases (syntax->list (syntax (cases ...)))] [ctc-stxs null] @@ -842,12 +877,31 @@ improve method arity mismatch contract violation error messages? (loop (cdr cases) (cons ctc-stx ctc-stxs) (cons mtd-args args-stxs)))]))] - #| - [(opt-> opts ...) ...] - |# + [(opt->* (req-contracts ...) (opt-contracts ...) (res-contracts ...)) + (values + (obj-opt->*/proc (syntax (opt->* (any? req-contracts ...) (opt-contracts ...) (res-contracts ...)))) + (generate-opt->vars (syntax (req-contracts ...)) + (syntax (opt-contracts ...))))] + [(opt-> (req-contracts ...) (opt-contracts ...) res-contract) + (values + (obj-opt->/proc (syntax (opt-> (any? req-contracts ...) (opt-contracts ...) res-contract))) + (generate-opt->vars (syntax (req-contracts ...)) + (syntax (opt-contracts ...))))] [else (let-values ([(x y z) (expand-mtd-arrow mtd-stx)]) (values (x y) z))])) + ;; generate-opt->vars : syntax[requried contracts] syntax[optional contracts] -> syntax[list of arg specs] + (define (generate-opt->vars req-stx opt-stx) + (with-syntax ([(req-vars ...) (generate-temporaries req-stx)] + [(ths) (generate-temporaries (syntax (ths)))]) + (let loop ([opt-vars (generate-temporaries opt-stx)]) + (cond + [(null? opt-vars) (list (syntax (ths req-vars ...)))] + [else (with-syntax ([(opt-vars ...) opt-vars] + [(rests ...) (loop (cdr opt-vars))]) + (syntax ((ths req-vars ... opt-vars ...) + rests ...)))])))) + ;; expand-mtd-arrow : stx -> (values (syntax[ctc] -> syntax[expanded ctc]) syntax[ctc] syntax[mtd-arg]) (define (expand-mtd-arrow mtd-stx) (syntax-case mtd-stx (-> ->* ->d ->d*) @@ -1140,7 +1194,7 @@ improve method arity mismatch contract violation error messages? ;; and combined into a case-lambda for the case-> macro. ;; ->/h : boolean stx -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax)) - (define (->/h show-first? stx) + (define (->/h method-proc? stx) (syntax-case stx () [(_) (raise-syntax-error '-> "expected at least one argument" stx)] [(_ arg ...) @@ -1153,11 +1207,11 @@ improve method arity mismatch contract violation error messages? [(dom-ant-x ...) (generate-temporaries (syntax (dom ...)))] [(arg-x ...) (generate-temporaries (syntax (dom ...)))]) (with-syntax ([(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax-case* (syntax rng) (any values) module-or-top-identifier=? [any (values @@ -1296,7 +1350,7 @@ improve method arity mismatch contract violation error messages? (rng-projection-x res-x))))))))]))))])) ;; ->*/h : boolean stx -> (values (syntax -> syntax) (syntax syntax -> syntax) (syntax -> syntax) (syntax -> syntax)) - (define (->*/h show-first? stx) + (define (->*/h method-proc? stx) (syntax-case stx (any) [(_ (dom ...) (rng ...)) (with-syntax ([(dom-x ...) (generate-temporaries (syntax (dom ...)))] @@ -1317,11 +1371,11 @@ improve method arity mismatch contract violation error messages? (with-syntax ([body body] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] [(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax (let ([dom-contract-x (coerce-contract ->* dom)] ... [rng-contract-x (coerce-contract ->* rng)] ...) @@ -1376,11 +1430,11 @@ improve method arity mismatch contract violation error messages? (with-syntax ([body body] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] [(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax (let ([dom-contract-x (coerce-contract ->* dom)] ...) (let ([dom-x (contract-proc dom-contract-x)] ...) @@ -1439,11 +1493,11 @@ improve method arity mismatch contract violation error messages? (with-syntax ([(val pos-blame neg-blame src-info orig-str name-id) outer-args] [body body] [(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax (let ([dom-contract-x (coerce-contract ->* dom)] ... [dom-rest-contract-x (coerce-contract ->* rest)] @@ -1508,11 +1562,11 @@ improve method arity mismatch contract violation error messages? (with-syntax ([body body] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] [(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax (let ([dom-contract-x (coerce-contract ->* dom)] ... [dom-rest-contract-x (coerce-contract ->* rest)]) @@ -1554,7 +1608,7 @@ improve method arity mismatch contract violation error messages? (dom-projection-rest-x arg-rest-x))))))))])) ;; ->d/h : boolean stx -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax)) - (define (->d/h show-first? stx) + (define (->d/h method-proc? stx) (syntax-case stx () [(_) (raise-syntax-error '->d "expected at least one argument" stx)] [(_ ct ...) @@ -1570,11 +1624,11 @@ improve method arity mismatch contract violation error messages? (with-syntax ([body body] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] [(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax (let ([dom-contract-x (coerce-contract ->d dom)] ...) (let ([dom-x (contract-proc dom-contract-x)] ... @@ -1619,7 +1673,7 @@ improve method arity mismatch contract violation error messages? (val (dom-projection-x arg-x) ...))))))))))])) ;; ->d*/h : boolean stx -> (values (syntax -> syntax) (syntax -> syntax) (syntax -> syntax)) - (define (->d*/h show-first? stx) + (define (->d*/h method-proc? stx) (syntax-case stx () [(_ (dom ...) rng-mk) (with-syntax ([(dom-x ...) (generate-temporaries (syntax (dom ...)))] @@ -1633,11 +1687,11 @@ improve method arity mismatch contract violation error messages? (with-syntax ([body body] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] [(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax (let ([dom-contract-x (coerce-contract ->d* dom)] ...) (let ([dom-x (contract-proc dom-contract-x)] ... @@ -1709,11 +1763,11 @@ improve method arity mismatch contract violation error messages? (with-syntax ([body body] [(val pos-blame neg-blame src-info orig-str name-id) outer-args] [(name-dom-contract-x ...) - (if show-first? - (syntax (dom-contract-x ...)) + (if method-proc? (cdr (syntax->list - (syntax (dom-contract-x ...)))))]) + (syntax (dom-contract-x ...)))) + (syntax (dom-contract-x ...)))]) (syntax (let ([dom-contract-x (coerce-contract ->d* dom)] ... [dom-rest-contract-x (coerce-contract ->d* rest)]) @@ -1858,37 +1912,6 @@ improve method arity mismatch contract violation error messages? (error 'name "expected contract or procedure of arity 1, got ~e" x)])))])) - - (define class-with-contracts<%> (interface ())) - - (define-syntax (opt-> stx) - (syntax-case stx () - [(_ (reqs ...) (opts ...) res) - (syntax (opt->* (reqs ...) (opts ...) (res)))])) - - (define-syntax (opt->* stx) - (syntax-case stx () - [(_ (reqs ...) (opts ...) (ress ...)) - (let* ([res-vs (generate-temporaries (syntax->list (syntax (ress ...))))] - [req-vs (generate-temporaries (syntax->list (syntax (reqs ...))))] - [opt-vs (generate-temporaries (syntax->list (syntax (opts ...))))] - [cases - (reverse - (let loop ([opt-vs (reverse opt-vs)]) - (cond - [(null? opt-vs) (list req-vs)] - [else (cons (append req-vs (reverse opt-vs)) - (loop (cdr opt-vs)))])))]) - (with-syntax ([((double-res-vs ...) ...) (map (lambda (x) res-vs) cases)] - [(res-vs ...) res-vs] - [(req-vs ...) req-vs] - [(opt-vs ...) opt-vs] - [((case-doms ...) ...) cases]) - (syntax/loc stx - (let ([res-vs ress] ... - [req-vs reqs] ... - [opt-vs opts] ...) - (case-> (->* (case-doms ...) (double-res-vs ...)) ...)))))])) ; diff --git a/collects/tests/mzscheme/contract-test.ss b/collects/tests/mzscheme/contract-test.ss index 113c091..9e50e9f 100644 --- a/collects/tests/mzscheme/contract-test.ss +++ b/collects/tests/mzscheme/contract-test.ss @@ -1,6 +1,7 @@ (load-relative "loadtest.ss") (require (lib "contract.ss") - (lib "class.ss")) + (lib "class.ss") + (lib "etc.ss")) (SECTION 'contract) @@ -948,7 +949,198 @@ 3 4) 7) - + + (test/spec-failed + 'object-contract-opt->*1 + '(contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a]) + x)) + (super-new))) + 'pos + 'neg) + "pos") + + (test/spec-failed + 'object-contract-opt->*2 + '(contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x y [z #t]) + x)) + (super-new))) + 'pos + 'neg) + "pos") + + (test/spec-passed + 'object-contract-opt->*3 + '(contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + x)) + (super-new))) + 'pos + 'neg)) + + (test/spec-passed/result + 'object-contract-opt->*4 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + x)) + (super-new))) + 'pos + 'neg) + m + 1) + 1) + + (test/spec-passed/result + 'object-contract-opt->*5 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + x)) + (super-new))) + 'pos + 'neg) + m + 2 + 'z) + 2) + + (test/spec-passed/result + 'object-contract-opt->*7 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + x)) + (super-new))) + 'pos + 'neg) + m + 3 + 'z + #f) + 3) + + (test/spec-failed + 'object-contract-opt->*8 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + x)) + (super-new))) + 'pos + 'neg) + m + #f) + "neg") + + (test/spec-failed + 'object-contract-opt->*9 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + x)) + (super-new))) + 'pos + 'neg) + m + 2 + 4) + "neg") + + (test/spec-failed + 'object-contract-opt->*10 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + x)) + (super-new))) + 'pos + 'neg) + m + 3 + 'z + 'y) + "neg") + + (test/spec-failed + 'object-contract-opt->*11 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + 'x)) + (super-new))) + 'pos + 'neg) + m + 3 + 'z + #f) + "pos") + + (test/spec-passed/result + 'object-contract-opt->*12 + '(let-values ([(x y) + (send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number? symbol?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + (values 1 'x))) + (super-new))) + 'pos + 'neg) + m + 3 + 'z + #f)]) + (cons x y)) + (cons 1 'x)) + + (test/spec-failed + 'object-contract-opt->*13 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number? symbol?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + (values 'x 'x))) + (super-new))) + 'pos + 'neg) + m + 3 + 'z + #f) + "pos") + + (test/spec-failed + 'object-contract-opt->*14 + '(send (contract (object-contract (m (opt->* (integer?) (symbol? boolean?) (number? symbol?)))) + (new (class object% + (define/public m + (opt-lambda (x [y 'a] [z #t]) + (values 1 1))) + (super-new))) + 'pos + 'neg) + m + 3 + 'z + #f) + "pos") + ; ; @@ -1481,6 +1673,8 @@ (object-contract (m (case-> (-> integer? integer? integer?) (-> integer? (values integer? integer?)))))) + (test-name "(object-contract (m (case-> (-> integer? (values symbol?)) (-> integer? boolean? (values symbol?)) (-> integer? boolean? number? (values symbol?)))))" + (object-contract (m (opt->* (integer?) (boolean? number?) (symbol?))))) )) (report-errs)