diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt index 5cc5a0b..ece16a8 100644 --- a/tapl/mlish.rkt +++ b/tapl/mlish.rkt @@ -186,8 +186,8 @@ (stlc+rec-iso:define-type-alias Name Name2))] [(_ (Name:id X:id ...) ;; constructors must have the form (Cons τ ...) - ;; but the first ~or clause accepts 0-arg constructors as ids - ;; the ~and is required to bind the duplicate Cons ids (see Ryan's email) + ;; but the first ~or clause accepts 0-arg constructors as ids; + ;; the ~and is a workaround to bind the duplicate Cons ids (see Ryan's email) (~and (~or (~and IdCons:id (~parse (Cons [fld (~datum :) τ] ...) #'(IdCons))) (Cons [fld (~datum :) τ] ...) @@ -199,20 +199,29 @@ #:with ((e_arg ...) ...) (stx-map generate-temporaries #'((τ ...) ...)) #:with ((e_arg- ...) ...) (stx-map generate-temporaries #'((τ ...) ...)) #:with ((τ_arg ...) ...) (stx-map generate-temporaries #'((τ ...) ...)) -; #:with ((fld ...) ...) (stx-map generate-temporaries #'((τ ...) ...)) #:with ((acc ...) ...) (stx-map (λ (S fs) (stx-map (λ (f) (format-id S "~a-~a" S f)) fs)) #'(StructName ...) #'((fld ...) ...)) #:with (Cons? ...) (stx-map mk-? #'(StructName ...)) #:with get-Name-info (format-id #'Name "get-~a-info" #'Name) ;; types, but using RecName instead of Name - #:with ((τ/rec ...) ...) (subst-expr #'RecName #'(Name X ...) #'((τ ...) ...)) + #:with ((τ/rec ...) ...) (subst #'RecName #'Name #'((τ ...) ...)) #`(begin (define-type-constructor Name #:arity = #,(stx-length #'(X ...)) #:extra-info (X ...) (λ (RecName) - (let-syntax ([RecName (make-rename-transformer - (assign-type #'RecName #'#%type))]) + (let-syntax + ([RecName + (syntax-parser + [(_ . rst) + ;; - this is a placeholder to break the recursion + ;; - clients, like match, must manually unfold by + ;; replacing the entire (#%plain-app RecName ...) stx + ;; - to preserve polymorphic recursion, the entire stx + ;; should be replaced but with #'rst as the args + ;; in place of args in the input type + ;; (see subst-special in typecheck.rkt) + (assign-type #'(#%plain-app RecName . rst) #'#%type)])]) ('Cons Cons? [acc τ/rec] ...) ...)) #:no-provide) (struct StructName (fld ...) #:reflection-name 'Cons #:transparent) ... @@ -291,6 +300,7 @@ ;; match -------------------------------------------------- (define-syntax (match stx) (syntax-parse stx #:datum-literals (with ->) + ;; TODO: eliminate redundant expansions [(_ e with . clauses) ;; e is tuple #:with [e- ty_e] (infer+erase #'e) @@ -320,7 +330,7 @@ ((~literal let-values) () . info-body))) (get-extra-info #'τ_e) - #:with info-unfolded (subst #'τ_e #'RecName #'info-body) + #:with info-unfolded (subst-special #'τ_e #'RecName #'info-body) #:with ((_ ((~literal quote) ConsAll) . _) ...) #'info-body #:fail-unless (set=? (syntax->datum #'(Clause ...)) (syntax->datum #'(ConsAll ...))) @@ -341,6 +351,7 @@ (equal? Cl (syntax->datum #'C))]) #'info-unfolded)) (syntax->datum #'(Clause ...))) + ;; this commented block experiments with expanding to unsafe ops ;; #:with ((acc ...) ...) (stx-map ;; (lambda (accs) ;; (for/list ([(a i) (in-indexed (syntax->list accs))]) diff --git a/tapl/tests/mlish/polyrecur.mlish b/tapl/tests/mlish/polyrecur.mlish new file mode 100644 index 0000000..65c6ded --- /dev/null +++ b/tapl/tests/mlish/polyrecur.mlish @@ -0,0 +1,28 @@ +#lang s-exp "../../mlish.rkt" +(require "../rackunit-typechecking.rkt") + +;; tests of polymorphic recursion + +;; polymorphic recursion of functions +(define (polyf [lst : (List X)] -> (List X)) + (let ([x (polyf (list 1 2 3))] + [y (polyf (list #t #f))]) + (polyf lst))) + +;; polymorphic recursive type +;; from okasaki, ch10 +(define-type (Seq X) + Nil + (Cons X (Seq (× X X)))) + +(define (size [s : (Seq X)] -> Int) + (match s with + [Nil -> 0] + [Cons x ps -> (add1 (* 2 (size ps)))])) + +(check-type (size (Nil {Int})) : Int -> 0) +(check-type (size (Cons 1 Nil)) : Int -> 1) +(check-type (size (Cons 1 (Cons (tup 2 3) Nil))) : Int -> 3) +(check-type + (size (Cons 1 (Cons (tup 2 3) (Cons (tup (tup 4 5) (tup 6 7)) Nil)))) + : Int -> 7) diff --git a/tapl/tests/run-all-mlish-tests.rkt b/tapl/tests/run-all-mlish-tests.rkt index 5d1fdb5..b4f5dfd 100644 --- a/tapl/tests/run-all-mlish-tests.rkt +++ b/tapl/tests/run-all-mlish-tests.rkt @@ -25,3 +25,6 @@ (require "mlish/bg/basics.mlish") (require "mlish/bg/huffman.mlish") (require "mlish/bg/lambda.rkt") + +;; okasaki, polymorphic recursion +(require "mlish/polyrecur.mlish") diff --git a/tapl/typecheck.rkt b/tapl/typecheck.rkt index b345d3d..2620646 100644 --- a/tapl/typecheck.rkt +++ b/tapl/typecheck.rkt @@ -422,13 +422,22 @@ (define (brace? stx) (define paren-shape/#f (syntax-property stx 'paren-shape)) (and paren-shape/#f (char=? paren-shape/#f #\{))) + ;; todo: abstract out the common shape of a type constructor, + ;; i.e., the repeated pattern code in the functions below (define (get-extra-info t) (syntax-parse t [((~literal #%plain-app) internal-id ((~literal #%plain-lambda) bvs ((~literal #%expression) extra-info-to-extract) . rst)) #'extra-info-to-extract] - [_ #'void]))) + [_ #'void])) + (define (get-tyargs ty) + (syntax-parse ty + [((~literal #%plain-app) internal-id + ((~literal #%plain-lambda) bvs + xtra-info . rst)) + #'rst]))) + (define-syntax define-basic-checked-id-stx (syntax-parser #:datum-literals (:) @@ -693,16 +702,8 @@ stx)) ; subst τ for y in e, if (bound-id=? x y) (define (subst τ x e) - #'(printf "subst ~a for ~a in ~a\n" - (syntax->datum τ) - x - (syntax->datum e)) (syntax-parse e [y:id #:when (bound-identifier=? e x) - ; #:when (printf "~a = ~a\n" #'y x) -; #:when -; (displayln (syntax-property (syntax-track-origin τ #'y #'y) 'type)) -; #:when (displayln (syntax-property (syntax-property (syntax-track-origin τ #'y #'y) 'type #'#%type) 'type)) ; use syntax-track-origin to transfer 'orig ; but may transfer multiple #%type tags, so merge (merge-type-tags (syntax-track-origin τ #'y #'y))] @@ -714,11 +715,24 @@ (define (substs τs xs e) (stx-fold subst e τs xs)) - ;; subst-expr - ;; used for inferring recursive types + ;; subst-expr: + ;; - like subst except the target can be any stx, rather than just an id + ;; - used for implementing polymorphic recursive types + (define (stx-lam? s) + (syntax-parse s + [((~literal #%plain-lambda) . rst) #t] [_ #f])) + (define (stx-lam=? s1 s2) + (syntax-parse (list s1 s2) + [(((~literal #%plain-lambda) xs . bs1) + ((~literal #%plain-lambda) ys . bs2)) + #:with zs (generate-temporaries #'xs) + (and (stx-length=? #'xs #'ys) + (stx=? (substs #'zs #'xs #'bs1) + (substs #'zs #'ys #'bs2)))])) (define (stx=? s1 s2) (or (and (identifier? s1) (identifier? s2) (free-identifier=? s1 s2)) (and (stx-null? s1) (stx-null? s2)) + (and (stx-lam? s1) (stx-lam? s2) (stx-lam=? s1 s2)) (and (stx-pair? s1) (stx-pair? s2) (stx-length=? s1 s2) (stx-andmap stx=? s1 s2)))) ;; subst e1 for e2 in e3 @@ -727,5 +741,26 @@ [(stx=? e2 e3) e1] [(identifier? e3) e3] [else ; stx-pair - (stx-map (lambda (e) (subst-expr e1 e2 e)) e3)])) + (with-syntax ([result (stx-map (lambda (e) (subst-expr e1 e2 e)) e3)]) + (syntax-track-origin #'result e3 #'here))])) + (define (subst-exprs τs xs e) + (stx-fold subst-expr e τs xs)) + ;; subst-special: + ;; - used for unfolding polymorphic recursive type + ;; subst ty1 for x in ty2 + ;; where ty1 is an applied type constructor type + ;; x is a placeholder for an applied tycons type in ty2 + ;; - subst special first replaces the args of ty1 with that of x + ;; before replacing applications of tycons x with this modified ty1 + (define (subst-special ty1 x ty2) + (cond + [(identifier? ty2) ty2] + [(syntax-parse ty2 [((~literal #%plain-app) tycons:id . _) (free-identifier=? #'tycons x)] [_ #f]) + (syntax-parse ty2 + [((~literal #%plain-app) tycons:id . newargs) +; #:with oldargs (get-tyargs ty1) + (subst-exprs #'newargs (get-tyargs ty1) ty1)])] + [else ; stx-pair + (with-syntax ([result (stx-map (lambda (e) (subst-special ty1 x e)) ty2)]) + (syntax-track-origin #'result ty2 #'here))])) )