diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt index 406d22a..751f829 100644 --- a/tapl/mlish.rkt +++ b/tapl/mlish.rkt @@ -23,6 +23,9 @@ (require (prefix-in stlc+cons: (only-in "stlc+cons.rkt" list))) (require (prefix-in stlc+tup: (only-in "stlc+tup.rkt" tup))) +(module+ test + (require (for-syntax rackunit))) + (provide → →/test match2 define-type) ;; ML-like language @@ -252,6 +255,44 @@ (define (inst-types/cs/orig Xs cs tys [var=? free-identifier=?]) (stx-map (lambda (t) (inst-type/cs/orig Xs cs t var=?)) tys)) + ;; covariant-Xs? : Type -> Bool + ;; Takes a possibly polymorphic type, and returns true if all of the + ;; type variables are in covariant positions within the type, false + ;; otherwise. + (define (covariant-Xs? ty) + (syntax-parse ((current-type-eval) ty) + [(~?∀ Xs ty) + (for/and ([X (in-list (syntax->list #'Xs))]) + (covariant-X? X #'ty))])) + + ;; find-X-variance : Id Type -> Variance + ;; Returns the variance of X within the type ty + (define (find-X-variance X ty) + (syntax-parse ty + [A:id #:when (free-identifier=? #'A X) covariant] + [(~Any tycons) irrelevant] + [(~?∀ () (~Any tycons τ ...)) + #:when (get-arg-variances #'tycons) + #:when (stx-length=? #'[τ ...] (get-arg-variances #'tycons)) + (for/fold ([acc irrelevant]) + ([τ (in-list (syntax->list #'[τ ...]))] + [arg-variance (in-list (get-arg-variances #'tycons))]) + (variance-join + acc + (variance-compose arg-variance (find-X-variance X τ))))] + [ty #:when (not (stx-contains-id? #'ty X)) irrelevant] + [_ invariant])) + + ;; covariant-X? : Id Type -> Bool + ;; Returns true if every place X appears in ty is a covariant position, false otherwise. + (define (covariant-X? X ty) + (variance-covariant? (find-X-variance X ty))) + + ;; contravariant-X? : Id Type -> Bool + ;; Returns true if every place X appears in ty is a contravariant position, false otherwise. + (define (contravariant-X? X ty) + (variance-contravariant? (find-X-variance X ty))) + ;; compute unbound tyvars in one unexpanded type ty (define (compute-tyvar1 ty) (syntax-parse ty @@ -386,12 +427,26 @@ #'(StructName ...) #'((fld ...) ...)) #:with (Cons? ...) (stx-map mk-? #'(StructName ...)) #:with (exposed-Cons? ...) (stx-map mk-? #'(Cons ...)) + #:do [(define expanded-tys + (for/list ([τ (in-list (syntax->list #'[τ ... ...]))]) + (with-handlers ([exn:fail:syntax? (λ (e) #false)]) + ((current-type-eval) #`(∀ (X ...) #,τ)))))] + #:with [arg-variance ...] + (for/list ([i (in-range (length (syntax->list #'[X ...])))]) + (for/fold ([acc irrelevant]) + ([ty (in-list expanded-tys)]) + (cond [ty + (define/syntax-parse (~?∀ Xs τ) ty) + (define X (list-ref (syntax->list #'Xs) i)) + (variance-join acc (find-X-variance X #'τ))] + [else invariant]))) #`(begin (define-syntax (NameExtraInfo stx) (syntax-parse stx [(_ X ...) #'(('Cons 'StructName Cons? [acc τ] ...) ...)])) (define-type-constructor Name #:arity = #,(stx-length #'(X ...)) + #:arg-variances (λ (stx) (list 'arg-variance ...)) #:extra-info 'NameExtraInfo #:no-provide) (struct StructName (fld ...) #:reflection-name 'Cons #:transparent) ... @@ -1290,3 +1345,30 @@ (cond [(eof-object? x) ""] [(number? x) (number->string x)] [(symbol? x) (symbol->string x)])) : String)]) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(module+ test + (begin-for-syntax + (check-true (covariant-Xs? #'Int)) + (check-true (covariant-Xs? #'(stlc+box:Ref Int))) + (check-true (covariant-Xs? #'(→ Int Int))) + (check-true (covariant-Xs? #'(∀ (X) X))) + (check-false (covariant-Xs? #'(∀ (X) (stlc+box:Ref X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ X X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ X Int)))) + (check-true (covariant-Xs? #'(∀ (X) (→ Int X)))) + (check-true (covariant-Xs? #'(∀ (X) (→ (→ X Int) X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ (→ (→ X Int) Int) X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ (stlc+box:Ref X) Int)))) + (check-false (covariant-Xs? #'(∀ (X Y) (→ X Y)))) + (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) Y)))) + (check-false (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Y Int))))) + (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Int Y))))) + (check-false (covariant-Xs? #'(∀ (A B) (→ (→ Int (stlc+rec-iso:× A B)) + (→ String (stlc+rec-iso:× A B)) + (stlc+rec-iso:× A B))))) + (check-true (covariant-Xs? #'(∀ (A B) (→ (→ (stlc+rec-iso:× A B) Int) + (→ (stlc+rec-iso:× A B) String) + (stlc+rec-iso:× A B))))) + )) diff --git a/tapl/stlc+tup.rkt b/tapl/stlc+tup.rkt index 03e3726..9d08d13 100644 --- a/tapl/stlc+tup.rkt +++ b/tapl/stlc+tup.rkt @@ -1,5 +1,7 @@ #lang s-exp "typecheck.rkt" (extends "ext-stlc.rkt") + +(require (for-syntax racket/list)) ;; Simply-Typed Lambda Calculus, plus tuples ;; Types: @@ -9,7 +11,9 @@ ;; - terms from ext-stlc.rkt ;; - tup and proj -(define-type-constructor × #:arity >= 0) +(define-type-constructor × #:arity >= 0 + #:arg-variances (λ (stx) + (make-list (stx-length (stx-cdr stx)) covariant))) (define-typed-syntax tup [(_ e ...) diff --git a/tapl/stlc.rkt b/tapl/stlc.rkt index e5c104d..22b8145 100644 --- a/tapl/stlc.rkt +++ b/tapl/stlc.rkt @@ -2,6 +2,8 @@ (provide (for-syntax current-type=? types=?)) (provide (for-syntax mk-app-err-msg)) +(require (for-syntax racket/list)) + ;; Simply-Typed Lambda Calculus ;; - no base types; can't write any terms ;; Types: multi-arg → (1+) @@ -66,7 +68,13 @@ (define-syntax-category type) -(define-type-constructor → #:arity >= 1) +(define-type-constructor → #:arity >= 1 + #:arg-variances (λ (stx) + (syntax-parse stx + [(_ τ_in ... τ_out) + (append + (make-list (stx-length #'[τ_in ...]) contravariant) + (list covariant))]))) (define-typed-syntax λ [(_ bvs:type-ctx e) diff --git a/tapl/typecheck.rkt b/tapl/typecheck.rkt index 1391e15..c1db34d 100644 --- a/tapl/typecheck.rkt +++ b/tapl/typecheck.rkt @@ -15,6 +15,9 @@ "stx-utils.rkt")) (for-meta 2 (all-from-out racket/base syntax/parse racket/syntax))) +(module+ test + (require (for-syntax rackunit))) + ;; type checking functions/forms ;; General type checking strategy: @@ -420,6 +423,48 @@ (define (brack? stx) (define paren-shape/#f (syntax-property stx 'paren-shape)) (and paren-shape/#f (char=? paren-shape/#f #\[))) + + (define (iff b1 b2) + (boolean=? b1 b2)) + + ;; Variance is (variance Boolean Boolean) + (struct variance (covariant? contravariant?) #:prefab) + (define irrelevant (variance #true #true)) + (define covariant (variance #true #false)) + (define contravariant (variance #false #true)) + (define invariant (variance #false #false)) + ;; variance-irrelevant? : Variance -> Boolean + (define (variance-irrelevant? v) + (and (variance-covariant? v) (variance-contravariant? v))) + ;; variance-invariant? : Variance -> Boolean + (define (variance-invariant? v) + (and (not (variance-covariant? v)) (not (variance-contravariant? v)))) + ;; variance-join : Variance Variance -> Variance + (define (variance-join v1 v2) + (variance (and (variance-covariant? v1) + (variance-covariant? v2)) + (and (variance-contravariant? v1) + (variance-contravariant? v2)))) + ;; variance-compose : Variance Variance -> Variance + (define (variance-compose v1 v2) + (variance (or (variance-irrelevant? v1) + (variance-irrelevant? v2) + (and (variance-covariant? v1) (variance-covariant? v2)) + (and (variance-contravariant? v1) (variance-contravariant? v2))) + (or (variance-irrelevant? v1) + (variance-irrelevant? v2) + (and (variance-covariant? v1) (variance-contravariant? v2)) + (and (variance-contravariant? v1) (variance-covariant? v2))))) + + ;; add-arg-variances : Id (Listof Variance) -> Id + ;; Takes a type constructor id and adds variance information about the arguments. + (define (add-arg-variances id arg-variances) + (set-stx-prop/preserved id 'arg-variances arg-variances)) + ;; get-arg-variances : Id -> (U False (Listof Variance)) + ;; Takes a type constructor id and gets the argument variance information. + (define (get-arg-variances id) + (syntax-property id 'arg-variances)) + ;; todo: abstract out the common shape of a type constructor, ;; i.e., the repeated pattern code in the functions below (define (get-extra-info t) @@ -482,6 +527,10 @@ #:defaults ([bvs-op #'=][bvs-n #'0])) (~optional (~seq #:arr (~and (~parse has-annotations? #'#t) tycon)) #:defaults ([tycon #'void])) + (~optional (~seq #:arg-variances arg-variances-stx:expr) + #:defaults ([arg-variances-stx + #`(λ (stx-id) (for/list ([arg (in-list (stx->list (stx-cdr stx-id)))]) + invariant))])) (~optional (~seq #:extra-info extra-info) #:defaults ([extra-info #'void])) (~optional (~and #:no-provide (~parse no-provide? #'#t)))) @@ -532,6 +581,7 @@ #:msg "Expected ~a type, got: ~a" #'τ #'any))))]))) + (define arg-variances arg-variances-stx) (define (τ? t) (and (stx-pair? t) (syntax-parse t @@ -565,10 +615,11 @@ #:with k_result (if #,(attribute has-annotations?) #'(tycon k_arg (... ...)) #'#%kind) + #:with τ-internal* (add-arg-variances #'τ-internal (arg-variances stx)) (add-orig (assign-type (syntax/loc stx - (τ-internal (λ bvs- (#%expression extra-info) . τs-))) + (τ-internal* (λ bvs- (#%expression extra-info) . τs-))) #'k_result) #'(τ . args))] ;; else fail with err msg @@ -701,3 +752,48 @@ (define (substs τs xs e [cmp bound-identifier=?]) (stx-fold (lambda (ty x res) (subst ty x res cmp)) e τs xs))) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(module+ test + (begin-for-syntax + (test-case "variance-join" + (test-case "joining with irrelevant doesn't change it" + (check-equal? (variance-join irrelevant irrelevant) irrelevant) + (check-equal? (variance-join irrelevant covariant) covariant) + (check-equal? (variance-join irrelevant contravariant) contravariant) + (check-equal? (variance-join irrelevant invariant) invariant)) + (test-case "joining with invariant results in invariant" + (check-equal? (variance-join invariant irrelevant) invariant) + (check-equal? (variance-join invariant covariant) invariant) + (check-equal? (variance-join invariant contravariant) invariant) + (check-equal? (variance-join invariant invariant) invariant)) + (test-case "joining a with a results in a" + (check-equal? (variance-join irrelevant irrelevant) irrelevant) + (check-equal? (variance-join covariant covariant) covariant) + (check-equal? (variance-join contravariant contravariant) contravariant) + (check-equal? (variance-join invariant invariant) invariant)) + (test-case "joining covariant with contravariant results in invariant" + (check-equal? (variance-join covariant contravariant) invariant) + (check-equal? (variance-join contravariant covariant) invariant))) + (test-case "variance-compose" + (test-case "composing with covariant doesn't change it" + (check-equal? (variance-compose covariant irrelevant) irrelevant) + (check-equal? (variance-compose covariant covariant) covariant) + (check-equal? (variance-compose covariant contravariant) contravariant) + (check-equal? (variance-compose covariant invariant) invariant)) + (test-case "composing with irrelevant results in irrelevant" + (check-equal? (variance-compose irrelevant irrelevant) irrelevant) + (check-equal? (variance-compose irrelevant covariant) irrelevant) + (check-equal? (variance-compose irrelevant contravariant) irrelevant) + (check-equal? (variance-compose irrelevant invariant) irrelevant)) + (test-case "otherwise composing with invariant results in invariant" + (check-equal? (variance-compose invariant covariant) invariant) + (check-equal? (variance-compose invariant contravariant) invariant) + (check-equal? (variance-compose invariant invariant) invariant)) + (test-case "composing with with contravariant flips covariant and contravariant" + (check-equal? (variance-compose contravariant covariant) contravariant) + (check-equal? (variance-compose contravariant contravariant) covariant) + (check-equal? (variance-compose contravariant irrelevant) irrelevant) + (check-equal? (variance-compose contravariant invariant) invariant))) + ))