Allow type annotations after for clauses in the for macros.

Cf. GH issue #751.

original commit: 8de77f596f6337531656aac8e5aed980d4d51267
This commit is contained in:
Vincent St-Amour 2014-08-21 16:44:09 -04:00
parent 77c61e5eaf
commit 38c64fcdc6
2 changed files with 295 additions and 114 deletions

View File

@ -23,7 +23,7 @@ This file defines two sorts of primitives. All of them are provided into any mod
(provide (except-out (all-defined-out) dtsi* dtsi/exec* -let-internal define-for-variants define-for*-variants
with-handlers: for/annotation for*/annotation define-for/acc:-variants base-for/flvector: base-for/vector
with-handlers: define-for/acc:-variants base-for/flvector: base-for/vector
-lambda -define -do -let -let* -let*-values -let-values -let/cc -let/ec -letrec -letrec-values -struct)
;; provide the contracted bindings as primitives
(all-from-out "base-contracted.rkt")
@ -844,18 +844,6 @@ This file defines two sorts of primitives. All of them are provided into any mod
(ann #,do-stx #,(attribute ty)))
do-stx)]))
;; wrap the original for with a type annotation
(define-syntax (for/annotation stx)
(syntax-parse stx
[(_ x ...)
(quasisyntax/loc stx
(ann #,(syntax/loc stx (for x ...)) Void))]))
(define-syntax (for*/annotation stx)
(syntax-parse stx
[(_ x ...)
(syntax/loc stx
(ann (for* x ...) Void))]))
;; we need handle #:when clauses manually because we need to annotate
;; the type of each nested for
(define-syntax (for: stx)
@ -863,8 +851,10 @@ This file defines two sorts of primitives. All of them are provided into any mod
;; the annotation is not necessary (always of Void type), but kept
;; for consistency with the other for: macros
[(_ (~optional (~seq : Void))
clauses
(~optional (~seq : Void)) ; can be either before or after
;; c is not always an expression, could be a break-clause
clauses c ...) ; no need to annotate the type, it's always Void
c ...) ; no need to annotate the type, it's always Void
(let ((body #'(; break-clause ...
c ...)))
(let loop ((clauses #'clauses))
@ -926,15 +916,17 @@ This file defines two sorts of primitives. All of them are provided into any mod
;; If the only #:when clause is the last clause, inference should work.
(define-for-syntax (define-for-variant name)
(lambda (stx)
(syntax-parse stx #:literals (:)
[(_ a:optional-standalone-annotation*
(syntax-parse stx
[(_ a1:optional-standalone-annotation*
clause:for-clauses
a2:optional-standalone-annotation* ; can be either before or after
c ...) ; c is not always an expression, can be a break-clause
((attribute a.annotate)
(quasisyntax/loc stx
(#,name
(clause.expand ... ...)
c ...)))])))
((attribute a1.annotate)
((attribute a2.annotate)
(quasisyntax/loc stx
(#,name
(clause.expand ... ...)
c ...))))])))
(define-syntax (define-for-variants stx)
(syntax-parse stx
@ -955,19 +947,11 @@ This file defines two sorts of primitives. All of them are provided into any mod
;; Unlike with the above, the inferencer can handle any number of #:when
;; clauses with these 2.
(define-syntax (for/lists: stx)
(syntax-parse stx #:literals (:)
[(_ : ty
(syntax-parse stx
[(_ a1:optional-standalone-annotation*
(var:optionally-annotated-formal ...)
clause:for-clauses
c ...) ; c is not always an expression, can be a break-clause
(add-ann
(quasisyntax/loc stx
(for/lists (var.ann-name ...)
(clause.expand ... ...)
c ...))
#'ty)]
[(_ (var:optionally-annotated-formal ...)
clause:for-clauses
a2:optional-standalone-annotation*
c ...)
(define all-typed? (andmap values (attribute var.ty)))
(define for-stx
@ -975,25 +959,19 @@ This file defines two sorts of primitives. All of them are provided into any mod
(for/lists (var.ann-name ...)
(clause.expand ... ...)
c ...)))
(if all-typed?
(add-ann
for-stx
#'(values var.ty ...))
for-stx)]))
((attribute a1.annotate)
((attribute a2.annotate)
(if all-typed?
(add-ann
for-stx
#'(values var.ty ...))
for-stx)))]))
(define-syntax (for/fold: stx)
(syntax-parse stx #:literals (:)
[(_ : ty
((var:optionally-annotated-name init:expr) ...)
clause:for-clauses
c ...) ; c is not always an expression, can be a break-clause
(add-ann
(quasisyntax/loc stx
(for/fold ((var.ann-name init) ...)
(clause.expand ... ...)
c ...))
#'ty)]
[(_ accum:accumulator-bindings
(syntax-parse stx
[(_ a1:optional-standalone-annotation*
accum:accumulator-bindings
clause:for-clauses
a2:optional-standalone-annotation*
c ...)
(define all-typed? (andmap values (attribute accum.ty)))
(define for-stx
@ -1001,16 +979,19 @@ This file defines two sorts of primitives. All of them are provided into any mod
(for/fold ((accum.ann-name accum.init) ...)
(clause.expand ... ...)
c ...)))
(if all-typed?
(add-ann
for-stx
#'(values accum.ty ...))
for-stx)]))
((attribute a1.annotate)
((attribute a2.annotate)
(if all-typed?
(add-ann
for-stx
#'(values accum.ty ...))
for-stx)))]))
(define-syntax (for*: stx)
(syntax-parse stx #:literals (:)
[(_ (~seq : Void) ...
(syntax-parse stx #:literals (: Void)
[(_ (~optional (~seq : Void))
clause:for-clauses
(~optional (~seq : Void))
c ...) ; c is not always an expression, can be a break-clause
(quasisyntax/loc stx
(for: (clause.expand* ... ...)
@ -1019,14 +1000,16 @@ This file defines two sorts of primitives. All of them are provided into any mod
;; These currently only typecheck in very limited cases.
(define-for-syntax (define-for*-variant name)
(lambda (stx)
(syntax-parse stx #:literals (:)
[(_ a:optional-standalone-annotation*
(syntax-parse stx
[(_ a1:optional-standalone-annotation*
clause:for-clauses
a2:optional-standalone-annotation*
c ...) ; c is not always an expression, can be a break-clause
((attribute a.annotate)
(quasisyntax/loc stx
(#,name (clause.expand ... ...)
c ...)))])))
((attribute a1.annotate)
((attribute a2.annotate)
(quasisyntax/loc stx
(#,name (clause.expand ... ...)
c ...))))])))
(define-syntax (define-for*-variants stx)
(syntax-parse stx
[(_ (name no-colon-name) ...)
@ -1042,56 +1025,47 @@ This file defines two sorts of primitives. All of them are provided into any mod
;; Like for/lists: and for/fold:, the inferencer can handle these correctly.
(define-syntax (for*/lists: stx)
(syntax-parse stx #:literals (:)
[(_ : ty
(syntax-parse stx
[(_ a1:optional-standalone-annotation*
((var:optionally-annotated-name) ...)
clause:for-clauses
c ...) ; c is not always an expression, can be a break-clause
(add-ann
(quasisyntax/loc stx
(for/lists (var.ann-name ...)
(clause.expand* ... ...)
c ...))
#'ty)]
[(_ ((var:annotated-name) ...)
clause:for-clauses
a2:optional-standalone-annotation*
c ...)
(add-ann
(quasisyntax/loc stx
(for/lists (var.ann-name ...)
(clause.expand* ... ...)
c ...))
#'(values var.ty ...))]))
((attribute a1.annotate)
((attribute a2.annotate)
(add-ann
(quasisyntax/loc stx
(for/lists (var.ann-name ...)
(clause.expand* ... ...)
c ...))
#'(values var.ty ...))))]))
(define-syntax (for*/fold: stx)
(syntax-parse stx #:literals (:)
[(_ : ty
[(_ a1:optional-standalone-annotation*
((var:optionally-annotated-name init:expr) ...)
clause:for-clauses
c ...) ; c is not always an expression, can be a break-clause
(add-ann
(quasisyntax/loc stx
(for/fold ((var.ann-name init) ...)
(clause.expand* ... ...)
c ...))
#'ty)]
[(_ ((var:annotated-name init:expr) ...)
clause:for-clauses
a2:optional-standalone-annotation*
c ...)
(add-ann
(quasisyntax/loc stx
(for/fold ((var.ann-name init) ...)
(clause.expand* ... ...)
c ...))
#'(values var.ty ...))]))
((attribute a1.annotate)
((attribute a2.annotate)
(add-ann
(quasisyntax/loc stx
(for/fold ((var.ann-name init) ...)
(clause.expand* ... ...)
c ...))
#'(values var.ty ...))))]))
(define-for-syntax (define-for/acc:-variant for*? for/folder: for/folder op initial final)
(lambda (stx)
(syntax-parse stx #:literals (:)
[(_ a:optional-standalone-annotation*
[(_ a1:optional-standalone-annotation*
clause:for-clauses
a2:optional-standalone-annotation*
c ...) ; c is not always an expression, can be a break-clause
(define a.ty (or (attribute a2.ty)
(attribute a1.ty)))
(cond
[(attribute a.ty)
[a.ty
;; ty has to include exact 0, exact 1, null (sum/product/list respectively),
;; the initial value of the accumulator
;; (to be consistent with Racket semantics).
@ -1100,7 +1074,7 @@ This file defines two sorts of primitives. All of them are provided into any mod
;; (for/sum: : Float ([i : Float '(1.1)] #:when (zero? (random 1))) i)
(quasisyntax/loc stx
(#,final
(#,for/folder: : a.ty ([acc : a.ty #,initial])
(#,for/folder: : #,a.ty ([acc : #,a.ty #,initial])
(clause.expand ... ...)
(let ([new (let () c ...)])
(#,op acc new)))))]
@ -1128,20 +1102,21 @@ This file defines two sorts of primitives. All of them are provided into any mod
(define-for-syntax (define-for/hash:-variant hash-maker)
(lambda (stx)
(syntax-parse stx
#:literals (:)
[(_ (~seq : return-annotation:expr)
[(_ a1:optional-standalone-annotation*
clause:for-clauses
a2:optional-standalone-annotation*
body ...) ; body is not always an expression, can be a break-clause
(quasisyntax/loc stx
(for/fold: : return-annotation
((return-hash : return-annotation (ann (#,hash-maker null) return-annotation)))
(clause.expand ... ...)
(let-values (((key val) (let () body ...)))
(hash-set return-hash key val))))]
[(_ clause:for-clauses body ...)
(syntax/loc stx
(for/hash (clause.expand ... ...)
body ...))])))
(define a.ty (or (attribute a2.ty) (attribute a1.ty)))
(if a.ty
(quasisyntax/loc stx
(for/fold: : #,a.ty
((return-hash : #,a.ty (ann (#,hash-maker null) #,a.ty)))
(clause.expand ... ...)
(let-values (((key val) (let () body ...)))
(hash-set return-hash key val))))
(syntax/loc stx
(for/hash (clause.expand ... ...)
body ...)))])))
(define-syntax for/hash: (define-for/hash:-variant #'make-immutable-hash))
(define-syntax for/hasheq: (define-for/hash:-variant #'make-immutable-hasheq))
@ -1151,15 +1126,19 @@ This file defines two sorts of primitives. All of them are provided into any mod
(lambda (stx)
(syntax-parse stx
#:literals (:)
((_ (~seq : return-annotation:expr)
[(_ a1:optional-standalone-annotation*
clause:for-clauses
a2:optional-standalone-annotation*
body ...) ; body is not always an expression, can be a break-clause
(define a.ty (or (attribute a2.ty) (attribute a1.ty)))
(quasisyntax/loc stx
(for*/fold: : return-annotation
((return-hash : return-annotation (ann (#,hash-maker null) return-annotation)))
(for*/fold: #,@(if a.ty #`(: #,a.ty) #'())
#,(if a.ty
#`((return-hash : #,a.ty (ann (#,hash-maker null) #,a.ty)))
#`((return-hash (#,hash-maker null))))
(clause.expand* ... ...)
(let-values (((key val) (let () body ...)))
(hash-set return-hash key val))))))))
(hash-set return-hash key val))))])))
(define-syntax for*/hash: (define-for*/hash:-variant #'make-immutable-hash))
(define-syntax for*/hasheq: (define-for*/hash:-variant #'make-immutable-hasheq))

View File

@ -7,12 +7,21 @@
#t
(error (format "Check (~a ~a ~a) failed" f a b))))
;; Each test is there twice, once with the type annotation before the for
;; clauses, and once after.
(check string=?
(with-output-to-string
(lambda ()
(for: : Void ([i : Integer (in-range 10)])
(display i))))
"0123456789")
(check string=?
(with-output-to-string
(lambda ()
(for: ([i : Integer (in-range 10)]) : Void
(display i))))
"0123456789")
(check string=?
(with-output-to-string
@ -25,10 +34,24 @@
#:when k)
(display (list i j k)))))
"(1 a #t)(1 a #t)(3 c #t)(3 c #t)")
(check string=?
(with-output-to-string
(lambda ()
(for: ((i : Integer '(1 2 3))
(j : Char "abc")
#:when (odd? i)
(k : Boolean #(#t #t))
#:when k)
: Void
(display (list i j k)))))
"(1 a #t)(1 a #t)(3 c #t)(3 c #t)")
(check equal?
(for/list: : (Listof Integer) ([i : Integer (in-range 10)]) i)
'(0 1 2 3 4 5 6 7 8 9))
(check equal?
(for/list: ([i : Integer (in-range 10)]) : (Listof Integer) i)
'(0 1 2 3 4 5 6 7 8 9))
(check equal?
(for/list: : (Listof Integer)
@ -37,6 +60,14 @@
#:when (odd? i))
(+ i j 10))
'(21 43))
(check equal?
(for/list: ((i : Integer '(1 2 3))
(j : Integer '(10 20 30))
#:when (odd? i))
: (Listof Integer)
(+ i j 10))
'(21 43))
(check equal?
(for/list: : (Listof Integer)
((i : Integer '(1 2 3))
@ -44,12 +75,24 @@
#:unless (odd? i))
(+ i j 10))
'(32))
(check equal?
(for/list: ((i : Integer '(1 2 3))
(j : Integer '(10 20 30))
#:unless (odd? i))
: (Listof Integer)
(+ i j 10))
'(32))
(check equal?
(for/or: : Boolean
((i : Integer '(1 2 3)))
(>= i 3))
#t)
(check equal?
(for/or: ((i : Integer '(1 2 3)))
: Boolean
(>= i 3))
#t)
(check equal?
(for/or: : Boolean
@ -57,6 +100,12 @@
(j : Integer '(2 1 3)))
(>= i j))
#t)
(check equal?
(for/or: ((i : Integer '(1 2 3))
(j : Integer '(2 1 3)))
: Boolean
(>= i j))
#t)
(check equal?
(let-values: ([([x : (Listof Integer)] [y : (Listof Integer)])
@ -70,6 +119,18 @@
(values i j))])
(append x y))
'(1 1 2 2 3 3 20 30 20 30 20 30))
(check equal?
(let-values: ([([x : (Listof Integer)] [y : (Listof Integer)])
(for/lists: ((x : (Listof Integer))
(y : (Listof Integer)))
((i : Integer '(1 2 3))
#:when #t
(j : Integer '(10 20 30))
#:when (> j 12))
: (values (Listof Integer) (Listof Integer))
(values i j))])
(append x y))
'(1 1 2 2 3 3 20 30 20 30 20 30))
(check =
(for/fold: : Integer
@ -78,6 +139,13 @@
(j : Integer '(10 20 30)))
(+ acc i j))
66)
(check =
(for/fold: ((acc : Integer 0))
((i : Integer '(1 2 3))
(j : Integer '(10 20 30)))
: Integer
(+ acc i j))
66)
(check =
(for/fold: : Integer
@ -89,6 +157,16 @@
(k : Integer '(100 200 300)))
(+ acc i j k))
1998)
(check =
(for/fold: ((acc : Integer 0))
((i : Integer '(1 2 3))
#:when (even? i)
(j : Integer '(10 20 30))
#:when #t
(k : Integer '(100 200 300)))
: Integer
(+ acc i j k))
1998)
(check string=?
(with-output-to-string
@ -98,6 +176,14 @@
(j : Integer '(10 20 30)))
(display (list i j)))))
"(1 10)(1 20)(1 30)(2 10)(2 20)(2 30)(3 10)(3 20)(3 30)")
(check string=?
(with-output-to-string
(lambda ()
(for*: ((i : Integer '(1 2 3))
(j : Integer '(10 20 30)))
: Void
(display (list i j)))))
"(1 10)(1 20)(1 30)(2 10)(2 20)(2 30)(3 10)(3 20)(3 30)")
(check equal?
(let-values: ([([x : (Listof Integer)] [y : (Listof Integer)])
@ -110,6 +196,17 @@
(values i j))])
(append x y))
'(1 1 2 2 3 3 20 30 20 30 20 30))
(check equal?
(let-values: ([([x : (Listof Integer)] [y : (Listof Integer)])
(for*/lists: ((x : (Listof Integer))
(y : (Listof Integer)))
((i : Integer '(1 2 3))
(j : Integer '(10 20 30))
#:when (> j 12))
: (values (Listof Integer) (Listof Integer))
(values i j))])
(append x y))
'(1 1 2 2 3 3 20 30 20 30 20 30))
(check =
(for*/fold: : Integer
@ -120,6 +217,16 @@
(k : Integer '(100 200 300)))
(+ acc i j k))
1998)
(check =
(for*/fold: ((acc : Integer 0))
((i : Integer '(1 2 3))
#:when (even? i)
(j : Integer '(10 20 30))
(k : Integer '(100 200 300)))
: Integer
(+ acc i j k))
1998)
(check =
(for*/fold: : Integer
((acc : Integer 0))
@ -129,35 +236,74 @@
(k : Integer '(100 200 300)))
(+ acc i j k))
3996)
(check =
(for*/fold: ((acc : Integer 0))
((i : Integer '(1 2 3))
#:unless (even? i)
(j : Integer '(10 20 30))
(k : Integer '(100 200 300)))
: Integer
(+ acc i j k))
3996)
(check =
(for/sum: : Integer
([i : Integer (in-range 10)])
i)
45)
(check =
(for/sum: ([i : Integer (in-range 10)])
: Integer
i)
45)
(check =
(for/sum: : Integer
([i : Integer (in-range 10)]
[j : Integer (in-range 10)])
(+ i j))
90)
(check =
(for/sum: ([i : Integer (in-range 10)]
[j : Integer (in-range 10)])
: Integer
(+ i j))
90)
(check =
(for/product: : Integer
([i : Integer (in-range 10)])
i)
0)
(check =
(for/product: ([i : Integer (in-range 10)])
: Integer
i)
0)
(check =
(for/product: : Integer
([i : Integer (in-range 1 10)])
i)
362880)
(check =
(for/product: ([i : Integer (in-range 1 10)])
: Integer
i)
362880)
(check =
(for/product: : Integer
([i : Integer (in-range 1 10)]
[j : Integer (in-range 1 10)])
(+ i j))
185794560)
(check =
(for/product: ([i : Integer (in-range 1 10)]
[j : Integer (in-range 1 10)])
: Integer
(+ i j))
185794560)
;; for/product: had problems with Real due to an unannotated accumulator
(check =
@ -165,6 +311,11 @@
([i (in-list (list 1.2 -1.0 0.5))])
i)
-0.6)
(check =
(for/product: ([i (in-list (list 1.2 -1.0 0.5))])
: Real
i)
-0.6)
;; multiclause versions of these don't currently work properly
(check =
@ -172,17 +323,33 @@
([i : Integer (in-range 10)])
i)
45)
(check =
(for*/sum: ([i : Integer (in-range 10)])
: Integer
i)
45)
(check =
(for*/product: : Integer
([i : Integer (in-range 10)])
i)
0)
(check =
(for*/product: ([i : Integer (in-range 10)])
: Integer
i)
0)
(check =
(for*/product: : Integer
([i : Integer (in-range 1 10)])
i)
362880)
(check =
(for*/product: ([i : Integer (in-range 1 10)])
: Integer
i)
362880)
;; Integers as sequences.
@ -191,21 +358,44 @@
([i : Byte 4])
i)
6)
(check =
(for/sum: ([i : Byte 4])
: Integer
i)
6)
(check =
(for/sum: : Integer
([i : Index (ann 4 Index)])
i)
6)
(check =
(for/sum: ([i : Index (ann 4 Index)])
: Integer
i)
6)
(check =
(for/sum: : Integer
([i : Nonnegative-Fixnum (ann 4 Fixnum)])
i)
6)
(check =
(for/sum: ([i : Nonnegative-Fixnum (ann 4 Fixnum)])
: Integer
i)
6)
(check =
(for/sum: : Integer
([i : Natural (ann 4 Integer)])
i)
6)
(check =
(for/sum: ([i : Natural (ann 4 Integer)])
: Integer
i)
6)
(check string=?
(with-output-to-string
@ -216,6 +406,9 @@
(check equal?
(for/hasheq: : (HashTable Integer String) ([k (list 2 3 4)]) (values k "val"))
#hasheq((2 . "val") (3 . "val") (4 . "val")))
(check equal?
(for/hasheq: ([k (list 2 3 4)]) : (HashTable Integer String) (values k "val"))
#hasheq((2 . "val") (3 . "val") (4 . "val")))
(check equal?
(for/vector: ([i : Natural (in-range 3)]) 5)
@ -224,6 +417,9 @@
(check equal?
(for/vector: : (Vectorof Number) ([i : Natural (in-range 3)]) 5)
(vector 5 5 5))
(check equal?
(for/vector: ([i : Natural (in-range 3)]) : Number 5)
(vector 5 5 5))
(check equal?
@ -232,6 +428,12 @@
(j : Natural (and (in-range 5))))
(+ i j))
(list 0 2 4 6 8))
(check equal?
(for/list: ((i : Natural (and (in-naturals)))
(j : Natural (and (in-range 5))))
: (Listof Natural)
(+ i j))
(list 0 2 4 6 8))
;; break and final clauses
;; TODO typechecker can't handle these