diff --git a/collects/mzlib/private/unit-contract.ss b/collects/mzlib/private/unit-contract.ss new file mode 100644 index 0000000000..632f613ead --- /dev/null +++ b/collects/mzlib/private/unit-contract.ss @@ -0,0 +1,324 @@ +#lang scheme/base + +(require (for-syntax scheme/base + stxclass + syntax/boundmap + "unit-compiletime.ss") + scheme/contract + "unit-keywords.ss" + "unit-utils.ss" + "unit-runtime.ss") + +(provide unit/c) + +(define-for-syntax (contract-imports/exports import?) + (λ (table-stx import-tagged-infos import-sigs ctc-table-stx pos neg src-info name) + (define def-table (make-bound-identifier-mapping)) + (define ctc-table (make-bound-identifier-mapping)) + (define (convert-reference vref ctc sig-ctc) + (let ([wrap-with-proj + (λ (stx) + #`((((proj-get ctc) ctc) #,(if import? neg pos) + #,(if import? pos neg) + #,src-info + #,name) + #,stx))]) + #`(let ([ctc #,ctc]) + (if ctc + (cons (λ () + (let* ([old-v #,(if sig-ctc + #`(let ([old-v/c ((car #,vref))]) + (cons #,(wrap-with-proj #'(car old-v/c)) + (cdr old-v/c))) + (wrap-with-proj #`((car #,vref))))]) + old-v)) + (λ (v) + (let* ([new-v #,(if sig-ctc + #`(cons #,(wrap-with-proj #'(car v)) + (cdr v)) + (wrap-with-proj #'v))]) + ((cdr #,vref) new-v)))) + #,vref)))) + (for-each + (lambda (tagged-info sig) + (define v + #`(hash-ref #,table-stx #,(car (tagged-info->keys tagged-info)))) + (define c + #`(hash-ref #,ctc-table-stx #,(car (tagged-info->keys tagged-info)))) + (for-each + (lambda (int/ext-name index ctc) + (bound-identifier-mapping-put! def-table + (car int/ext-name) + #`(vector-ref #,v #,index)) + (bound-identifier-mapping-put! ctc-table + (car int/ext-name) + #`(vector-ref #,c #,index))) + (car sig) + (build-list (length (car sig)) values) + (cadddr sig))) + import-tagged-infos + import-sigs) + (with-syntax ((((eloc ...) ...) + (map + (lambda (target-sig) + (map + (lambda (target-int/ext-name sig-ctc) + (let* ([vref + (bound-identifier-mapping-get + def-table + (car target-int/ext-name))] + [ctc + (bound-identifier-mapping-get + ctc-table + (car target-int/ext-name))]) + (convert-reference vref ctc sig-ctc))) + (car target-sig) + (cadddr target-sig))) + import-sigs)) + (((export-keys ...) ...) + (map tagged-info->keys import-tagged-infos))) + #'(unit-export ((export-keys ...) + (vector-immutable eloc ...)) ...)))) + +(define-for-syntax contract-imports (contract-imports/exports #t)) +(define-for-syntax contract-exports (contract-imports/exports #f)) + +(define-for-syntax (build-contract-table import? import-tagged-infos import-sigs id-stx ctc-stx) + (with-syntax ([((ectc ...) ...) + (map (λ (sig ids ctcs) + (let ([alist (map cons (syntax->list ids) (syntax->list ctcs))]) + (map (λ (int/ext-name) + (cond + [(assf (λ (i) + (bound-identifier=? i (car int/ext-name))) + alist) + => + (λ (p) (cdr p))] + [else #'#f])) + (car sig)))) + import-sigs + (syntax->list id-stx) + (syntax->list ctc-stx))] + [((export-keys ...) ...) + (map tagged-info->keys import-tagged-infos)]) + #'(unit-export ((export-keys ...) (vector-immutable ectc ...)) ...))) + +(define-for-syntax (check-ids name sig alist) + (let ([ctc-sig/ids (assf (λ (i) + (bound-identifier=? name i)) + alist)]) + (when ctc-sig/ids + (let ([ids (map car (car sig))]) + (for-each (λ (id) + (unless (memf (λ (i) (bound-identifier=? id i)) ids) + (raise-syntax-error 'unit/c + (format "identifier not member of signature ~a" + (syntax-e name)) + id))) + (cdr ctc-sig/ids)))))) + +(define-syntax/err-param (unit/c stx) + (begin + (define-syntax-class sig-id + (pattern x + #:declare x (static-of 'signature + (λ (x) + (signature? (set!-trans-extract x)))))) + (define-syntax-class unit/c-clause + #:transparent + (pattern (s:sig-id [x:identifier c:expr] ...)) + (pattern s:sig-id ;; allow a non-wrapped sig-id, which is the same as (sig-id) + #:with (x ...) null + #:with (c ...) null)) + (define-syntax-class import-clause #:literals (import) + #:transparent + (pattern (import i:unit/c-clause ...))) + (define-syntax-class export-clause #:literals (export) + #:transparent + (pattern (export e:unit/c-clause ...))) + (syntax-parse stx + [(_ (import i:unit/c-clause ...) + (export e:unit/c-clause ...) bad-expr . rest) + (raise-syntax-error 'unit/c + "extra form" + #'bad-expr)] + [(_ :import-clause :export-clause) + (begin + (define-values (isig tagged-import-sigs import-tagged-infos + import-tagged-sigids import-sigs) + (process-unit-import #'(i.s ...))) + + (define-values (esig tagged-export-sigs export-tagged-infos + export-tagged-sigids export-sigs) + (process-unit-export #'(e.s ...))) + + (check-duplicate-sigs import-tagged-infos isig null null) + + (check-duplicate-subs export-tagged-infos esig) + + (check-unit-ie-sigs import-sigs export-sigs) + + (for-each (λ (sig xs) + (let ([dup (check-duplicate-identifier (syntax->list xs))]) + (when dup + (raise-syntax-error 'unit/c + (format "duplicate identifier found for signature ~a" (syntax->datum sig)) + dup)))) + (syntax->list #'(i.s ... e.s ...)) + (syntax->list #'((i.x ...) ... (e.x ...) ...))) + + (let ([alist (map syntax->list + (syntax->list #'((i.s i.x ...) ...)))]) + (for-each (λ (name sig) + (check-ids name sig alist)) + isig import-sigs)) + + (let ([alist (map syntax->list + (syntax->list #'((e.s e.x ...) ...)))]) + (for-each (λ (name sig) + (check-ids name sig alist)) + esig export-sigs)) + + (with-syntax ([((import-key ...) ...) + (map tagged-info->keys import-tagged-infos)] + [((export-key ...) ...) + (map tagged-info->keys export-tagged-infos)] + [(import-name ...) + (map (lambda (tag/info) (car (siginfo-names (cdr tag/info)))) + import-tagged-infos)] + [(export-name ...) + (map (lambda (tag/info) (car (siginfo-names (cdr tag/info)))) + export-tagged-infos)] + [((new-ci ...) ...) (map generate-temporaries (syntax->list #'((i.c ...) ...)))] + [((new-ce ...) ...) (map generate-temporaries (syntax->list #'((e.c ...) ...)))]) + (quasisyntax/loc stx + (let-values ([(new-ci ...) (values (coerce-contract 'unit/c i.c) ...)] ... + [(new-ce ...) (values (coerce-contract 'unit/c e.c) ...)] ...) + (make-proj-contract + (list 'unit/c + (cons 'import + (list (cons 'i.s + (map list (list 'i.x ...) + (build-compound-type-name new-ci ...))) + ...)) + (cons 'export + (list (cons 'e.s + (map list (list 'e.x ...) + (build-compound-type-name new-ce ...))) + ...))) + (λ (pos neg src-info name) + (λ (unit-tmp) + (unless (unit? unit-tmp) + (raise-contract-error unit-tmp src-info pos name + "value is not a unit")) + (contract-check-sigs + unit-tmp + (vector-immutable + (cons 'import-name + (vector-immutable import-key ...)) ...) + (vector-immutable + (cons 'export-name + (vector-immutable export-key ...)) ...) + src-info pos name) + (make-unit + #f + (vector-immutable (cons 'import-name + (vector-immutable import-key ...)) ...) + (vector-immutable (cons 'export-name + (vector-immutable export-key ...)) ...) + (unit-deps unit-tmp) + (lambda () + (let-values ([(unit-fn export-table) ((unit-go unit-tmp))]) + (values (lambda (import-table) + (let ([import-ctc-table + #,(build-contract-table #t + import-tagged-infos + import-sigs + #'((i.x ...) ...) + #'((new-ci ...) ...))]) + (unit-fn #,(contract-imports + #'import-table + import-tagged-infos + import-sigs + #'import-ctc-table + #'pos + #'neg + #'src-info + #'name)))) + (let ([export-ctc-table + #,(build-contract-table #f + export-tagged-infos + export-sigs + #'((e.x ...) ...) + #'((new-ce ...) ...))]) + #,(contract-exports + #'export-table + export-tagged-infos + export-sigs + #'export-ctc-table + #'pos + #'neg + #'src-info + #'name)))))))) + (λ (v) + (and (unit? v) + (with-handlers ([exn:fail:contract? (λ () #f)]) + (contract-check-sigs + v + (vector-immutable + (cons 'import-name + (vector-immutable import-key ...)) ...) + (vector-immutable + (cons 'export-name + (vector-immutable export-key ...)) ...) + (list #f "not-used") 'not-used null)) + #t)))))))] + [(_ (import i:unit/c-clause ...) bad-e . body) + (raise-syntax-error 'unit/c + "expected an export description" + #'bad-e)] + [(_ (import i:unit/c-clause ...)) + (raise-syntax-error 'unit/c + "expected an export description" + stx)] + [(_ bad-i . rest) + (raise-syntax-error 'unit/c + "expected an import description" + #'bad-i)] + [(_) + (raise-syntax-error 'unit/c + "expected an import description" + stx)]))) + +(define (contract-check-helper sub-sig super-sig import? val src-info blame ctc) + (define t (make-hash)) + (let loop ([i (sub1 (vector-length sub-sig))]) + (when (>= i 0) + (let ([v (cdr (vector-ref sub-sig i))]) + (let loop ([j (sub1 (vector-length v))]) + (when (>= j 0) + (let ([vj (vector-ref v j)]) + (hash-set! t vj + (if (hash-ref t vj #f) + 'amb + #t))) + (loop (sub1 j))))) + (loop (sub1 i)))) + (let loop ([i (sub1 (vector-length super-sig))]) + (when (>= i 0) + (let* ([v0 (vector-ref (cdr (vector-ref super-sig i)) 0)] + [r (hash-ref t v0 #f)]) + (when (not r) + (let ([sub-name (car (vector-ref super-sig i))]) + (raise-contract-error + val src-info blame ctc + (cond + [import? + (format "contract does not list import ~a" sub-name)] + [else + (format "unit must export signature ~a" sub-name)]))))) + (loop (sub1 i))))) + +(define (contract-check-sigs unit expected-imports expected-exports src-info blame ctc) + (contract-check-helper expected-imports (unit-import-sigs unit) #t unit src-info blame ctc) + (contract-check-helper (unit-export-sigs unit) expected-exports #f unit src-info blame ctc)) diff --git a/collects/mzlib/private/unit-utils.ss b/collects/mzlib/private/unit-utils.ss new file mode 100644 index 0000000000..3552e67914 --- /dev/null +++ b/collects/mzlib/private/unit-utils.ss @@ -0,0 +1,107 @@ +#lang mzscheme + +(require (for-syntax "unit-compiletime.ss" + "unit-syntax.ss")) + +(provide (for-syntax build-key + check-duplicate-sigs + check-unit-ie-sigs + iota + process-unit-import + process-unit-export + tagged-info->keys)) + +(provide equal-hash-table + unit-export) + +(define-for-syntax (iota n) + (let loop ((n n) + (acc null)) + (cond + ((= n 0) acc) + (else (loop (sub1 n) (cons (sub1 n) acc)))))) + +(define-syntax-rule (equal-hash-table [k v] ...) + (make-immutable-hash-table (list (cons k v) ...) 'equal)) + +(define-syntax (unit-export stx) + (syntax-case stx () + ((_ ((esig ...) elocs) ...) + (with-syntax ((((kv ...) ...) + (map + (lambda (esigs eloc) + (map + (lambda (esig) #`(#,esig #,eloc)) + (syntax->list esigs))) + (syntax->list #'((esig ...) ...)) + (syntax->list #'(elocs ...))))) + #'(equal-hash-table kv ... ...))))) + +;; check-duplicate-sigs : (listof (cons symbol siginfo)) (listof syntax-object) +;; (listof (cons symbol siginfo)) (listof syntax-object) -> +(define-for-syntax (check-duplicate-sigs tagged-siginfos sources tagged-deps dsources) + (define import-idx (make-hash-table 'equal)) + (for-each + (lambda (tinfo s) + (define key (cons (car tinfo) + (car (siginfo-ctime-ids (cdr tinfo))))) + (when (hash-table-get import-idx key #f) + (raise-stx-err "duplicate import signature" s)) + (hash-table-put! import-idx key #t)) + tagged-siginfos + sources) + (for-each + (lambda (dep s) + (unless (hash-table-get import-idx + (cons (car dep) + (car (siginfo-ctime-ids (cdr dep)))) + #f) + (raise-stx-err "initialization dependency on unknown import" s))) + tagged-deps + dsources)) + +(define-for-syntax (check-unit-ie-sigs import-sigs export-sigs) + (let ([dup (check-duplicate-identifier + (apply append (map sig-int-names import-sigs)))]) + (when dup + (raise-stx-err + (format "~a is imported by multiple signatures" (syntax-e dup))))) + + (let ([dup (check-duplicate-identifier + (apply append (map sig-int-names export-sigs)))]) + (when dup + (raise-stx-err (format "~a is exported by multiple signatures" + (syntax-e dup))))) + + (let ([dup (check-duplicate-identifier + (append + (apply append (map sig-int-names import-sigs)) + (apply append (map sig-int-names export-sigs))))]) + (when dup + (raise-stx-err (format "import ~a is exported" (syntax-e dup)))))) + +(define-for-syntax (process-unit-import/export process) + (lambda (s) + (define x1 (syntax->list s)) + (define x2 (map process x1)) + (values x1 x2 (map car x2) (map cadr x2) (map caddr x2)))) + +(define-for-syntax process-unit-import + (process-unit-import/export process-tagged-import)) + +(define-for-syntax process-unit-export + (process-unit-import/export process-tagged-export)) + +;; build-key : (or symbol #f) identifier -> syntax-object +(define-for-syntax (build-key tag i) + (if tag + #`(cons '#,tag #,i) + i)) + +;; tagged-info->keys : (cons (or symbol #f) siginfo) -> (listof syntax-object) +(define-for-syntax (tagged-info->keys tagged-info) + (define tag (car tagged-info)) + (map (lambda (rid) + (build-key tag (syntax-local-introduce rid))) + (siginfo-rtime-ids (cdr tagged-info)))) + diff --git a/collects/mzlib/unit.ss b/collects/mzlib/unit.ss index 9136ece466..aa5c441af4 100644 --- a/collects/mzlib/unit.ss +++ b/collects/mzlib/unit.ss @@ -12,13 +12,15 @@ (require mzlib/etc mzlib/contract mzlib/stxparam + "private/unit-contract.ss" "private/unit-keywords.ss" - "private/unit-runtime.ss") + "private/unit-runtime.ss" + "private/unit-utils.ss") (provide define-signature-form struct open define-signature provide-signature-elements only except rename import export prefix link tag init-depend extends contracted - unit? + unit? (all-from "private/unit-contract.ss") (rename :unit unit) define-unit compound-unit define-compound-unit compound-unit/infer define-compound-unit/infer invoke-unit define-values/invoke-unit @@ -362,103 +364,11 @@ 'expression (list #'stop) def-ctx)))) - - (define-for-syntax (iota n) - (let loop ((n n) - (acc null)) - (cond - ((= n 0) acc) - (else (loop (sub1 n) (cons (sub1 n) acc)))))) - - (define-syntax-rule (equal-hash-table [k v] ...) - (make-immutable-hash-table (list (cons k v) ...) 'equal)) - - (define-syntax (unit-export stx) - (syntax-case stx () - ((_ ((esig ...) elocs) ...) - (with-syntax ((((kv ...) ...) - (map - (lambda (esigs eloc) - (map - (lambda (esig) #`(#,esig #,eloc)) - (syntax->list esigs))) - (syntax->list #'((esig ...) ...)) - (syntax->list #'(elocs ...))))) - #'(equal-hash-table kv ... ...))))) - - ;; build-key : (or symbol #f) identifier -> syntax-object - (define-for-syntax (build-key tag i) - (if tag - #`(cons '#,tag #,i) - i)) - - ;; tagged-info->keys : (cons (or symbol #f) siginfo) -> (listof syntax-object) - (define-for-syntax (tagged-info->keys tagged-info) - (define tag (car tagged-info)) - (map (lambda (rid) - (build-key tag (syntax-local-introduce rid))) - (siginfo-rtime-ids (cdr tagged-info)))) - - ;; check-duplicate-sigs : (listof (cons symbol siginfo)) (listof syntax-object) - ;; (listof (cons symbol siginfo)) (listof syntax-object) -> - (define-for-syntax (check-duplicate-sigs tagged-siginfos sources tagged-deps dsources) - (define import-idx (make-hash-table 'equal)) - (for-each - (lambda (tinfo s) - (define key (cons (car tinfo) - (car (siginfo-ctime-ids (cdr tinfo))))) - (when (hash-table-get import-idx key #f) - (raise-stx-err "duplicate import signature" s)) - (hash-table-put! import-idx key #t)) - tagged-siginfos - sources) - (for-each - (lambda (dep s) - (unless (hash-table-get import-idx - (cons (car dep) - (car (siginfo-ctime-ids (cdr dep)))) - #f) - (raise-stx-err "initialization dependency on unknown import" s))) - tagged-deps - dsources)) (define-for-syntax (tagged-sigid->tagged-siginfo x) (cons (car x) (signature-siginfo (lookup-signature (cdr x))))) - (define-for-syntax (check-unit-ie-sigs import-sigs export-sigs) - (let ([dup (check-duplicate-identifier - (apply append (map sig-int-names import-sigs)))]) - (when dup - (raise-stx-err - (format "~a is imported by multiple signatures" (syntax-e dup))))) - - (let ([dup (check-duplicate-identifier - (apply append (map sig-int-names export-sigs)))]) - (when dup - (raise-stx-err (format "~a is exported by multiple signatures" - (syntax-e dup))))) - - (let ([dup (check-duplicate-identifier - (append - (apply append (map sig-int-names import-sigs)) - (apply append (map sig-int-names export-sigs))))]) - (when dup - (raise-stx-err (format "import ~a is exported" (syntax-e dup)))))) - - - (define-for-syntax (process-unit-import/export process) - (lambda (s) - (define x1 (syntax->list s)) - (define x2 (map process x1)) - (values x1 x2 (map car x2) (map cadr x2) (map caddr x2)))) - - (define-for-syntax process-unit-import - (process-unit-import/export process-tagged-import)) - - (define-for-syntax process-unit-export - (process-unit-import/export process-tagged-export)) - ;; id->contract-src-info : identifier -> syntax ;; constructs the last argument to the contract, given an identifier (define-for-syntax (id->contract-src-info id) diff --git a/collects/scribblings/reference/units.scrbl b/collects/scribblings/reference/units.scrbl index d3d7d853eb..15de5aec15 100644 --- a/collects/scribblings/reference/units.scrbl +++ b/collects/scribblings/reference/units.scrbl @@ -629,6 +629,40 @@ Expands to a @scheme[provide] of all identifiers implied by the @; ------------------------------------------------------------------------ +@section[#:tag "unitcontracts"]{Unit Contracts} + +@defform/subs[#:literals (import export) + (unit/c (import sig-block ...) (export sig-block ...)) + ([sig-block (sig-id [id ctc] ...)]) + #:contracts ([ctc contract?])]{ + +A @deftech{unit contract} wraps a unit and checks both its imported and +exported identifiers to ensure that they match the appropriate contracts. +This allows the programmer to add contract checks to a single unit value +without adding contracts to the imported and exported signatures. + +The unit value must import a subset of the import signatures and export a +superset of the export signatures listed in the unit contract. Any +identifier which is not listed for a given signature is left alone. + +Here is an example use of @scheme[unit/c]: + +@schememod[scheme/base +(require scheme/unit) +(define-signature odd^ (odd?)) +(define-signature even^ (even?)) +(define-unit E@ + (import odd^) + (export even^) + (define (even? n) + (if (zero? n) #t (odd? (sub1 n))))) +(provide/contract + [E@ (unit/c (import odd^ [odd? (-> number? boolean?)]) + (export even^ [even? (-> number? boolean?)]))]) +]} + +@; ------------------------------------------------------------------------ + @section[#:tag "single-unit"]{Single-Unit Modules} When @schememodname[scheme/unit] is used as a language name with diff --git a/collects/tests/units/test-unit-contracts.ss b/collects/tests/units/test-unit-contracts.ss index 4332adc2b1..859668c181 100644 --- a/collects/tests/units/test-unit-contracts.ss +++ b/collects/tests/units/test-unit-contracts.ss @@ -530,4 +530,66 @@ (link [((S : sig9)) unit53] [() unit55-1 S])) (test-runtime-error exn:fail:contract? "unit55-1 misuses f" - (invoke-unit unit55-2))) \ No newline at end of file + (invoke-unit unit55-2))) + +(module m1 scheme + (define-signature foo^ (x)) + (define-signature bar^ (y)) + (provide foo^ bar^) + + (define-unit U@ + (import foo^) + (export bar^) + (define (y s) + (if (eq? s 'bork) + 3 + (string-append (symbol->string s) " " (if (x 3) "was true on 3" "was not true on 3"))))) + (provide/contract [U@ (unit/c (import (foo^ [x (-> number? boolean?)])) + (export (bar^ [y (-> symbol? string?)])))])) + +(module m2 scheme + (require 'm1) + + (define x zero?) + (define-values/invoke-unit U@ + (import foo^) + (export bar^)) + + (define (z) + (y 'a)) + (define (w) + (y "foo")) + (define (v) + (y 'bork)) + + (provide z w v)) + +(require (prefix-in m2: 'm2)) + +(m2:z) +(test-runtime-error exn:fail:contract? "m2 broke the contract on U@ (string, not symbol)" (m2:w)) +(test-runtime-error exn:fail:contract? "m1 broke the contract on U@ (number, not string)" (m2:v)) + +(test-syntax-error "no y in sig1" + (unit/c (import (sig1 [y number?])) + (export))) +(test-syntax-error "two xs for sig1" + (unit/c (import) + (export (sig1 [x string?] [x number?])))) +(test-syntax-error "no sig called faux^, so import description matching fails" + (unit/c (import faux^) (export))) + +(test-runtime-error exn:fail:contract? "unit bad-export@ does not export sig1" + (let () + (define/contract bad-export@ + (unit/c (import) (export sig1)) + (unit (import) (export))) + bad-export@)) + +(test-runtime-error exn:fail:contract? "contract on bad-import@ does not export sig1" + (let () + (define/contract bad-import@ + (unit/c (import) (export)) + (unit (import sig1) (export) (+ x 1))) + bad-import@)) + \ No newline at end of file