#lang typed/racket/base

;; Constant-folding math operators.
;; Where possible, they simplify their arguments.

(provide
  +: -: *: /:
  ;; Same signature as the racket/base operators,
  ;;  but try to simplify arguments during expansion.
)

(require (for-syntax
  typed/racket/base
  (only-in racket/format ~a)
  (only-in racket/syntax format-id)
  syntax/id-table
  syntax/parse
  trivial/private/common
))

;; =============================================================================

(define-syntax make-numeric-operator
  (syntax-parser
   [(_ f:id)
    #:with f: (format-id #'f "~a:" (syntax-e #'f))
    #'(define-syntax f:
        (syntax-parser
         [(g e* (... ...))
          #:with e+* (for/list ([e (in-list (syntax->list #'(e* (... ...))))])
                       (expand-expr e))
          (let ([e++ (reduce/op f (syntax->list #'e+*))])
            (if (list? e++)
              (quasisyntax/loc #'g (f #,@e++))
              (quasisyntax/loc #'g #,e++)))]
         [g:id
          (syntax/loc #'g f)]
         [(g e* (... ...))
          (syntax/loc #'g (f e* (... ...)))]))]))

(make-numeric-operator +)
(make-numeric-operator -)
(make-numeric-operator *)
(make-numeric-operator /)

;; -----------------------------------------------------------------------------

;; Simplify a list of expressions using an associative binary operator.
;; Return either:
;; - A numeric value
;; - A list of syntax objects, to be spliced back in the source code
(define-for-syntax (reduce/op op e*)
  (let loop ([prev #f]   ;; (U #f Number), candidate for reduction
             [acc  '()]  ;; (Listof Syntax), irreducible arguments
             [e*   e*])  ;; (Listof Syntax), arguments to process
    (if (null? e*)
      ;; then: finished, return a number (prev) or list of expressions (acc)
      (if (null? acc)
        prev
        (reverse (if prev (cons prev acc) acc)))
      ;; else: pop the next argument from e*, fold if it's a constant
      (let ([v (quoted-stx-value? (car e*))])
        (if (number? v)
          ;; then: reduce the number
          (if prev
            ;; Watch for division-by-zero
            (if (and (zero? v) (eq? / op))
              (loop v (cons prev acc) (cdr e*))
              (loop (op prev v) acc (cdr e*)))
            (loop v acc (cdr e*)))
          ;; else: save value in acc
          (let ([acc+ (cons (car e*) (if prev (cons prev acc) acc))])
            (loop #f acc+ (cdr e*))))) )))