diff --git a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/private/with-types.rkt b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/private/with-types.rkt index 05d54308..9b94bec6 100644 --- a/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/private/with-types.rkt +++ b/pkgs/typed-racket-pkgs/typed-racket-lib/typed-racket/private/with-types.rkt @@ -1,6 +1,9 @@ #lang racket/base -(require racket/require racket/promise +(require "../utils/utils.rkt" + (utils lift) + (typecheck tc-toplevel) + racket/require racket/promise (for-template (except-in racket/base for for* with-handlers lambda λ define let let* letrec letrec-values let-values @@ -68,15 +71,14 @@ (for ([i (in-syntax fvids)] [ty (in-list fv-types)]) (register-type i ty)) - (define expanded-body - (disarm* - (if expr? - (with-syntax ([body body]) - (local-expand #'(let () . body) ctx null)) - (with-syntax ([(body ...) body] - [(id ...) exids] - [(ty ...) extys]) - (local-expand #'(let () (begin (: id ty) ... body ... (values id ...))) ctx null))))) + (define-values (lifted-definitions expanded-body) + (if expr? + (with-syntax ([body body]) + (wt-expand #'(let () . body) ctx)) + (with-syntax ([(body ...) body] + [(id ...) exids] + [(ty ...) extys]) + (wt-expand #'(let () (begin (: id ty) ... body ... (values id ...))) ctx)))) (parameterize (;; do we report multiple errors [delay-errors? #t] ;; this parameter is just for printing types @@ -94,6 +96,10 @@ ;; for error reporting [orig-module-stx stx] [expanded-module-stx expanded-body]) + ;; we can treat the lifted definitions as top-level forms because they + ;; are only definitions and not forms that have special top-level meaning + ;; to TR + (tc-toplevel-form lifted-definitions) (tc-expr/check expanded-body (if expr? region-tc-result (ret ex-types)))) (report-all-errors) (set-box! typed-context? old-context) @@ -118,14 +124,25 @@ (c:with-contract typed-region #:results (region-cnt ...) #:freevars ([fv.id cnt] ...) + #,lifted-definitions body))) - (syntax/loc stx + (quasisyntax/loc stx (begin (define-values () (begin check-syntax-help (values))) (c:with-contract typed-region ([ex-id ex-cnt] ...) + #,lifted-definitions (define-values (ex-id ...) body)))))))) +;; Syntax (U Symbol List) -> (values Syntax Syntax) +;; local expansion for with-type expressions +(define (wt-expand stx ctx) + (syntax-parse (local-expand/capture* stx ctx null) + #:literal-sets (kernel-literals) + [(begin (define-values (x ...) e ...) ... (let-values () . body)) + (values (disarm* #'(begin (define-values (x ...) e ...) ...)) + (disarm* (local-expand/capture* #'(let-values () . body) ctx null)))])) + (define (wt-core stx) (define-syntax-class typed-id #:description "[id type]" diff --git a/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/succeed/with-type-lift.rkt b/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/succeed/with-type-lift.rkt new file mode 100644 index 00000000..8f2d2c3a --- /dev/null +++ b/pkgs/typed-racket-pkgs/typed-racket-test/tests/typed-racket/succeed/with-type-lift.rkt @@ -0,0 +1,18 @@ +#lang racket/base + +;; Test syntax lifting in `with-type` + +(require rackunit typed/racket) + +(with-type #:result Number + (define-syntax (m stx) + (syntax-local-lift-expression #'(+ 1 2))) + (m)) + +(define-syntax (m2 stx) + (syntax-local-lift-expression #'(+ 1 2))) + +(with-type #:result Number (m2)) + +(with-type ([val Number]) (define val (m2))) +(check-equal? val 3)