From fca145bbd74db307a3d4f7334cc7a96dfb71e13c Mon Sep 17 00:00:00 2001 From: Stephen Chang Date: Tue, 29 Mar 2016 17:14:47 -0400 Subject: [PATCH] start match2, with support for nested matches --- dont use yet --- tapl/mlish.rkt | 105 ++++++++++++++++++++++++++++++++-- tapl/tests/mlish/match2.mlish | 71 +++++++++++++++++++++++ tapl/typecheck.rkt | 4 +- 3 files changed, 172 insertions(+), 8 deletions(-) create mode 100644 tapl/tests/mlish/match2.mlish diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt index 20ea4a0..4678584 100644 --- a/tapl/mlish.rkt +++ b/tapl/mlish.rkt @@ -11,15 +11,19 @@ (require (only-in "sysf.rkt" ~∀ ∀ ∀? Λ)) (reuse × tup proj define-type-alias #:from "stlc+rec-iso.rkt") (require (only-in "stlc+rec-iso.rkt" ~× ×?)) -(provide → define-type match) +(provide → define-type) (provide (rename-out [ext-stlc:and and] [ext-stlc:#%datum #%datum])) (reuse member length reverse list-ref cons nil isnil head tail list #:from "stlc+cons.rkt") +(require (prefix-in stlc+cons: (only-in "stlc+cons.rkt" list cons nil))) (require (only-in "stlc+cons.rkt" ~List List? List)) (provide List) (reuse ref deref := Ref #:from "stlc+box.rkt") (require (rename-in (only-in "stlc+reco+var.rkt" tup proj ×) [tup rec] [proj get] [× ××])) (provide rec get ××) +;; for pattern matching +(require (prefix-in stlc+cons: (only-in "stlc+cons.rkt" list))) +(require (prefix-in stlc+tup: (only-in "stlc+tup.rkt" tup))) ;; ML-like language ;; - top level recursive functions @@ -202,7 +206,7 @@ #:with ((acc ...) ...) (stx-map (λ (S fs) (stx-map (λ (f) (format-id S "~a-~a" S f)) fs)) #'(StructName ...) #'((fld ...) ...)) #:with (Cons? ...) (stx-map mk-? #'(StructName ...)) - #:with get-Name-info (format-id #'Name "get-~a-info" #'Name) +; #:with get-Name-info (format-id #'Name "get-~a-info" #'Name) ;; types, but using RecName instead of Name #:with ((τ/rec ...) ...) (subst #'RecName #'Name #'((τ ...) ...)) #`(begin @@ -222,7 +226,7 @@ ;; in place of args in the input type ;; (see subst-special in typecheck.rkt) (assign-type #'(#%plain-app RecName . rst) #'#%type)])]) - ('Cons Cons? [acc τ/rec] ...) ...)) + ('Cons 'StructName Cons? [acc τ/rec] ...) ...)) #:no-provide) (struct StructName (fld ...) #:reflection-name 'Cons #:transparent) ... (define-syntax (Cons stx) @@ -298,8 +302,97 @@ ...)])) ;; match -------------------------------------------------- -(define-syntax (match stx) +(begin-for-syntax + (define (get-ctx pat ty) + (unify-pat+ty (list pat ty))) + (define (unify-pat+ty pat+ty) + (syntax-parse pat+ty + [((~datum _) ty) #'()] + [(~literal stlc+cons:nil) ; nil + #'()] + [(x:id ty) + #'((x ty))] + [(((~literal stlc+tup:tup) p ...) ty) ; tup + #:with (~× t ...) #'ty + (unifys #'([p t] ...))] + [(((~literal stlc+cons:list) p ...) ty) ; known length list + #:with (~List t) #'ty + (unifys #'([p t] ...))] + [(((~literal stlc+cons:cons) p ps) ty) ; arb length list + #:with (~List t) #'ty + (unifys #'([p t] [ps ty]))] + [((Name p ...) ty) + #:with ((~literal #%plain-lambda) (RecName) + ((~literal let-values) () + ((~literal let-values) () + . info-body))) + (get-extra-info #'ty) + #:with ((_ ((~literal quote) ConsAll) . _) ...) #'info-body + #:with info-unfolded (subst-special #'τ_e #'RecName #'info-body) + #:with (_ ((~literal quote) Cons) ((~literal quote) StructName) Cons? [_ acc τ] ...) + (stx-findf + (syntax-parser + [((~literal #%plain-app) 'C . rst) + (equal? (syntax->datum #'Name) (syntax->datum #'C))]) + #'info-unfolded) + (unifys #'([p τ] ...))] + [p+t #:fail-when #t (format "could not unify ~a" (syntax->datum #'p+t)) #'()])) + (define (unifys p+tys) (stx-appendmap unify-pat+ty p+tys)) + + (define (compile-pat p ty) + (syntax-parse p + [(~datum _) #'_] + [(~literal stlc+cons:nil) ; nil + #'(list)] + [x:id p] + [((~literal stlc+tup:tup) p ...) + #:with (~× t ...) ty + #:with (p- ...) (stx-map (lambda (p t) (compile-pat p t)) #'(p ...) #'(t ...)) + #'(list p- ...)] + [((~literal stlc+cons:list) p ...) + #:with (~List t) ty + #:with (p- ...) (stx-map (lambda (p) (compile-pat p #'t)) #'(p ...)) + #'(list p- ...)] + [((~literal stlc+cons:cons) p ps) + #:with (~List t) ty + #:with p- (compile-pat #'p #'t) + #:with ps- (compile-pat #'ps ty) + #'(cons p- ps-)] + [(Name p ...) + #:with ((~literal #%plain-lambda) (RecName) + ((~literal let-values) () + ((~literal let-values) () + . info-body))) + (get-extra-info ty) + #:with ((_ ((~literal quote) ConsAll) . _) ...) #'info-body + #:with info-unfolded (subst-special #'τ_e #'RecName #'info-body) + #:with (_ ((~literal quote) Cons) ((~literal quote) StructName) Cons? [_ acc τ] ...) + (stx-findf + (syntax-parser + [((~literal #%plain-app) 'C . rst) + (equal? (syntax->datum #'Name) (syntax->datum #'C))]) + #'info-unfolded) + #:with (p- ...) (stx-map compile-pat #'(p ...) #'(τ ...)) + #'(StructName p- ...)])) + ) + +(provide match2) +(define-syntax (match2 stx) (syntax-parse stx #:datum-literals (with) + [(_ e with . clauses) + #:fail-when (null? (syntax->list #'clauses)) "no clauses" + #:with [e- τ_e] (infer+erase #'e) + (syntax-parse #'clauses #:datum-literals (->) + [([pat -> e_body] ...) + #:with ((~and ctx ([x ty] ...)) ...) (stx-map (lambda (p) (get-ctx p #'τ_e)) #'(pat ...)) + #:with ([(x- ...) e_body- ty_body] ...) (stx-map infer/ctx+erase #'(ctx ...) #'(e_body ...)) + #:with (pat- ...) (stx-map (lambda (p) (compile-pat p #'τ_e)) #'(pat ...)) + #:with τ_out (stx-car #'(ty_body ...)) + (⊢ (match e- [pat- (let ([x- x] ...) e_body-)] ...) : τ_out) + ])])) + +(define-typed-syntax match #:datum-literals (with) +; (syntax-parse stx #:datum-literals (with) [(_ e with . clauses) #:fail-when (null? (syntax->list #'clauses)) "no clauses" #:with [e- τ_e] (infer+erase #'e) @@ -371,7 +464,7 @@ (syntax->datum #'(ConsAll ...)) (syntax->datum #'(Clause ...)))) ", ")) - #:with ((_ ((~literal quote) Cons) Cons? [_ acc τ] ...) ...) + #:with ((_ ((~literal quote) Cons) ((~literal quote) StructName) Cons? [_ acc τ] ...) ...) (map ; ok to compare symbols since clause names can't be rebound (lambda (Cl) (stx-findf @@ -405,7 +498,7 @@ [(and (Cons? z) (let ([x- (acc z)] ...) e_guard-)) (let ([x- (acc z)] ...) e_c-)] ...)) - : τ_out)])])])) + : τ_out)])])]) (define-syntax → ; wrapping → (syntax-parser diff --git a/tapl/tests/mlish/match2.mlish b/tapl/tests/mlish/match2.mlish new file mode 100644 index 0000000..95b830c --- /dev/null +++ b/tapl/tests/mlish/match2.mlish @@ -0,0 +1,71 @@ +#lang s-exp "../../mlish.rkt" +(require "../rackunit-typechecking.rkt") + +;; alternate match that supports nested patterns + +(define-type (Test X) + (A X) + (B (× X X)) + (C (× X (× X X)))) + +(check-type + (match2 (B (tup 2 3)) with + [(B x) -> x]) : (× Int Int) -> (list 2 3)) + +(check-type + (match2 (A (tup 2 3)) with + [(A x) -> x]) : (× Int Int) -> (list 2 3)) + +(check-type + (match2 (A 1) with + [(A x) -> x]) : Int -> 1) + +(typecheck-fail + (match2 (B 1) with + [(B x) -> x]) + #:with-msg "Could not infer instantiation of polymorphic function B") + +(check-type + (match2 (B (tup 2 3)) with + [(B (tup x y)) -> (+ x y)]) : Int -> 5) + +(check-type + (match2 (C (tup 2 (tup 3 4))) with + [(C (tup x (tup y z))) -> (+ x (+ y z))]) : Int -> 9) + +(check-type + (match2 (C (tup 2 (tup 3 4))) with + [(A x) -> x] + [_ -> 100]) : Int -> 100) + + + +;; lists + +(check-type + (match2 (list 1) with + [(list x) -> x]) : Int -> 1) + +(check-type + (match2 (list 1 2) with + [(list x y) -> (+ x y)]) : Int -> 3) + +(check-type + (match2 (list 1 2) with + [(list) -> 0] + [(list x y) -> (+ x y)]) : Int -> 3) + +(check-type + (match2 (list (list 3 4) (list 5 6)) with + [(list) -> 0] + [(list (list w x) (list y z)) -> (+ (+ x y) (+ z w))]) : Int -> 18) + +(check-type + (match2 (list (tup 3 4) (tup 5 6)) with + [(list) -> 0] + [(list (tup w x) (tup y z)) -> (+ (+ x y) (+ z w))]) : Int -> 18) + +#;(check-type + (match2 (nil {Int}) with + [nil -> 0] + [(list x y) -> (+ x y)]) : Int -> 0) diff --git a/tapl/typecheck.rkt b/tapl/typecheck.rkt index 634adde..3c3a24d 100644 --- a/tapl/typecheck.rkt +++ b/tapl/typecheck.rkt @@ -6,9 +6,9 @@ "stx-utils.rkt") (for-meta 2 racket/base syntax/parse racket/syntax syntax/stx "stx-utils.rkt") (for-meta 3 racket/base syntax/parse racket/syntax) - racket/bool racket/provide racket/require) + racket/bool racket/provide racket/require racket/match) (provide - symbol=? + symbol=? match (except-out (all-from-out racket/base) #%module-begin) (for-syntax (all-defined-out)) (all-defined-out) (for-syntax