diff --git a/typecheck.rkt b/typecheck.rkt index 92faf67..2a6af75 100644 --- a/typecheck.rkt +++ b/typecheck.rkt @@ -27,7 +27,8 @@ (define-syntax (define-primop stx) (syntax-parse stx #:datum-literals (:) [(_ op:id : ((~and τ_arg (~not (~literal ...))) ... (~optional (~and ldots (~literal ...))) - arr τ_result)) + (~and arr (~datum →)) + τ_result)) ; #:with lit-→ (datum->syntax stx '→) #:with (~datum →) #'arr #:with op/tc (format-id #'op "~a/tc" #'op) @@ -37,28 +38,30 @@ (syntax-parse stx [f:id ; HO case (⊢ (syntax/loc stx op) -; #,(if (attribute ldots) -; #'#'(τ_arg ... (... (... ...)) arr τ_result) -; #'#'(τ_arg ... arr τ_result)))] + #,(if (attribute ldots) + #'#'(τ_arg ... (... (... ...)) arr τ_result) + #'#'(τ_arg ... arr τ_result)))] ;; TODO: for now, just drop the ... - #'(τ_arg ... arr τ_result))] +; #'(τ_arg ... arr τ_result))] [(_ e (... ...)) #:with es+ (stx-map expand/df #'(e (... ...))) #:with τs #'(τ_arg ...) #:fail-unless (let ([es-len (stx-length #'es+)] - [τs-len (sub1 (stx-length #'τs))]) + [τs-len (stx-length #'τs)]) (or (and #,(if (attribute ldots) #t #f) - (>= (- es-len τs-len) 0)) + (>= (- es-len (sub1 τs-len)) 0)) (= es-len τs-len))) #,(if (attribute ldots) #'(format "Wrong number of arguments, given ~a, expected at least ~a" (stx-length #'es+) (sub1 (stx-length #'τs))) #'(format "Wrong number of arguments, given ~a, expected ~a" (stx-length #'es+) (stx-length #'τs))) - #:with τs-ext (let* ([diff (- (stx-length #'es+) (sub1 (stx-length #'τs)))] - [last-τ (stx-last #'τs)] - [last-τs (build-list diff (λ _ last-τ))]) - (append (drop-right (syntax->list #'τs) 1) last-τs)) + #:with τs-ext #,(if (attribute ldots) + #'(let* ([diff (- (stx-length #'es+) (sub1 (stx-length #'τs)))] + [last-τ (stx-last #'τs)] + [last-τs (build-list diff (λ _ last-τ))]) + (append (drop-right (syntax->list #'τs) 1) last-τs)) + #'#'τs) #:when (stx-andmap assert-type #'es+ #'τs-ext) (⊢ (syntax/loc stx (op . es+)) #'τ_result)])))])) @@ -80,8 +83,50 @@ (and (= (length (syntax->list #'τvars1)) (length (syntax->list #'τvars2))) (type=? (apply-forall #'∀τ1 #'fresh-τvars) (apply-forall #'∀τ2 #'fresh-τvars)))] + ;; ldots on lhs + [(((~and τ_arg1 (~not (~literal ...))) ... τ_repeat (~and ldots (~literal ...)) → τ_result1) + ((~and τ_arg2 (~not (~literal ...))) ... → τ_result2)) + (let ([num-arg1 (stx-length #'(τ_arg1 ...))] + [num-arg2 (stx-length #'(τ_arg2 ...))]) + (define diff (- num-arg2 num-arg1)) + (define extra-τs (build-list diff (λ _ #'τ_repeat))) + (with-syntax ([(τ_arg1/ext ...) (append (syntax->list #'(τ_arg1 ...)) extra-τs)]) + (and (= (length (syntax->list #'(τ_arg1/ext ...))) (length (syntax->list #'(τ_arg2 ...)))) + (stx-andmap type=? #'(τ_arg1/ext ...) #'(τ_arg2 ...)) + (type=? #'τ_result1 #'τ_result2))))] + ;; ldots on rhs + [(((~and τ_arg2 (~not (~literal ...))) ... → τ_result2) + ((~and τ_arg1 (~not (~literal ...))) ... τ_repeat (~and ldots (~literal ...)) → τ_result1)) + (let ([num-arg1 (stx-length #'(τ_arg1 ...))] + [num-arg2 (stx-length #'(τ_arg2 ...))]) + (define diff (- num-arg2 num-arg1)) + (define extra-τs (build-list diff (λ _ #'τ_repeat))) + (with-syntax ([(τ_arg1/ext ...) (append (syntax->list #'(τ_arg1 ...)) extra-τs)]) + (and (= (length (syntax->list #'(τ_arg1/ext ...))) (length (syntax->list #'(τ_arg2 ...)))) + (stx-andmap type=? #'(τ_arg1/ext ...) #'(τ_arg2 ...)) + (type=? #'τ_result1 #'τ_result2))))] + ;; ldots on both lhs and rhs + [(((~and τ_arg1 (~not (~literal ...))) ... τ_repeat1 (~and ldots1 (~literal ...)) → τ_result1) + ((~and τ_arg2 (~not (~literal ...))) ... τ_repeat2 (~and ldots2 (~literal ...)) → τ_result2)) + (let ([num-arg1 (stx-length #'(τ_arg1 ...))] + [num-arg2 (stx-length #'(τ_arg2 ...))]) + (cond [(> num-arg2 num-arg1) + (define diff (- num-arg2 num-arg1)) + (define extra-τs (build-list diff (λ _ #'τ_repeat1))) + (with-syntax ([(τ_arg1/ext ...) (append (syntax->list #'(τ_arg1 ...)) extra-τs)]) + (and (= (length (syntax->list #'(τ_arg1/ext ...))) (length (syntax->list #'(τ_arg2 ...)))) + (stx-andmap type=? #'(τ_arg1/ext ...) #'(τ_arg2 ...)) + (type=? #'τ_result1 #'τ_result2)))] + [else + (define diff (- num-arg1 num-arg2)) + (define extra-τs (build-list diff (λ _ #'τ_repeat2))) + (with-syntax ([(τ_arg2/ext ...) (append (syntax->list #'(τ_arg2 ...)) extra-τs)]) + (and (= (length (syntax->list #'(τ_arg2/ext ...))) (length (syntax->list #'(τ_arg1 ...)))) + (stx-andmap type=? #'(τ_arg2/ext ...) #'(τ_arg1 ...)) + (type=? #'τ_result1 #'τ_result2)))]))] [((τ_arg1 ... → τ_result1) (τ_arg2 ... → τ_result2)) (and (= (length (syntax->list #'(τ_arg1 ...))) (length (syntax->list #'(τ_arg2 ...)))) + (stx-andmap type=? #'(τ_arg1 ...) #'(τ_arg2 ...)) (type=? #'τ_result1 #'τ_result2))] [((tycon1:id τ1 ...) (tycon2:id τ2 ...)) (and (free-identifier=? #'tycon1 #'tycon2)