diff --git a/collects/typed-racket/base-env/prims.rkt b/collects/typed-racket/base-env/prims.rkt index dd4b29ff..6f42fbba 100644 --- a/collects/typed-racket/base-env/prims.rkt +++ b/collects/typed-racket/base-env/prims.rkt @@ -642,7 +642,6 @@ This file defines two sorts of primitives. All of them are provided into any mod (for/last: for/last) (for/vector: for/vector) (for/flvector: for/flvector) - (for/sum: for/sum) (for/product: for/product)) ;; Unlike with the above, the inferencer can handle any number of #:when @@ -694,6 +693,7 @@ This file defines two sorts of primitives. All of them are provided into any mod 'type-ascription #'(values var.ty ...))])) + (define-syntax (for*: stx) (syntax-parse stx #:literals (:) [(_ (~seq : Void) ... @@ -733,7 +733,6 @@ This file defines two sorts of primitives. All of them are provided into any mod (for*/last: for*/last) (for*/vector: for*/vector) (for*/flvector: for*/flvector) - (for*/sum: for*/sum) (for*/product: for*/product)) ;; Like for/lists: and for/fold:, the inferencer can handle these correctly. @@ -784,6 +783,32 @@ This file defines two sorts of primitives. All of them are provided into any mod 'type-ascription #'(values var.ty ...))])) + +(define-for-syntax (define-for/sum:-variant for/folder) + (lambda (stx) + (syntax-parse stx #:literals (:) + [(_ : ty + (clause:for-clause ...) + c:expr ...) + ;; ty has to include exact 0, the initial value of the accumulator + ;; (to be consistent with Racket semantics). + ;; We can't just change the initial value to be 0.0 if we expect a + ;; Float result. This is problematic in some cases e.g: + ;; (for/sum: : Float ([i : Float '(1.1)] #:when (zero? (random 1))) i) + (quasisyntax/loc stx + (#,for/folder : ty ([acc : ty 0]) + (clause.expand ... ...) + (let ([new (let () c ...)]) + (+ acc new))))]))) +(define-syntax (define-for/sum:-variants stx) + (syntax-parse stx + [(_ (name for/folder) ...) + (quasisyntax/loc stx + (begin (define-syntax name (define-for/sum:-variant #'for/folder)) + ...))])) +(define-for/sum:-variants (for/sum: for/fold:) (for*/sum: for*/fold:)) + + (define-syntax (provide: stx) (syntax-parse stx [(_ [i:id t] ...)