diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt index 8056881..e5584eb 100644 --- a/tapl/mlish.rkt +++ b/tapl/mlish.rkt @@ -6,7 +6,7 @@ ;(reuse [inst sysf:inst] #:from "sysf.rkt") (require (rename-in (only-in "sysf.rkt" inst) [inst sysf:inst])) (provide inst) -(require (only-in "ext-stlc.rkt" →?)) +(require (only-in "ext-stlc.rkt" → →?)) (require (only-in "sysf.rkt" ~∀ ∀ ∀? Λ)) (reuse × tup proj define-type-alias #:from "stlc+rec-iso.rkt") (require (only-in "stlc+rec-iso.rkt" ~× ×?)) ; using current-type=? from here @@ -23,6 +23,9 @@ (require (prefix-in stlc+cons: (only-in "stlc+cons.rkt" list))) (require (prefix-in stlc+tup: (only-in "stlc+tup.rkt" tup))) +(module+ test + (require (for-syntax rackunit))) + (provide → →/test match2 define-type) ;; ML-like language @@ -31,6 +34,25 @@ ;; - pattern matching ;; - (local) type inference +;; creating possibly polymorphic types +;; ?∀ only wraps a type in a forall if there's at least one type variable +(define-syntax ?∀ + (lambda (stx) + (syntax-case stx () + [(?∀ () body) + #'body] + [(?∀ (X ...) body) + #'(∀ (X ...) body)]))) + +;; ?Λ only wraps an expression in a Λ if there's at least one type variable +(define-syntax ?Λ + (lambda (stx) + (syntax-case stx () + [(?Λ () body) + #'body] + [(?Λ (X ...) body) + #'(Λ (X ...) body)]))) + (begin-for-syntax ;; matching possibly polymorphic types (define-syntax ~?∀ @@ -43,26 +65,91 @@ (~parse vars-pat #'()) body-pat))])))) - ;; type inference constraint solving - (define (compute-constraint τ1-τ2) - (syntax-parse τ1-τ2 - [(X:id τ) #'((X τ))] - [((~Any tycons1 τ1 ...) (~Any tycons2 τ2 ...)) - #:when (typecheck? #'tycons1 #'tycons2) - (compute-constraints #'((τ1 τ2) ...))] - ; should only be monomorphic? - [((~∀ () (~ext-stlc:→ τ1 ...)) (~∀ () (~ext-stlc:→ τ2 ...))) - (compute-constraints #'((τ1 τ2) ...))] - [_ #'()])) - (define (compute-constraints τs) - (stx-appendmap compute-constraint τs)) - - (define (solve-constraint x-τ) - (syntax-parse x-τ - [(X:id τ) #'((X τ))] - [_ #'()])) - (define (solve-constraints cs) - (stx-appendmap compute-constraint cs)) + ;; add-constraints : + ;; (Listof Id) (Listof (List Id Type)) (Stx-Listof (Stx-List Stx Stx)) -> (Listof (List Id Type)) + ;; Adds a new set of constaints to a substituion, using the type + ;; unification algorithm for local type inference. + (define (add-constraints Xs substs new-cs [orig-cs new-cs]) + (define Xs* (stx->list Xs)) + (define Ys (stx-map stx-car substs)) + (define-syntax-class var + [pattern x:id #:when (member #'x Xs* free-identifier=?)]) + (syntax-parse new-cs + [() substs] + [([a:var b] . rst) + (cond + [(member #'a Ys free-identifier=?) + ;; There are two cases. + ;; Either #'a already maps to #'b or an equivalent type, + ;; or #'a already maps to a type that conflicts with #'b. + ;; In either case, whatever #'a maps to must be equivalent + ;; to #'b, so add that to the constraints. + (add-constraints + Xs + substs + (cons (list (lookup #'a substs) #'b) + #'rst) + orig-cs)] + [else + (add-constraints + Xs* + ;; Add the mapping #'a -> #'b to the substitution, + (cons (list #'a #'b) + (for/list ([subst (in-list (stx->list substs))]) + (list (stx-car subst) + (inst-type (list #'b) (list #'a) (stx-cadr subst))))) + ;; and substitute that in each of the constraints. + (for/list ([c (in-list (syntax->list #'rst))]) + (list (inst-type (list #'b) (list #'a) (stx-car c)) + (inst-type (list #'b) (list #'a) (stx-cadr c)))) + orig-cs)])] + [([a b:var] . rst) + (add-constraints Xs* + substs + #'([b a] . rst) + orig-cs)] + [([a b] . rst) + ;; If #'a and #'b are base types, check that they're equal. + ;; Identifers not within Xs count as base types. + ;; If #'a and #'b are constructed types, check that the + ;; constructors are the same, add the sub-constraints, and + ;; recur. + ;; Otherwise, raise an error. + (cond + [(identifier? #'a) + ;; #'a is an identifier, but not a var, so it is considered + ;; a base type. We also know #'b is not a var, so #'b has + ;; to be the same "identifier base type" as #'a. + (unless (and (identifier? #'b) (free-identifier=? #'a #'b)) + (type-error #:src (get-orig #'a) + #:msg (format "couldn't unify ~~a and ~~a\n expected: ~a\n given: ~a" + (string-join (map type->str (stx-map stx-car orig-cs)) ", ") + (string-join (map type->str (stx-map stx-cadr orig-cs)) ", ")) + #'a #'b)) + (add-constraints Xs* + substs + #'rst + orig-cs)] + [else + (syntax-parse #'[a b] + [_ + #:when (typecheck? #'a #'b) + (add-constraints Xs + substs + #'rst + orig-cs)] + [((~Any tycons1 τ1 ...) (~Any tycons2 τ2 ...)) + #:when (typecheck? #'tycons1 #'tycons2) + (add-constraints Xs + substs + #'((τ1 τ2) ... . rst) + orig-cs)] + [else + (type-error #:src (get-orig #'a) + #:msg (format "couldn't unify ~~a and ~~a\n expected: ~a\n given: ~a" + (string-join (map type->str (stx-map stx-car orig-cs)) ", ") + (string-join (map type->str (stx-map stx-cadr orig-cs)) ", ")) + #'a #'b)])])])) (define (lookup x substs) (syntax-parse substs @@ -72,11 +159,11 @@ [(_ . rst) (lookup x #'rst)] [() #f])) - ;; find-unsolved-Xs : (Stx-Listof Id) Constraints -> (Listof Id) - ;; finds the free Xs that aren't constrained by cs - (define (find-unsolved-Xs Xs cs) + ;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id) + ;; finds the free Xs in the type + (define (find-free-Xs Xs ty) (for/list ([X (in-list (stx->list Xs))] - #:when (not (lookup X cs))) + #:when (stx-contains-id? ty X)) X)) ;; lookup-Xs/keep-unsolved : (Stx-Listof Id) Constraints -> (Listof Type-Stx) @@ -90,33 +177,43 @@ ;; tyXs = input and output types from fn type ;; ie (typeof e_fn) = (-> . tyXs) ;; It infers the types of arguments from left-to-right, - ;; and it short circuits if it's done early. + ;; and it expands and returns all of the arguments. ;; It returns list of 3 values if successful, else throws a type error - ;; - a list of the arguments that it expanded - ;; - a list of the the un-constrained type variables + ;; - a list of all the arguments, expanded + ;; - a list of all the type variables ;; - the constraints for substituting the types (define (solve Xs tyXs stx) (syntax-parse tyXs [(τ_inX ... τ_outX) ;; generate initial constraints with expected type and τ_outX - #:with expected-ty (get-expected-type stx) + #:with (~?∀ Vs expected-ty) (and (get-expected-type stx) + ((current-type-eval) (get-expected-type stx))) (define initial-cs - (if (syntax-e #'expected-ty) - (compute-constraint (list #'τ_outX ((current-type-eval) #'expected-ty))) + (if (and (syntax-e #'expected-ty) (stx-null? #'Vs)) + (add-constraints Xs '() (list (list #'expected-ty #'τ_outX))) #'())) (syntax-parse stx [(_ e_fn . args) (define-values (as- cs) (for/fold ([as- null] [cs initial-cs]) ([a (in-list (syntax->list #'args))] - [tyXin (in-list (syntax->list #'(τ_inX ...)))] - #:break (empty? (find-unsolved-Xs Xs cs))) - (define/with-syntax [a- ty_a] (infer+erase a)) + [tyXin (in-list (syntax->list #'(τ_inX ...)))]) + (define ty_in (inst-type/cs Xs cs tyXin)) + (define/with-syntax [a- ty_a] + (infer+erase (if (empty? (find-free-Xs Xs ty_in)) + (add-expected-ty a ty_in) + a))) (values (cons #'a- as-) - (stx-append cs (compute-constraint (list tyXin #'ty_a)))))) + (add-constraints Xs cs (list (list ty_in #'ty_a)) + (list (list (inst-type/cs/orig + Xs cs ty_in + (λ (id1 id2) + (equal? (syntax->datum id1) + (syntax->datum id2)))) + #'ty_a)))))) - (list (reverse as-) (find-unsolved-Xs Xs cs) cs)])])) + (list (reverse as-) Xs cs)])])) (define (raise-app-poly-infer-error stx expected-tys given-tys e_fn) (type-error #:src stx @@ -130,6 +227,11 @@ ;; identifier in Xs is associated with the ith type in tys-solved (define (inst-type tys-solved Xs ty) (substs tys-solved Xs ty)) + ;; inst-type/orig : (Listof Type) (Listof Id) Type (Id Id -> Bool) -> Type + ;; like inst-type, but also substitutes within the orig property + (define (inst-type/orig tys-solved Xs ty [var=? free-identifier=?]) + (add-orig (inst-type tys-solved Xs ty) + (substs (stx-map get-orig tys-solved) Xs (get-orig ty) var=?))) ;; inst-type/cs : (Stx-Listof Id) Constraints Type-Stx -> Type-Stx ;; Instantiates ty, substituting each identifier in Xs with its mapping in cs. @@ -141,6 +243,56 @@ (define (inst-types/cs Xs cs tys) (stx-map (lambda (t) (inst-type/cs Xs cs t)) tys)) + ;; inst-type/cs/orig : + ;; (Stx-Listof Id) Constraints Type-Stx (Id Id -> Bool) -> Type-Stx + ;; like inst-type/cs, but also substitutes within the orig property + (define (inst-type/cs/orig Xs cs ty [var=? free-identifier=?]) + (define tys-solved (lookup-Xs/keep-unsolved Xs cs)) + (inst-type/orig tys-solved Xs ty var=?)) + ;; inst-types/cs/orig : + ;; (Stx-Listof Id) Constraints (Stx-Listof Type-Stx) (Id Id -> Bool) -> (Listof Type-Stx) + ;; the plural version of inst-type/cs/orig + (define (inst-types/cs/orig Xs cs tys [var=? free-identifier=?]) + (stx-map (lambda (t) (inst-type/cs/orig Xs cs t var=?)) tys)) + + ;; covariant-Xs? : Type -> Bool + ;; Takes a possibly polymorphic type, and returns true if all of the + ;; type variables are in covariant positions within the type, false + ;; otherwise. + (define (covariant-Xs? ty) + (syntax-parse ((current-type-eval) ty) + [(~?∀ Xs ty) + (for/and ([X (in-list (syntax->list #'Xs))]) + (covariant-X? X #'ty))])) + + ;; find-X-variance : Id Type -> Variance + ;; Returns the variance of X within the type ty + (define (find-X-variance X ty) + (syntax-parse ty + [A:id #:when (free-identifier=? #'A X) covariant] + [(~Any tycons) irrelevant] + [(~?∀ () (~Any tycons τ ...)) + #:when (get-arg-variances #'tycons) + #:when (stx-length=? #'[τ ...] (get-arg-variances #'tycons)) + (for/fold ([acc irrelevant]) + ([τ (in-list (syntax->list #'[τ ...]))] + [arg-variance (in-list (get-arg-variances #'tycons))]) + (variance-join + acc + (variance-compose arg-variance (find-X-variance X τ))))] + [ty #:when (not (stx-contains-id? #'ty X)) irrelevant] + [_ invariant])) + + ;; covariant-X? : Id Type -> Bool + ;; Returns true if every place X appears in ty is a covariant position, false otherwise. + (define (covariant-X? X ty) + (variance-covariant? (find-X-variance X ty))) + + ;; contravariant-X? : Id Type -> Bool + ;; Returns true if every place X appears in ty is a contravariant position, false otherwise. + (define (contravariant-X? X ty) + (variance-contravariant? (find-X-variance X ty))) + ;; compute unbound tyvars in one unexpanded type ty (define (compute-tyvar1 ty) (syntax-parse ty @@ -182,8 +334,8 @@ ;; TODO: check that specified return type is correct ;; - currently cannot do it here; to do the check here, need all types of ;; top-lvl fns, since they can call each other - #:with (~and ty_fn_expected (~∀ _ (~ext-stlc:→ _ ... out_expected))) - ((current-type-eval) #'(∀ Ys (ext-stlc:→ τ+orig ...))) + #:with (~and ty_fn_expected (~?∀ _ (~ext-stlc:→ _ ... out_expected))) + ((current-type-eval) #'(?∀ Ys (ext-stlc:→ τ+orig ...))) #`(begin (define-syntax f (make-rename-transformer (⊢ g : ty_fn_expected))) (define g @@ -200,15 +352,15 @@ ;; TODO: check that specified return type is correct ;; - currently cannot do it here; to do the check here, need all types of ;; top-lvl fns, since they can call each other - #:with (~and ty_fn_expected (~∀ _ (~ext-stlc:→ _ ... out_expected))) + #:with (~and ty_fn_expected (~?∀ _ (~ext-stlc:→ _ ... out_expected))) (set-stx-prop/preserved - ((current-type-eval) #'(∀ Ys (ext-stlc:→ τ+orig ...))) + ((current-type-eval) #'(?∀ Ys (ext-stlc:→ τ+orig ...))) 'orig (list #'(→ τ+orig ...))) #`(begin (define-syntax f (make-rename-transformer (⊢ g : ty_fn_expected))) (define g - (Λ Ys (ext-stlc:λ ([x : τ] ...) (ext-stlc:begin e_body ... e_ann)))))]) + (?Λ Ys (ext-stlc:λ ([x : τ] ...) (ext-stlc:begin e_body ... e_ann)))))]) ;; define-type ----------------------------------------------- ;; TODO: should validate τ as part of define-type definition (before it's used) @@ -275,30 +427,44 @@ #'(StructName ...) #'((fld ...) ...)) #:with (Cons? ...) (stx-map mk-? #'(StructName ...)) #:with (exposed-Cons? ...) (stx-map mk-? #'(Cons ...)) + #:do [(define expanded-tys + (for/list ([τ (in-list (syntax->list #'[τ ... ...]))]) + (with-handlers ([exn:fail:syntax? (λ (e) #false)]) + ((current-type-eval) #`(∀ (X ...) #,τ)))))] + #:with [arg-variance ...] + (for/list ([i (in-range (length (syntax->list #'[X ...])))]) + (for/fold ([acc irrelevant]) + ([ty (in-list expanded-tys)]) + (cond [ty + (define/syntax-parse (~?∀ Xs τ) ty) + (define X (list-ref (syntax->list #'Xs) i)) + (variance-join acc (find-X-variance X #'τ))] + [else invariant]))) #`(begin (define-syntax (NameExtraInfo stx) (syntax-parse stx [(_ X ...) #'(('Cons 'StructName Cons? [acc τ] ...) ...)])) (define-type-constructor Name #:arity = #,(stx-length #'(X ...)) + #:arg-variances (λ (stx) (list 'arg-variance ...)) #:extra-info 'NameExtraInfo #:no-provide) (struct StructName (fld ...) #:reflection-name 'Cons #:transparent) ... (define-syntax (exposed-acc stx) ; accessor for records (syntax-parse stx - [_:id (⊢ acc (∀ (X ...) (ext-stlc:→ (Name X ...) τ)))] + [_:id (⊢ acc (?∀ (X ...) (ext-stlc:→ (Name X ...) τ)))] [(o . rst) ; handle if used in fn position #:with app (datum->syntax #'o '#%app) #`(app - #,(assign-type #'acc #'(∀ (X ...) (ext-stlc:→ (Name X ...) τ))) + #,(assign-type #'acc #'(?∀ (X ...) (ext-stlc:→ (Name X ...) τ))) . rst)])) ... ... (define-syntax (exposed-Cons? stx) ; predicates for each variant (syntax-parse stx - [_:id (⊢ Cons? (∀ (X ...) (ext-stlc:→ (Name X ...) Bool)))] + [_:id (⊢ Cons? (?∀ (X ...) (ext-stlc:→ (Name X ...) Bool)))] [(o . rst) ; handle if used in fn position #:with app (datum->syntax #'o '#%app) #`(app - #,(assign-type #'Cons? #'(∀ (X ...) (ext-stlc:→ (Name X ...) Bool))) + #,(assign-type #'Cons? #'(?∀ (X ...) (ext-stlc:→ (Name X ...) Bool))) . rst)])) ... (define-syntax (Cons stx) (syntax-parse stx @@ -319,7 +485,7 @@ (current-continuation-marks))) #:with (NameExpander τ-expected-arg (... ...)) ((current-type-eval) #'τ-expected) #'(C {τ-expected-arg (... ...)})] - [_:id (⊢ StructName (∀ (X ...) (ext-stlc:→ τ ... (Name X ...))))] ; HO fn + [_:id (⊢ StructName (?∀ (X ...) (ext-stlc:→ τ ... (Name X ...))))] ; HO fn [(C τs e_arg ...) #:when (brace? #'τs) ; commit to this clause #:with {~! τ_X:type (... ...)} #'τs @@ -340,7 +506,7 @@ [(C . args) ; no type annotations, must infer instantiation #:with StructName/ty (set-stx-prop/preserved - (⊢ StructName : (∀ (X ...) (ext-stlc:→ τ ... (Name X ...)))) + (⊢ StructName : (?∀ (X ...) (ext-stlc:→ τ ... (Name X ...)))) 'orig (list #'C)) ; stx/loc transfers expected-type @@ -651,19 +817,16 @@ (let ([x- (acc z)] ...) e_c-)] ...)) : τ_out)])])]) -(define-syntax → ; wrapping → - (syntax-parser - [(_ . rst) (set-stx-prop/preserved #'(∀ () (ext-stlc:→ . rst)) 'orig (list #'(→ . rst)))])) ; special arrow that computes free vars; for use with tests ; (because we can't write explicit forall (define-syntax →/test (syntax-parser [(_ (~and Xs (X:id ...)) . rst) #:when (brace? #'Xs) - #'(∀ (X ...) (ext-stlc:→ . rst))] + #'(?∀ (X ...) (ext-stlc:→ . rst))] [(_ . rst) #:with Xs (compute-tyvars #'rst) - #'(∀ Xs (ext-stlc:→ . rst))])) + #'(?∀ Xs (ext-stlc:→ . rst))])) ; redefine these to use lifted → (define-primop + : (→ Int Int Int)) @@ -685,7 +848,7 @@ (define-primop even? : (→ Int Bool)) (define-primop odd? : (→ Int Bool)) -; all λs have type (∀ (X ...) (→ τ_in ... τ_out)), even monomorphic fns +; all λs have type (?∀ (X ...) (→ τ_in ... τ_out)) (define-typed-syntax liftedλ #:export-as λ [(_ (x:id ...+) body) #:with (~?∀ Xs expected) (get-expected-type stx) @@ -696,21 +859,21 @@ (type-error #:src stx #:msg (format "expected a function of ~a arguments, got one with ~a arguments" (stx-length #'[arg-ty ...] #'[x ...]))))] - #`(Λ Xs (ext-stlc:λ ([x : arg-ty] ...) #,(add-expected-ty #'body #'body-ty)))] + #`(?Λ Xs (ext-stlc:λ ([x : arg-ty] ...) #,(add-expected-ty #'body #'body-ty)))] [(_ args body) #:with (~?∀ () (~ext-stlc:→ arg-ty ... body-ty)) (get-expected-type stx) - #`(Λ () (ext-stlc:λ args #,(add-expected-ty #'body #'body-ty)))] + #`(?Λ () (ext-stlc:λ args #,(add-expected-ty #'body #'body-ty)))] [(_ (~and x+tys ([_ (~datum :) ty] ...)) . body) #:with Xs (compute-tyvars #'(ty ...)) ;; TODO is there a way to have λs that refer to ids defined after them? - #'(Λ Xs (ext-stlc:λ x+tys . body))]) + #'(?Λ Xs (ext-stlc:λ x+tys . body))]) ;; #%app -------------------------------------------------- (define-typed-syntax mlish:#%app #:export-as #%app [(_ e_fn . e_args) ;; ) compute fn type (ie ∀ and →) - #:with [e_fn- (~∀ Xs (~ext-stlc:→ . tyX_args))] (infer+erase #'e_fn) + #:with [e_fn- (~?∀ Xs (~ext-stlc:→ . tyX_args))] (infer+erase #'e_fn) (cond [(stx-null? #'Xs) (syntax-parse #'(e_args tyX_args) @@ -722,22 +885,17 @@ #'(ext-stlc:#%app e_fn/ty (add-expected e_arg τ_inX) ...)])] [else ;; ) solve for type variables Xs - (define/with-syntax ((e_arg1- ...) (unsolved-X ...) cs) (solve #'Xs #'tyX_args stx)) + (define/with-syntax ((e_arg- ...) Xs* cs) (solve #'Xs #'tyX_args stx)) ;; ) instantiate polymorphic function type - (syntax-parse (inst-types/cs #'Xs #'cs #'tyX_args) + (syntax-parse (inst-types/cs #'Xs* #'cs #'tyX_args) [(τ_in ... τ_out) ; concrete types + #:with (unsolved-X ...) (find-free-Xs #'Xs* #'τ_out) ;; ) arity check #:fail-unless (stx-length=? #'(τ_in ...) #'e_args) (mk-app-err-msg stx #:expected #'(τ_in ...) #:note "Wrong number of arguments.") - ;; ) compute argument types; re-use args expanded during solve - #:with ([e_arg2- τ_arg2] ...) (let ([n (stx-length #'(e_arg1- ...))]) - (infers+erase - (stx-map add-expected-ty - (stx-drop #'e_args n) (stx-drop #'(τ_in ...) n)))) - #:with (τ_arg1 ...) (stx-map typeof #'(e_arg1- ...)) - #:with (τ_arg ...) #'(τ_arg1 ... τ_arg2 ...) - #:with (e_arg- ...) #'(e_arg1- ... e_arg2- ...) + ;; ) compute argument types + #:with (τ_arg ...) (stx-map typeof #'(e_arg- ...)) ;; ) typecheck args #:fail-unless (typechecks? #'(τ_arg ...) #'(τ_in ...)) (mk-app-err-msg stx @@ -749,14 +907,23 @@ (define new-orig (and old-orig (substs - (stx-map get-orig (lookup-Xs/keep-unsolved #'Xs #'cs)) #'Xs old-orig + (stx-map get-orig (lookup-Xs/keep-unsolved #'Xs* #'cs)) + #'Xs* + old-orig (lambda (x y) (equal? (syntax->datum x) (syntax->datum y)))))) (set-stx-prop/preserved tyin 'orig (list new-orig))) #'(τ_in ...))) #:with τ_out* (if (stx-null? #'(unsolved-X ...)) #'τ_out - (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn)) + (syntax-parse #'τ_out + [(~?∀ (Y ...) τ_out) + (unless (→? #'τ_out) + (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn)) + (for ([X (in-list (syntax->list #'(unsolved-X ...)))]) + (unless (covariant-X? X #'τ_out) + (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn))) + #'(∀ (unsolved-X ... Y ...) τ_out)])) (⊢ (#%app e_fn- e_arg- ...) : τ_out*)])])] [(_ e_fn . e_args) ; err case; e_fn is not a function #:with [e_fn- τ_fn] (infer+erase #'e_fn) @@ -814,7 +981,7 @@ ;; threads (define-typed-syntax thread [(_ th) - #:with (th- (~∀ () (~ext-stlc:→ τ_out))) (infer+erase #'th) + #:with (th- (~?∀ () (~ext-stlc:→ τ_out))) (infer+erase #'th) (⊢ (thread th-) : Thread)]) (define-primop random : (→ Int Int)) @@ -1177,10 +1344,7 @@ [(_ e ty ...) #:with [ee tyty] (infer+erase #'e) #:with [e- ty_e] (infer+erase #'(sysf:inst e ty ...)) - #:with ty_out (if (→? #'ty_e) - #'(∀ () ty_e) - #'ty_e) - (⊢ e- : ty_out)])) + (⊢ e- : ty_e)])) (define-typed-syntax read [(_) @@ -1188,3 +1352,30 @@ (cond [(eof-object? x) ""] [(number? x) (number->string x)] [(symbol? x) (symbol->string x)])) : String)]) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(module+ test + (begin-for-syntax + (check-true (covariant-Xs? #'Int)) + (check-true (covariant-Xs? #'(stlc+box:Ref Int))) + (check-true (covariant-Xs? #'(→ Int Int))) + (check-true (covariant-Xs? #'(∀ (X) X))) + (check-false (covariant-Xs? #'(∀ (X) (stlc+box:Ref X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ X X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ X Int)))) + (check-true (covariant-Xs? #'(∀ (X) (→ Int X)))) + (check-true (covariant-Xs? #'(∀ (X) (→ (→ X Int) X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ (→ (→ X Int) Int) X)))) + (check-false (covariant-Xs? #'(∀ (X) (→ (stlc+box:Ref X) Int)))) + (check-false (covariant-Xs? #'(∀ (X Y) (→ X Y)))) + (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) Y)))) + (check-false (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Y Int))))) + (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Int Y))))) + (check-false (covariant-Xs? #'(∀ (A B) (→ (→ Int (stlc+rec-iso:× A B)) + (→ String (stlc+rec-iso:× A B)) + (stlc+rec-iso:× A B))))) + (check-true (covariant-Xs? #'(∀ (A B) (→ (→ (stlc+rec-iso:× A B) Int) + (→ (stlc+rec-iso:× A B) String) + (stlc+rec-iso:× A B))))) + )) diff --git a/tapl/stlc+tup.rkt b/tapl/stlc+tup.rkt index 03e3726..9d08d13 100644 --- a/tapl/stlc+tup.rkt +++ b/tapl/stlc+tup.rkt @@ -1,5 +1,7 @@ #lang s-exp "typecheck.rkt" (extends "ext-stlc.rkt") + +(require (for-syntax racket/list)) ;; Simply-Typed Lambda Calculus, plus tuples ;; Types: @@ -9,7 +11,9 @@ ;; - terms from ext-stlc.rkt ;; - tup and proj -(define-type-constructor × #:arity >= 0) +(define-type-constructor × #:arity >= 0 + #:arg-variances (λ (stx) + (make-list (stx-length (stx-cdr stx)) covariant))) (define-typed-syntax tup [(_ e ...) diff --git a/tapl/stlc.rkt b/tapl/stlc.rkt index e5c104d..22b8145 100644 --- a/tapl/stlc.rkt +++ b/tapl/stlc.rkt @@ -2,6 +2,8 @@ (provide (for-syntax current-type=? types=?)) (provide (for-syntax mk-app-err-msg)) +(require (for-syntax racket/list)) + ;; Simply-Typed Lambda Calculus ;; - no base types; can't write any terms ;; Types: multi-arg → (1+) @@ -66,7 +68,13 @@ (define-syntax-category type) -(define-type-constructor → #:arity >= 1) +(define-type-constructor → #:arity >= 1 + #:arg-variances (λ (stx) + (syntax-parse stx + [(_ τ_in ... τ_out) + (append + (make-list (stx-length #'[τ_in ...]) contravariant) + (list covariant))]))) (define-typed-syntax λ [(_ bvs:type-ctx e) diff --git a/tapl/stx-utils.rkt b/tapl/stx-utils.rkt index fac72f3..694c731 100644 --- a/tapl/stx-utils.rkt +++ b/tapl/stx-utils.rkt @@ -1,5 +1,5 @@ #lang racket/base -(require syntax/stx racket/list version/utils) +(require syntax/stx syntax/parse racket/list version/utils) (provide (all-defined-out)) (define (stx-cadr stx) (stx-car (stx-cdr stx))) @@ -70,6 +70,9 @@ (define (generate-temporariesss stx) (stx-map generate-temporariess stx)) +;; set-stx-prop/preserved : Stx Any Any -> Stx +;; Returns a new syntax object with the prop property set to val. If preserved +;; syntax properties are supported, this also marks the property as preserved. (define REQUIRED-VERSION "6.5.0.4") (define VERSION (version)) (define PRESERVED-STX-PROP-SUPPORTED? (version<=? REQUIRED-VERSION VERSION)) @@ -78,6 +81,16 @@ (syntax-property stx prop val #t) (syntax-property stx prop val))) +;; stx-contains-id? : Stx Id -> Boolean +;; Returns true if stx contains the identifier x, false otherwise. +(define (stx-contains-id? stx x) + (syntax-parse stx + [a:id (free-identifier=? #'a x)] + [(a . b) + (or (stx-contains-id? #'a x) + (stx-contains-id? #'b x))] + [_ #false])) + ;; based on make-variable-like-transformer from syntax/transformer, ;; but using (#%app id ...) instead of ((#%expression id) ...) (define (make-variable-like-transformer ref-stx) diff --git a/tapl/tests/mlish-tests.rkt b/tapl/tests/mlish-tests.rkt index a818110..f2ed1bc 100644 --- a/tapl/tests/mlish-tests.rkt +++ b/tapl/tests/mlish-tests.rkt @@ -36,7 +36,7 @@ ;; type err (typecheck-fail (Cons 1 1) - #:with-msg (expected "Int, (List Int)" #:given "Int, Int")) + #:with-msg "expected: \\(List Int\\)\n *given: Int") ;; check Nil still available as tyvar (define (f11 [x : Nil] -> Nil) x) @@ -55,7 +55,7 @@ (check-type g2 : (→/test (List Y) (List Y))) (typecheck-fail (g2 1) #:with-msg - (expected "(List Y)" #:given "Int")) + "expected: \\(List Y\\)\n *given: Int") ;; todo? allow polymorphic nil? (check-type (g2 (Nil {Int})) : (List Int) ⇒ (Nil {Int})) @@ -113,7 +113,7 @@ (check-type (map add1 (Cons 1 (Cons 2 (Cons 3 Nil)))) : (List Int) ⇒ (Cons 2 (Cons 3 (Cons 4 Nil)))) (typecheck-fail (map add1 (Cons "1" Nil)) - #:with-msg (expected "Int, (List Int)" #:given "String, (List Int)")) + #:with-msg "expected: Int\n *given: String") (check-type (map (λ ([x : Int]) (+ x 2)) (Cons 1 (Cons 2 (Cons 3 Nil)))) : (List Int) ⇒ (Cons 3 (Cons 4 (Cons 5 Nil)))) ;; ; doesnt work yet: all lambdas need annotations @@ -179,6 +179,24 @@ (check-type (build-list 5 (λ (x) (add1 (add1 x)))) : (List Int) ⇒ (Cons 6 (Cons 5 (Cons 4 (Cons 3 (Cons 2 Nil)))))) +(define (build-list/comp [i : Int] [n : Int] [nf : (→ Int Int)] [f : (→ Int X)] → (List X)) + (if (= i n) + Nil + (Cons (f (nf i)) (build-list/comp (add1 i) n nf f)))) + +(define built-list-1 (build-list/comp 0 3 (λ (x) (* 2 x)) add1)) +(define built-list-2 (build-list/comp 0 3 (λ (x) (* 2 x)) number->string)) +(check-type built-list-1 : (List Int) -> (Cons 1 (Cons 3 (Cons 5 Nil)))) +(check-type built-list-2 : (List String) -> (Cons "0" (Cons "2" (Cons "4" Nil)))) + +(define (~>2 [a : A] [f : (→ A A)] [g : (→ A B)] → B) + (g (f a))) + +(define ~>2-result-1 (~>2 1 (λ (x) (* 2 x)) add1)) +(define ~>2-result-2 (~>2 1 (λ (x) (* 2 x)) number->string)) +(check-type ~>2-result-1 : Int -> 3) +(check-type ~>2-result-2 : String -> "2") + (define (append [lst1 : (List X)] [lst2 : (List X)] → (List X)) (match lst1 with [Nil -> lst2] @@ -242,8 +260,7 @@ (typecheck-fail Nil #:with-msg "add annotations") (typecheck-fail (Cons 1 (Nil {Bool})) #:with-msg - (expected "Int, (List Int)" #:given "Int, (List Bool)" - #:note "Type error applying.*Cons")) + "expected: \\(List Int\\)\n *given: \\(List Bool\\)") (typecheck-fail (Cons {Bool} 1 (Nil {Int})) #:with-msg (expected "Bool, (List Bool)" #:given "Int, (List Int)" @@ -285,6 +302,8 @@ (None) (Some A)) +(define (None* → (Option A)) None) + (check-type (match (tup 1 2) with [a b -> None]) : (Option Int) -> None) (check-type (match (list 1 2) with @@ -380,6 +399,52 @@ (check-type ((inst nn2 Int (List Int) String) 1) : (→ (× Int (→ (List Int) (List Int)) (List String)))) +(define (nn3 [x : X] -> (→ (× X (Option Y) (Option Z)))) + (λ () (tup x None None))) +(check-type (nn3 1) : (→/test (× Int (Option Y) (Option Z)))) +(check-type (nn3 1) : (→ (× Int (Option String) (Option (List Int))))) +(check-type ((nn3 1)) : (× Int (Option String) (Option (List Int)))) +(check-type ((nn3 1)) : (× Int (Option (List Int)) (Option String))) +;; test inst order +(check-type ((inst (nn3 1) String (List Int))) : (× Int (Option String) (Option (List Int)))) +(check-type ((inst (nn3 1) (List Int) String)) : (× Int (Option (List Int)) (Option String))) + +(define (nn4 -> (→ (Option X))) + (λ () (None*))) +(check-type (let ([x (nn4)]) + x) + : (→/test (Option X))) + +(define (nn5 -> (→ (Ref (Option X)))) + (λ () (ref (None {X})))) +(typecheck-fail (let ([x (nn5)]) + x) + #:with-msg "Could not infer instantiation of polymorphic function nn5.") + +(define (nn6 -> (→ (Option X))) + (let ([r (((inst nn5 X)))]) + (λ () (deref r)))) +(check-type (nn6) : (→/test (Option X))) + +;; A is covariant, B is invariant. +(define-type (Cps A B) + (cps (→ (→ A B) B))) +(define (cps* [f : (→ (→ A B) B)] → (Cps A B)) + (cps f)) + +(define (nn7 -> (→ (Cps (Option A) B))) + (let ([r (((inst nn5 A)))]) + (λ () (cps* (λ (k) (k (deref r))))))) +(typecheck-fail (let ([x (nn7)]) + x) + #:with-msg "Could not infer instantiation of polymorphic function nn7.") + +(define (nn8 -> (→ (Cps (Option A) Int))) + (nn7)) +(check-type (let ([x (nn8)]) + x) + : (→/test (Cps (Option A) Int))) + (define-type (Result A B) (Ok A) (Error B)) @@ -389,6 +454,35 @@ (define (error [b : B] → (Result A B)) (Error b)) +(define (ok-fn [a : A] -> (→ (Result A B))) + (λ () (ok a))) +(define (error-fn [b : B] -> (→ (Result A B))) + (λ () (error b))) + +(check-type (let ([x (ok-fn 1)]) + x) + : (→/test (Result Int B))) +(check-type (let ([x (error-fn "bad")]) + x) + : (→/test (Result A String))) + +(define (nn9 [a : A] -> (→ (Result A (Ref B)))) + (ok-fn a)) +(define (nn10 [a : A] -> (→ (Result A (Ref String)))) + (nn9 a)) +(define (nn11 -> (→ (Result (Option A) (Ref String)))) + (nn10 (None*))) + +(typecheck-fail (let ([x (nn9 1)]) + x) + #:with-msg "Could not infer instantiation of polymorphic function nn9.") +(check-type (let ([x (nn10 1)]) + x) + : (→ (Result Int (Ref String)))) +(check-type (let ([x (nn11)]) + x) + : (→/test (Result (Option A) (Ref String)))) + (check-type (if (zero? (random 2)) (ok 0) (error "didn't get a zero")) @@ -453,6 +547,21 @@ (λ (b) (Error (Cons b Nil)))) : (Result (List Int) (List String))) +(define (tup* [a : A] [b : B] -> (× A B)) + (tup a b)) + +(define (nn12 -> (→ (× (Option A) (Option B)))) + (λ () (tup* (None*) (None*)))) +(check-type (let ([x (nn12)]) + x) + : (→/test (× (Option A) (Option B)))) + +(define (nn13 -> (→ (× (Option A) (Option (Ref B))))) + (nn12)) +(typecheck-fail (let ([x (nn13)]) + x) + #:with-msg "Could not infer instantiation of polymorphic function nn13.") + ;; records and automatically-defined accessors and predicates (define-type (RecoTest X Y) (RT1 [x : X] [y : Y] [z : String]) diff --git a/tapl/tests/mlish/alex.mlish b/tapl/tests/mlish/alex.mlish index dc29a17..9e80c23 100644 --- a/tapl/tests/mlish/alex.mlish +++ b/tapl/tests/mlish/alex.mlish @@ -14,3 +14,12 @@ (check-type try : (→/test X (→ X Y) X)) +(define (accept-A×A [pair : (× A A)] → (× A A)) + pair) + +(typecheck-fail (accept-A×A (tup 8 "ate")) + #:with-msg "couldn't unify Int and String\n *expected: \\(× A A\\)\n *given: \\(× Int String\\)") + +(typecheck-fail (ann (accept-A×A (tup 8 "ate")) : (× String String)) + #:with-msg "expected: \\(× String String\\)\n *given: \\(× Int String\\)") + diff --git a/tapl/tests/mlish/match2.mlish b/tapl/tests/mlish/match2.mlish index 85ab5ad..49de5f9 100644 --- a/tapl/tests/mlish/match2.mlish +++ b/tapl/tests/mlish/match2.mlish @@ -62,7 +62,7 @@ (typecheck-fail (match2 (B 1) with [B x -> x]) - #:with-msg (expected "(× X X)" #:given "Int")) + #:with-msg "expected: \\(× X X\\)\n *given: Int") (check-type (match2 (B (tup 2 3)) with diff --git a/tapl/tests/mlish/queens.mlish b/tapl/tests/mlish/queens.mlish index 8814805..d45b4b4 100644 --- a/tapl/tests/mlish/queens.mlish +++ b/tapl/tests/mlish/queens.mlish @@ -46,7 +46,7 @@ (check-type (map add1 (Cons 1 (Cons 2 (Cons 3 Nil)))) : (List Int) ⇒ (Cons 2 (Cons 3 (Cons 4 Nil)))) (typecheck-fail (map add1 (Cons "1" Nil)) - #:with-msg (expected "Int, (List Int)" #:given "String, (List Int)")) + #:with-msg "expected: Int\n *given: String") (check-type (map (λ ([x : Int]) (+ x 2)) (Cons 1 (Cons 2 (Cons 3 Nil)))) : (List Int) ⇒ (Cons 3 (Cons 4 (Cons 5 Nil)))) ;; ; doesnt work yet: all lambdas need annotations diff --git a/tapl/typecheck.rkt b/tapl/typecheck.rkt index 1391e15..c1db34d 100644 --- a/tapl/typecheck.rkt +++ b/tapl/typecheck.rkt @@ -15,6 +15,9 @@ "stx-utils.rkt")) (for-meta 2 (all-from-out racket/base syntax/parse racket/syntax))) +(module+ test + (require (for-syntax rackunit))) + ;; type checking functions/forms ;; General type checking strategy: @@ -420,6 +423,48 @@ (define (brack? stx) (define paren-shape/#f (syntax-property stx 'paren-shape)) (and paren-shape/#f (char=? paren-shape/#f #\[))) + + (define (iff b1 b2) + (boolean=? b1 b2)) + + ;; Variance is (variance Boolean Boolean) + (struct variance (covariant? contravariant?) #:prefab) + (define irrelevant (variance #true #true)) + (define covariant (variance #true #false)) + (define contravariant (variance #false #true)) + (define invariant (variance #false #false)) + ;; variance-irrelevant? : Variance -> Boolean + (define (variance-irrelevant? v) + (and (variance-covariant? v) (variance-contravariant? v))) + ;; variance-invariant? : Variance -> Boolean + (define (variance-invariant? v) + (and (not (variance-covariant? v)) (not (variance-contravariant? v)))) + ;; variance-join : Variance Variance -> Variance + (define (variance-join v1 v2) + (variance (and (variance-covariant? v1) + (variance-covariant? v2)) + (and (variance-contravariant? v1) + (variance-contravariant? v2)))) + ;; variance-compose : Variance Variance -> Variance + (define (variance-compose v1 v2) + (variance (or (variance-irrelevant? v1) + (variance-irrelevant? v2) + (and (variance-covariant? v1) (variance-covariant? v2)) + (and (variance-contravariant? v1) (variance-contravariant? v2))) + (or (variance-irrelevant? v1) + (variance-irrelevant? v2) + (and (variance-covariant? v1) (variance-contravariant? v2)) + (and (variance-contravariant? v1) (variance-covariant? v2))))) + + ;; add-arg-variances : Id (Listof Variance) -> Id + ;; Takes a type constructor id and adds variance information about the arguments. + (define (add-arg-variances id arg-variances) + (set-stx-prop/preserved id 'arg-variances arg-variances)) + ;; get-arg-variances : Id -> (U False (Listof Variance)) + ;; Takes a type constructor id and gets the argument variance information. + (define (get-arg-variances id) + (syntax-property id 'arg-variances)) + ;; todo: abstract out the common shape of a type constructor, ;; i.e., the repeated pattern code in the functions below (define (get-extra-info t) @@ -482,6 +527,10 @@ #:defaults ([bvs-op #'=][bvs-n #'0])) (~optional (~seq #:arr (~and (~parse has-annotations? #'#t) tycon)) #:defaults ([tycon #'void])) + (~optional (~seq #:arg-variances arg-variances-stx:expr) + #:defaults ([arg-variances-stx + #`(λ (stx-id) (for/list ([arg (in-list (stx->list (stx-cdr stx-id)))]) + invariant))])) (~optional (~seq #:extra-info extra-info) #:defaults ([extra-info #'void])) (~optional (~and #:no-provide (~parse no-provide? #'#t)))) @@ -532,6 +581,7 @@ #:msg "Expected ~a type, got: ~a" #'τ #'any))))]))) + (define arg-variances arg-variances-stx) (define (τ? t) (and (stx-pair? t) (syntax-parse t @@ -565,10 +615,11 @@ #:with k_result (if #,(attribute has-annotations?) #'(tycon k_arg (... ...)) #'#%kind) + #:with τ-internal* (add-arg-variances #'τ-internal (arg-variances stx)) (add-orig (assign-type (syntax/loc stx - (τ-internal (λ bvs- (#%expression extra-info) . τs-))) + (τ-internal* (λ bvs- (#%expression extra-info) . τs-))) #'k_result) #'(τ . args))] ;; else fail with err msg @@ -701,3 +752,48 @@ (define (substs τs xs e [cmp bound-identifier=?]) (stx-fold (lambda (ty x res) (subst ty x res cmp)) e τs xs))) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(module+ test + (begin-for-syntax + (test-case "variance-join" + (test-case "joining with irrelevant doesn't change it" + (check-equal? (variance-join irrelevant irrelevant) irrelevant) + (check-equal? (variance-join irrelevant covariant) covariant) + (check-equal? (variance-join irrelevant contravariant) contravariant) + (check-equal? (variance-join irrelevant invariant) invariant)) + (test-case "joining with invariant results in invariant" + (check-equal? (variance-join invariant irrelevant) invariant) + (check-equal? (variance-join invariant covariant) invariant) + (check-equal? (variance-join invariant contravariant) invariant) + (check-equal? (variance-join invariant invariant) invariant)) + (test-case "joining a with a results in a" + (check-equal? (variance-join irrelevant irrelevant) irrelevant) + (check-equal? (variance-join covariant covariant) covariant) + (check-equal? (variance-join contravariant contravariant) contravariant) + (check-equal? (variance-join invariant invariant) invariant)) + (test-case "joining covariant with contravariant results in invariant" + (check-equal? (variance-join covariant contravariant) invariant) + (check-equal? (variance-join contravariant covariant) invariant))) + (test-case "variance-compose" + (test-case "composing with covariant doesn't change it" + (check-equal? (variance-compose covariant irrelevant) irrelevant) + (check-equal? (variance-compose covariant covariant) covariant) + (check-equal? (variance-compose covariant contravariant) contravariant) + (check-equal? (variance-compose covariant invariant) invariant)) + (test-case "composing with irrelevant results in irrelevant" + (check-equal? (variance-compose irrelevant irrelevant) irrelevant) + (check-equal? (variance-compose irrelevant covariant) irrelevant) + (check-equal? (variance-compose irrelevant contravariant) irrelevant) + (check-equal? (variance-compose irrelevant invariant) irrelevant)) + (test-case "otherwise composing with invariant results in invariant" + (check-equal? (variance-compose invariant covariant) invariant) + (check-equal? (variance-compose invariant contravariant) invariant) + (check-equal? (variance-compose invariant invariant) invariant)) + (test-case "composing with with contravariant flips covariant and contravariant" + (check-equal? (variance-compose contravariant covariant) contravariant) + (check-equal? (variance-compose contravariant contravariant) covariant) + (check-equal? (variance-compose contravariant irrelevant) irrelevant) + (check-equal? (variance-compose contravariant invariant) invariant))) + ))