diff --git a/tapl/stlc+occurrence.rkt b/tapl/stlc+occurrence.rkt index 640e364..98ea25d 100644 --- a/tapl/stlc+occurrence.rkt +++ b/tapl/stlc+occurrence.rkt @@ -67,6 +67,7 @@ ;; flatten nested unions (begin-for-syntax + (define τ-eval (current-type-eval)) (define (τ->symbol τ) @@ -84,36 +85,51 @@ [_ (error 'τ->symbol (~a (syntax->datum τ)))])) - (define (∪-eval τ-stx) - (syntax-parse (τ-eval τ-stx) - [(~∪ τ-stx* ...) - ;; Recursively evaluate members - (define τ** - (for/list ([τ (in-list (syntax->list #'(τ-stx* ...)))]) - (let ([τ+ (∪-eval τ)]) - (if (∪? τ+) - (∪->list τ+) - (list τ+))))) - ;; Remove duplicates from the union, sort members - (define τ* - (sort - (remove-duplicates (apply append τ**) (current-type=?)) - symbolsymbol)) - ;; Check for empty & singleton lists - (define τ - (cond - [(null? τ*) - (raise-user-error 'τ-eval "~a (~a:~a) empty union type ~a\n" - (syntax-source τ-stx) (syntax-line τ-stx) (syntax-column τ-stx) - (syntax->datum τ-stx))] - [(null? (cdr τ*)) - #`#,(car τ*)] - [else - #`#,(cons #'∪ τ*)])) - (τ-eval τ)] - [_ - (τ-eval τ-stx)])) + (define ∪-eval + ;; Private helper: check that all functions have unique arities + ;; It's private because it assumes all τ* have been evaluated + (let ([assert-unique-arity-arrows + (lambda (τ*) + (for/fold ([seen '()]) + ([τ (in-list τ*)]) + (syntax-parse τ + [(~→ τ-dom* ... τ-cod) + (define arity (stx-length #'(τ-dom* ...))) + (when (memv arity seen) + (error '∪ (format "Cannot discriminate types in the union ~a. Multiple functions have arity ~a." (cons '∪ (map syntax->datum τ*)) arity))) + (cons arity seen)] + [_ seen])))]) + (lambda (τ-stx) + (syntax-parse (τ-eval τ-stx) + [(~∪ τ-stx* ...) + ;; Recursively evaluate members + (define τ** + (for/list ([τ (in-list (syntax->list #'(τ-stx* ...)))]) + (let ([τ+ (∪-eval τ)]) + (if (∪? τ+) + (∪->list τ+) + (list τ+))))) + ;; Remove duplicates from the union, sort members + (define τ* + (sort + (remove-duplicates (apply append τ**) (current-type=?)) + symbolsymbol)) + ;; Check for empty & singleton lists + (define τ + (cond + [(null? τ*) + (raise-user-error 'τ-eval "~a (~a:~a) empty union type ~a\n" + (syntax-source τ-stx) (syntax-line τ-stx) (syntax-column τ-stx) + (syntax->datum τ-stx))] + [(null? (cdr τ*)) + #`#,(car τ*)] + [else + (assert-unique-arity-arrows τ*) + #`#,(cons #'∪ τ*)])) + (τ-eval τ)] + [_ + (τ-eval τ-stx)])))) (current-type-eval ∪-eval)) ;; ----------------------------------------------------------------------------- diff --git a/tapl/tests/stlc+occurrence-tests.rkt b/tapl/tests/stlc+occurrence-tests.rkt index e573803..f815af4 100644 --- a/tapl/tests/stlc+occurrence-tests.rkt +++ b/tapl/tests/stlc+occurrence-tests.rkt @@ -272,8 +272,8 @@ ;; ----------------------------------------------------------------------------- ;; --- Functions in union -(check-type (λ ([x : (∪ Int (∪ Nat) (∪ (→ Int Int)) (→ (→ (→ Int Int)) Int))]) #t) - : (→ (∪ Int Nat (→ Int Int) (→ (→ (→ Int Int)) Int)) Boolean)) +(check-type (λ ([x : (∪ Int (∪ Nat) (∪ (→ Int Str Int)) (→ (→ (→ Int Int)) Int))]) #t) + : (→ (∪ Int Nat (→ Int Str Int) (→ (→ (→ Int Int)) Int)) Boolean)) (check-type (λ ([x : (∪ Int (→ Int Int))]) #t) : (→ Int Boolean)) @@ -324,8 +324,8 @@ ;; --- disallow same-arity functions (typecheck-fail - (λ ([x : (∪ (→ Int Int) (→ Str Str))]) (x 1)) - #:with-msg "boooo") + (λ ([x : (∪ (→ Int Int) (→ Str Str))]) 1) + #:with-msg "Cannot discriminate") ;; ----------------------------------------------------------------------------- ;; --- TODO Filter values (should do nothing)