From 71045e255c0cc8d8f5db08f8cc7931ad21f0c496 Mon Sep 17 00:00:00 2001 From: Matthew Flatt Date: Wed, 25 Jul 2012 20:05:07 -0600 Subject: [PATCH] fix `for/vector' to really stop at size Avoids an infinite loop for an infinite sequence, for example. Also, expand via `for/fold/derived' to improve error reporting. --- collects/racket/private/for.rkt | 77 ++++++++++++++----- collects/racket/private/vector-wraps.rkt | 97 ++++++++++++++++-------- collects/tests/racket/flonum.rktl | 12 +++ collects/tests/racket/for.rktl | 34 +++++++++ 4 files changed, 167 insertions(+), 53 deletions(-) diff --git a/collects/racket/private/for.rkt b/collects/racket/private/for.rkt index 01395110d5..cd12576c58 100644 --- a/collects/racket/private/for.rkt +++ b/collects/racket/private/for.rkt @@ -1360,38 +1360,73 @@ (define-syntax (for/vector stx) (syntax-case stx () [(for/vector (for-clause ...) body ...) - (syntax/loc stx - (list->vector - (for/list (for-clause ...) body ...)))] + (with-syntax ([orig-stx stx]) + (syntax/loc stx + (list->vector + (reverse + (for/fold/derived + orig-stx + ([l null]) + (for-clause ...) + (cons (let () body ...) l))))))] [(for/vector #:length length-expr (for-clause ...) body ...) - (syntax/loc stx - (let ([len length-expr]) - (unless (exact-nonnegative-integer? len) - (raise-argument-error 'for/vector "exact-nonnegative-integer?" len)) - (let ([v (make-vector len)]) - (for/fold ([i 0]) - (for-clause ... #:when (< i len)) - (vector-set! v i (let () body ...)) - (add1 i)) - v)))])) + (with-syntax ([orig-stx stx]) + (syntax/loc stx + (let ([len length-expr]) + (unless (exact-nonnegative-integer? len) + (raise-argument-error 'for/vector "exact-nonnegative-integer?" len)) + (let ([v (make-vector len)]) + (unless (zero? len) + (let ([len-1 (sub1 len)]) + (for/fold/derived + orig-stx + ([vd (void)]) + ([i (stop-after (*in-naturals) (lambda (i) (= i len-1)))] + for-clause ...) + (vector-set! v i (let () body ...)) + (void)))) + v))))])) (define-syntax (for*/vector stx) (syntax-case stx () [(for*/vector (for-clause ...) body ...) - (syntax/loc stx - (list->vector - (for*/list (for-clause ...) body ...)))] + (with-syntax ([orig-stx stx]) + (syntax/loc stx + (list->vector + (reverse + (for*/fold/derived + orig-stx + ([l null]) + (for-clause ...) + (cons (let () body ...) l))))))] [(for*/vector #:length length-expr (for-clause ...) body ...) + (with-syntax ([orig-stx stx] + [(limited-for-clause ...) + (map (lambda (fc) + (syntax-case fc () + [[ids rhs] + (or (identifier? #'ids) + (let ([l (syntax->list #'ids)]) + (and l (andmap identifier? l)))) + (syntax/loc fc [ids (stop-after + rhs + (lambda x + (= i len)))])] + [_ fc])) + (syntax->list #'(for-clause ...)))]) (syntax/loc stx (let ([len length-expr]) (unless (exact-nonnegative-integer? len) (raise-argument-error 'for*/vector "exact-nonnegative-integer?" len)) (let ([v (make-vector len)]) - (for*/fold ([i 0]) - (for-clause ... #:when (< i len)) - (vector-set! v i (let () body ...)) - (add1 i)) - v)))])) + (unless (zero? len) + (for*/fold/derived + orig-stx + ([i 0]) + (limited-for-clause ...) + (vector-set! v i (let () body ...)) + (add1 i))) + v))))])) (define-for-syntax (do-for/lists for/fold-id stx) (syntax-case stx () diff --git a/collects/racket/private/vector-wraps.rkt b/collects/racket/private/vector-wraps.rkt index 79e716f58f..295f7177d6 100644 --- a/collects/racket/private/vector-wraps.rkt +++ b/collects/racket/private/vector-wraps.rkt @@ -41,41 +41,74 @@ (define-syntax (for/fXvector stx) (syntax-case stx () - ((for/fXvector (for-clause ...) body ...) - (syntax/loc stx - (list->fXvector - (for/list (for-clause ...) body ...)))) - ((for/fXvector #:length length-expr (for-clause ...) body ...) - (syntax/loc stx - (let ((len length-expr)) - (unless (exact-nonnegative-integer? len) - (raise-argument-error 'for/fXvector "exact-nonnegative-integer?" len)) - (let ((v (make-fXvector len))) - (for/fold ((i 0)) - (for-clause ... - #:when (< i len)) - (fXvector-set! v i (begin body ...)) - (add1 i)) - v)))))) + [(for/fXvector (for-clause ...) body ...) + (with-syntax ([orig-stx stx]) + (syntax/loc stx + (list->fXvector + (reverse + (for/fold/derived + orig-stx + ([l null]) + (for-clause ...) + (cons (let () body ...) l))))))] + [(for/fXvector #:length length-expr (for-clause ...) body ...) + (with-syntax ([orig-stx stx]) + (syntax/loc stx + (let ([len length-expr]) + (unless (exact-nonnegative-integer? len) + (raise-argument-error 'for/fXvector "exact-nonnegative-integer?" len)) + (let ([v (make-fXvector len)]) + (unless (zero? len) + (let ([len-1 (sub1 len)]) + (for/fold/derived + orig-stx + ([vd (void)]) + ([i (stop-after (in-naturals) (lambda (i) (= i len-1)))] + for-clause ...) + (fXvector-set! v i (let () body ...)) + (void)))) + v))))])) (define-syntax (for*/fXvector stx) (syntax-case stx () - ((for*/fXvector (for-clause ...) body ...) - (syntax/loc stx - (list->fXvector - (for*/list (for-clause ...) body ...)))) - ((for*/fXvector #:length length-expr (for-clause ...) body ...) - (syntax/loc stx - (let ((len length-expr)) - (unless (exact-nonnegative-integer? len) - (raise-argument-error 'for*/fXvector "exact-nonnegative-integer?" len)) - (let ((v (make-fXvector len))) - (for*/fold ((i 0)) - (for-clause ... - #:when (< i len)) - (fXvector-set! v i (begin body ...)) - (add1 i)) - v)))))) + [(for*/fXvector (for-clause ...) body ...) + (with-syntax ([orig-stx stx]) + (syntax/loc stx + (list->fXvector + (reverse + (for*/fold/derived + orig-stx + ([l null]) + (for-clause ...) + (cons (let () body ...) l))))))] + [(for*/fXvector #:length length-expr (for-clause ...) body ...) + (with-syntax ([orig-stx stx] + [(limited-for-clause ...) + (map (lambda (fc) + (syntax-case fc () + [[ids rhs] + (or (identifier? #'ids) + (let ([l (syntax->list #'ids)]) + (and l (andmap identifier? l)))) + (syntax/loc fc [ids (stop-after + rhs + (lambda x + (= i len)))])] + [_ fc])) + (syntax->list #'(for-clause ...)))]) + (syntax/loc stx + (let ([len length-expr]) + (unless (exact-nonnegative-integer? len) + (raise-argument-error 'for*/fXvector "exact-nonnegative-integer?" len)) + (let ([v (make-fXvector len)]) + (unless (zero? len) + (for*/fold/derived + orig-stx + ([i 0]) + (limited-for-clause ...) + (fXvector-set! v i (let () body ...)) + (add1 i))) + v))))])) (define (fXvector-copy flv [start 0] [end (and (fXvector? flv) (fXvector-length flv))]) (unless (fXvector? flv) diff --git a/collects/tests/racket/flonum.rktl b/collects/tests/racket/flonum.rktl index a23ca6d04e..5a61689fed 100644 --- a/collects/tests/racket/flonum.rktl +++ b/collects/tests/racket/flonum.rktl @@ -32,6 +32,18 @@ (test flv 'for*/flvector flv1) (test flv 'for*/flvector-fast flv2)) +;; Stop when a length is specified, even if the sequence continues: +(test (flvector 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0) + 'nat + (for/flvector #:length 10 ([i (in-naturals)]) (exact->inexact i))) +(test (flvector 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0) + 'nats + (for*/flvector #:length 10 ([i (in-naturals)] [j (in-naturals)]) (exact->inexact j))) +(test (flvector 0.0 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 1.0) + 'nat+5 + (for*/flvector #:length 10 ([i (in-naturals)] [j (in-range 5)]) (exact->inexact i))) + + ;; Test for both length too long and length too short (let ((v (make-flvector 3))) (flvector-set! v 0 0.0) diff --git a/collects/tests/racket/for.rktl b/collects/tests/racket/for.rktl index 691e309d5a..a3cd9ddb1f 100644 --- a/collects/tests/racket/for.rktl +++ b/collects/tests/racket/for.rktl @@ -199,6 +199,40 @@ (test (vector 2.0 3.0 4.0) 'for/vector-many-body v2) (test (vector 3.0 4.0 5.0) 'for/vector-length-many-body v3)) +;; Stop when a length is specified, even if the sequence continues: +(test '#(0 1 2 3 4 5 6 7 8 9) + 'nat + (for/vector #:length 10 ([i (in-naturals)]) i)) +(test '#((0 . 0) (1 . 0) (2 . 0) (3 . 0) (4 . 0) (5 . 0) (6 . 0) (7 . 0) (8 . 0) (9 . 0)) + 'nats + (for*/vector #:length 10 ([i (in-naturals)] [j (in-naturals)]) (cons j i))) +(test '#((0 . 0) (1 . 0) (2 . 0) (3 . 0) (4 . 0) (0 . 1) (1 . 1) (2 . 1) (3 . 1) (4 . 1)) + 'nat+5 + (for*/vector #:length 10 ([i (in-naturals)] [j (in-range 5)]) (cons j i))) +(test '#(1 3 5 7 9 11 13 15 17 19) + 'parallel + (for*/vector #:length 10 ([(i j) (in-parallel (in-naturals) + (in-naturals 1))]) + (+ i j))) + +;; Make sure the sequence stops at the length before consuming another element: +(test '(#("1" "2" "3" "4" "5" "6" "7" "8" "9" "10") . 10) + 'producer + (let ([c 0]) + (cons + (for/vector #:length 10 ([i (in-producer (lambda () (set! c (add1 c)) c) #f)]) + (number->string i)) + c))) +(test '(#("1" "2" "3" "4" "5" "6" "7" "8" "9" "10") . 10) + 'producer + (let ([c 0]) + (cons + (for*/vector #:length 10 ([j '(0)] + [i (in-producer (lambda () (set! c (add1 c)) c) #f)]) + (number->string i)) + c))) + + (test #hash((a . 1) (b . 2) (c . 3)) 'mk-hash (for/hash ([v (in-naturals)] [k '(a b c)])