add find-X-variance, covariant-X?, and covariant-Xs?

and also allow type constructors to declare the variance of their arguments.

infer variances for non-recursive `define-type` types
This commit is contained in:
AlexKnauth 2016-05-10 18:58:16 -04:00
parent 4347b2eaff
commit 4af0f4e2b4
4 changed files with 193 additions and 3 deletions

View File

@ -23,6 +23,9 @@
(require (prefix-in stlc+cons: (only-in "stlc+cons.rkt" list)))
(require (prefix-in stlc+tup: (only-in "stlc+tup.rkt" tup)))
(module+ test
(require (for-syntax rackunit)))
(provide →/test match2 define-type)
;; ML-like language
@ -252,6 +255,44 @@
(define (inst-types/cs/orig Xs cs tys [var=? free-identifier=?])
(stx-map (lambda (t) (inst-type/cs/orig Xs cs t var=?)) tys))
;; covariant-Xs? : Type -> Bool
;; Takes a possibly polymorphic type, and returns true if all of the
;; type variables are in covariant positions within the type, false
;; otherwise.
(define (covariant-Xs? ty)
(syntax-parse ((current-type-eval) ty)
[(~?∀ Xs ty)
(for/and ([X (in-list (syntax->list #'Xs))])
(covariant-X? X #'ty))]))
;; find-X-variance : Id Type -> Variance
;; Returns the variance of X within the type ty
(define (find-X-variance X ty)
(syntax-parse ty
[A:id #:when (free-identifier=? #'A X) covariant]
[(~Any tycons) irrelevant]
[(~?∀ () (~Any tycons τ ...))
#:when (get-arg-variances #'tycons)
#:when (stx-length=? #'[τ ...] (get-arg-variances #'tycons))
(for/fold ([acc irrelevant])
([τ (in-list (syntax->list #'[τ ...]))]
[arg-variance (in-list (get-arg-variances #'tycons))])
(variance-join
acc
(variance-compose arg-variance (find-X-variance X τ))))]
[ty #:when (not (stx-contains-id? #'ty X)) irrelevant]
[_ invariant]))
;; covariant-X? : Id Type -> Bool
;; Returns true if every place X appears in ty is a covariant position, false otherwise.
(define (covariant-X? X ty)
(variance-covariant? (find-X-variance X ty)))
;; contravariant-X? : Id Type -> Bool
;; Returns true if every place X appears in ty is a contravariant position, false otherwise.
(define (contravariant-X? X ty)
(variance-contravariant? (find-X-variance X ty)))
;; compute unbound tyvars in one unexpanded type ty
(define (compute-tyvar1 ty)
(syntax-parse ty
@ -386,12 +427,26 @@
#'(StructName ...) #'((fld ...) ...))
#:with (Cons? ...) (stx-map mk-? #'(StructName ...))
#:with (exposed-Cons? ...) (stx-map mk-? #'(Cons ...))
#:do [(define expanded-tys
(for/list ([τ (in-list (syntax->list #'[τ ... ...]))])
(with-handlers ([exn:fail:syntax? (λ (e) #false)])
((current-type-eval) #`( (X ...) #,τ)))))]
#:with [arg-variance ...]
(for/list ([i (in-range (length (syntax->list #'[X ...])))])
(for/fold ([acc irrelevant])
([ty (in-list expanded-tys)])
(cond [ty
(define/syntax-parse (~?∀ Xs τ) ty)
(define X (list-ref (syntax->list #'Xs) i))
(variance-join acc (find-X-variance X #'τ))]
[else invariant])))
#`(begin
(define-syntax (NameExtraInfo stx)
(syntax-parse stx
[(_ X ...) #'(('Cons 'StructName Cons? [acc τ] ...) ...)]))
(define-type-constructor Name
#:arity = #,(stx-length #'(X ...))
#:arg-variances (λ (stx) (list 'arg-variance ...))
#:extra-info 'NameExtraInfo
#:no-provide)
(struct StructName (fld ...) #:reflection-name 'Cons #:transparent) ...
@ -1290,3 +1345,30 @@
(cond [(eof-object? x) ""]
[(number? x) (number->string x)]
[(symbol? x) (symbol->string x)])) : String)])
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(module+ test
(begin-for-syntax
(check-true (covariant-Xs? #'Int))
(check-true (covariant-Xs? #'(stlc+box:Ref Int)))
(check-true (covariant-Xs? #'( Int Int)))
(check-true (covariant-Xs? #'( (X) X)))
(check-false (covariant-Xs? #'( (X) (stlc+box:Ref X))))
(check-false (covariant-Xs? #'( (X) ( X X))))
(check-false (covariant-Xs? #'( (X) ( X Int))))
(check-true (covariant-Xs? #'( (X) ( Int X))))
(check-true (covariant-Xs? #'( (X) ( ( X Int) X))))
(check-false (covariant-Xs? #'( (X) ( ( ( X Int) Int) X))))
(check-false (covariant-Xs? #'( (X) ( (stlc+box:Ref X) Int))))
(check-false (covariant-Xs? #'( (X Y) ( X Y))))
(check-true (covariant-Xs? #'( (X Y) ( ( X Int) Y))))
(check-false (covariant-Xs? #'( (X Y) ( ( X Int) ( Y Int)))))
(check-true (covariant-Xs? #'( (X Y) ( ( X Int) ( Int Y)))))
(check-false (covariant-Xs? #'( (A B) ( ( Int (stlc+rec-iso:× A B))
( String (stlc+rec-iso:× A B))
(stlc+rec-iso:× A B)))))
(check-true (covariant-Xs? #'( (A B) ( ( (stlc+rec-iso:× A B) Int)
( (stlc+rec-iso:× A B) String)
(stlc+rec-iso:× A B)))))
))

View File

@ -1,5 +1,7 @@
#lang s-exp "typecheck.rkt"
(extends "ext-stlc.rkt")
(require (for-syntax racket/list))
;; Simply-Typed Lambda Calculus, plus tuples
;; Types:
@ -9,7 +11,9 @@
;; - terms from ext-stlc.rkt
;; - tup and proj
(define-type-constructor × #:arity >= 0)
(define-type-constructor × #:arity >= 0
#:arg-variances (λ (stx)
(make-list (stx-length (stx-cdr stx)) covariant)))
(define-typed-syntax tup
[(_ e ...)

View File

@ -2,6 +2,8 @@
(provide (for-syntax current-type=? types=?))
(provide (for-syntax mk-app-err-msg))
(require (for-syntax racket/list))
;; Simply-Typed Lambda Calculus
;; - no base types; can't write any terms
;; Types: multi-arg → (1+)
@ -66,7 +68,13 @@
(define-syntax-category type)
(define-type-constructor #:arity >= 1)
(define-type-constructor #:arity >= 1
#:arg-variances (λ (stx)
(syntax-parse stx
[(_ τ_in ... τ_out)
(append
(make-list (stx-length #'[τ_in ...]) contravariant)
(list covariant))])))
(define-typed-syntax λ
[(_ bvs:type-ctx e)

View File

@ -15,6 +15,9 @@
"stx-utils.rkt"))
(for-meta 2 (all-from-out racket/base syntax/parse racket/syntax)))
(module+ test
(require (for-syntax rackunit)))
;; type checking functions/forms
;; General type checking strategy:
@ -420,6 +423,48 @@
(define (brack? stx)
(define paren-shape/#f (syntax-property stx 'paren-shape))
(and paren-shape/#f (char=? paren-shape/#f #\[)))
(define (iff b1 b2)
(boolean=? b1 b2))
;; Variance is (variance Boolean Boolean)
(struct variance (covariant? contravariant?) #:prefab)
(define irrelevant (variance #true #true))
(define covariant (variance #true #false))
(define contravariant (variance #false #true))
(define invariant (variance #false #false))
;; variance-irrelevant? : Variance -> Boolean
(define (variance-irrelevant? v)
(and (variance-covariant? v) (variance-contravariant? v)))
;; variance-invariant? : Variance -> Boolean
(define (variance-invariant? v)
(and (not (variance-covariant? v)) (not (variance-contravariant? v))))
;; variance-join : Variance Variance -> Variance
(define (variance-join v1 v2)
(variance (and (variance-covariant? v1)
(variance-covariant? v2))
(and (variance-contravariant? v1)
(variance-contravariant? v2))))
;; variance-compose : Variance Variance -> Variance
(define (variance-compose v1 v2)
(variance (or (variance-irrelevant? v1)
(variance-irrelevant? v2)
(and (variance-covariant? v1) (variance-covariant? v2))
(and (variance-contravariant? v1) (variance-contravariant? v2)))
(or (variance-irrelevant? v1)
(variance-irrelevant? v2)
(and (variance-covariant? v1) (variance-contravariant? v2))
(and (variance-contravariant? v1) (variance-covariant? v2)))))
;; add-arg-variances : Id (Listof Variance) -> Id
;; Takes a type constructor id and adds variance information about the arguments.
(define (add-arg-variances id arg-variances)
(set-stx-prop/preserved id 'arg-variances arg-variances))
;; get-arg-variances : Id -> (U False (Listof Variance))
;; Takes a type constructor id and gets the argument variance information.
(define (get-arg-variances id)
(syntax-property id 'arg-variances))
;; todo: abstract out the common shape of a type constructor,
;; i.e., the repeated pattern code in the functions below
(define (get-extra-info t)
@ -482,6 +527,10 @@
#:defaults ([bvs-op #'=][bvs-n #'0]))
(~optional (~seq #:arr (~and (~parse has-annotations? #'#t) tycon))
#:defaults ([tycon #'void]))
(~optional (~seq #:arg-variances arg-variances-stx:expr)
#:defaults ([arg-variances-stx
#`(λ (stx-id) (for/list ([arg (in-list (stx->list (stx-cdr stx-id)))])
invariant))]))
(~optional (~seq #:extra-info extra-info)
#:defaults ([extra-info #'void]))
(~optional (~and #:no-provide (~parse no-provide? #'#t))))
@ -532,6 +581,7 @@
#:msg
"Expected ~a type, got: ~a"
#'τ #'any))))])))
(define arg-variances arg-variances-stx)
(define (τ? t)
(and (stx-pair? t)
(syntax-parse t
@ -565,10 +615,11 @@
#:with k_result (if #,(attribute has-annotations?)
#'(tycon k_arg (... ...))
#'#%kind)
#:with τ-internal* (add-arg-variances #'τ-internal (arg-variances stx))
(add-orig
(assign-type
(syntax/loc stx
(τ-internal (λ bvs- (#%expression extra-info) . τs-)))
(τ-internal* (λ bvs- (#%expression extra-info) . τs-)))
#'k_result)
#'(τ . args))]
;; else fail with err msg
@ -701,3 +752,48 @@
(define (substs τs xs e [cmp bound-identifier=?])
(stx-fold (lambda (ty x res) (subst ty x res cmp)) e τs xs)))
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(module+ test
(begin-for-syntax
(test-case "variance-join"
(test-case "joining with irrelevant doesn't change it"
(check-equal? (variance-join irrelevant irrelevant) irrelevant)
(check-equal? (variance-join irrelevant covariant) covariant)
(check-equal? (variance-join irrelevant contravariant) contravariant)
(check-equal? (variance-join irrelevant invariant) invariant))
(test-case "joining with invariant results in invariant"
(check-equal? (variance-join invariant irrelevant) invariant)
(check-equal? (variance-join invariant covariant) invariant)
(check-equal? (variance-join invariant contravariant) invariant)
(check-equal? (variance-join invariant invariant) invariant))
(test-case "joining a with a results in a"
(check-equal? (variance-join irrelevant irrelevant) irrelevant)
(check-equal? (variance-join covariant covariant) covariant)
(check-equal? (variance-join contravariant contravariant) contravariant)
(check-equal? (variance-join invariant invariant) invariant))
(test-case "joining covariant with contravariant results in invariant"
(check-equal? (variance-join covariant contravariant) invariant)
(check-equal? (variance-join contravariant covariant) invariant)))
(test-case "variance-compose"
(test-case "composing with covariant doesn't change it"
(check-equal? (variance-compose covariant irrelevant) irrelevant)
(check-equal? (variance-compose covariant covariant) covariant)
(check-equal? (variance-compose covariant contravariant) contravariant)
(check-equal? (variance-compose covariant invariant) invariant))
(test-case "composing with irrelevant results in irrelevant"
(check-equal? (variance-compose irrelevant irrelevant) irrelevant)
(check-equal? (variance-compose irrelevant covariant) irrelevant)
(check-equal? (variance-compose irrelevant contravariant) irrelevant)
(check-equal? (variance-compose irrelevant invariant) irrelevant))
(test-case "otherwise composing with invariant results in invariant"
(check-equal? (variance-compose invariant covariant) invariant)
(check-equal? (variance-compose invariant contravariant) invariant)
(check-equal? (variance-compose invariant invariant) invariant))
(test-case "composing with with contravariant flips covariant and contravariant"
(check-equal? (variance-compose contravariant covariant) contravariant)
(check-equal? (variance-compose contravariant contravariant) covariant)
(check-equal? (variance-compose contravariant irrelevant) irrelevant)
(check-equal? (variance-compose contravariant invariant) invariant)))
))