for/set: same body handling as for/list, etc.

Change `for/set` to use `split-for-body`. Also, adjust the
documentation of `for/fold/derived` to recommend using
`split-for-body`.

Closes #3351
This commit is contained in:
Matthew Flatt 2020-08-16 16:29:44 -06:00
parent 413106413d
commit 7d8a95a943
3 changed files with 76 additions and 17 deletions

View File

@ -1,5 +1,6 @@
#lang scribble/doc
@(require "mz.rkt")
@(require "mz.rkt"
(for-label syntax/for-body))
@title[#:tag "for"]{Iterations and Comprehensions: @racket[for], @racket[for/list], ...}
@ -560,18 +561,25 @@ Like @racket[for/list], etc., but with the implicit nesting of
Like @racket[for/fold], but the extra @racket[orig-datum] is used as the
source for all syntax errors.
A macro that expands to @racket[for/fold/derived] should typically use
@racket[split-for-body] to handle the possibility of macros and other
definitions mixed with keywords like @racket[#:break].
@mz-examples[#:eval for-eval
(require (for-syntax syntax/for-body))
(define-syntax (for/digits stx)
(syntax-case stx ()
[(_ clauses body ... tail-expr)
(with-syntax ([original stx])
(with-syntax ([original stx]
[((pre-body ...) (post-body ...))
(split-for-body stx #'(body ... tail-expr))])
#'(let-values
([(n k)
(for/fold/derived
original ([n 0] [k 1])
clauses
body ...
(values (+ n (* tail-expr k)) (* k 10)))])
pre-body ...
(values (+ n (* (let () post-body ...) k)) (* k 10)))])
n))]))
@code:comment{If we misuse for/digits, we can get good error reporting}
@ -592,12 +600,14 @@ source for all syntax errors.
(define-syntax (for/max stx)
(syntax-case stx ()
[(_ clauses body ... tail-expr)
(with-syntax ([original stx])
(with-syntax ([original stx]
[((pre-body ...) (post-body ...))
(split-for-body stx #'(body ... tail-expr))])
#'(for/fold/derived original
([current-max -inf.0])
clauses
body ...
(define maybe-new-max tail-expr)
pre-body ...
(define maybe-new-max (let () post-body ...))
(if (> maybe-new-max current-max)
maybe-new-max
current-max)))]))
@ -614,16 +624,19 @@ source for all syntax errors.
Like @racket[for*/fold], but the extra @racket[orig-datum] is used as the source for all syntax errors.
@mz-examples[#:eval for-eval
(require (for-syntax syntax/for-body))
(define-syntax (for*/digits stx)
(syntax-case stx ()
[(_ clauses body ... tail-expr)
(with-syntax ([original stx])
(with-syntax ([original stx]
[((pre-body ...) (post-body ...))
(split-for-body stx #'(body ... tail-expr))])
#'(let-values
([(n k)
(for*/fold/derived original ([n 0] [k 1])
clauses
body ...
(values (+ n (* tail-expr k)) (* k 10)))])
pre-body ...
(values (+ n (* (let () post-body ...) k)) (* k 10)))])
n))]))
(eval:error

View File

@ -563,6 +563,50 @@
#:final (= i 2)
(add1 i)))
;; ----------------------------------------
(test (set 0) 'non-expression-last-form
(for/set ([x '(1)])
(begin
(define-syntax (m stx) #'0)
m)))
(test (set 10) 'non-expression-last-form
(for/set ([x '(1)])
(define (f x) (g x))
(define-syntax-rule (m g)
(begin
(define (g x) 10)
(f 1)))
(m g)))
(test (mutable-set 0) 'non-expression-last-form
(for/mutable-set ([x '(1)])
(begin
(define-syntax (m stx) #'0)
m)))
(test (set 0) 'non-expression-last-form
(for*/set ([x '(1)])
(begin
(define-syntax (m stx) #'0)
m)))
(test (set 10) 'non-expression-last-form
(for*/set ([x '(1)])
(define (f x) (g x))
(define-syntax-rule (m g)
(begin
(define (g x) 10)
(f 1)))
(m g)))
(test (mutable-set 0) 'non-expression-last-form
(for*/mutable-set ([x '(1)])
(begin
(define-syntax (m stx) #'0)
m)))
;; ----------------------------------------
;; chaperone-hash-set tests

View File

@ -8,7 +8,7 @@
racket/unsafe/ops
(only-in racket/syntax format-symbol)
(only-in racket/generic exn:fail:support)
(for-syntax racket/base racket/syntax))
(for-syntax racket/base racket/syntax syntax/for-body))
(provide set seteq seteqv
weak-set weak-seteq weak-seteqv
@ -1045,13 +1045,14 @@
(lambda (stx)
(syntax-case stx ()
[(form clauses body ... expr)
(with-syntax ([original stx])
(with-syntax ([original stx]
[((pre-body ...) (post-body ...)) (split-for-body stx #'(body ... expr))])
(syntax-protect
#'(immutable-custom-set
(begin0 #f (dprintf "~a\n" 'form))
(for_/fold/derived original ([table (make-table)]) clauses
body ...
(hash-set table expr #t)))))]))))
pre-body ...
(hash-set table (let () post-body ...) #t)))))]))))
(define (immutable-fors table-id)
(values (immutable-for #'for/fold/derived table-id)
@ -1064,13 +1065,14 @@
(lambda (stx)
(syntax-case stx ()
[(form clauses body ... expr)
(with-syntax ([original stx])
(with-syntax ([original stx]
[((pre-body ...) (post-body ...)) (split-for-body stx #'(body ... expr))])
(syntax-protect
#'(let ([table (make-table)])
(dprintf "~a\n" 'form)
(for_/fold/derived original () clauses
body ...
(hash-set! table expr #t)
pre-body ...
(hash-set! table (let () post-body ...) #t)
(values))
(make-set #f table))))]))))