Initial implementation of 3x3 and 4x4 Toom-Cook multiplication.

This speeds up `(factorial 1000000)` (using factorial from math/number-theory)
by about 3x, and the conversion of the result to a string by about 2x.

Benchmark:

    #lang racket
    (require math/number-theory)
    (define n (time (factorial 1000000)))
    (define s (time (number->string n)))
    (string-length s)

Current Racket CS:
cpu time: 19135 real time: 19137 gc time: 372
cpu time: 33416 real time: 33418 gc time: 463

Current Racket BC (GMP is really fast):
cpu time: 1465 real time: 1465 gc time: 51
cpu time: 3661 real time: 3659 gc time: 3

This PR:
cpu time: 6173 real time: 6172 gc time: 168
cpu time: 17846 real time: 17847 gc time: 377

Cutoff between Karatsuba and Toom3 estimated by mflatt.
Cutoff between Toom3 & Toom4 guessed.
This commit is contained in:
Sam Tobin-Hochstadt 2021-03-26 16:30:13 -04:00 committed by Sam Tobin-Hochstadt
parent cf8570f59d
commit 50a2cb32cb
2 changed files with 112 additions and 6 deletions

View File

@ -1690,6 +1690,16 @@
(and (exact? x) (exact? y))
(or (inexact? x) (inexact? y)))
(g (+ j 1)))))))))
(let ([sb* (foreign-procedure
"(cs)mul" (scheme-object scheme-object) scheme-object)])
;; (expt 2 100000) is big enough that all multiplication algorithms
;; are exercised
;; we add a power of 3 so that the number isn't too simple
(eqv? (sb* (+ 1 (expt 3 50) (expt 2 100000))
3)
(* (+ 1 (expt 3 50) (expt 2 100000))
3)))
(error? ; #f is not a fixnum
(* 3 #f))
(error? ; #f is not a fixnum

View File

@ -2388,7 +2388,103 @@
[(fx= x 1) (unless (number? y) (nonnumber-error who y)) y]
[else ($negate who y)])]
[else (integer* x y)])
(let ()
(let ([slim 32]
[klim 100]
[t3lim 512])
; both of the following functions were adapted from
; https://github.com/casevh/DecInt/blob/master/DecInt.py#L451
; under the BSD license
(define (toom3 x y)
(define xl (if (bignum? x) ($bignum-length x) 0))
(define yl (if (bignum? y) ($bignum-length y) 0))
(cond
[(and (fx< xl slim) (fx< yl slim))
(integer* x y)]
[(and (fx< xl klim) (fx< yl klim))
(karatsuba x y)]
[else
(let* ([k (fx* (fxquotient (fxmax xl yl) 3) (constant bigit-bits))]
[x-hi (ash x (fx* -2 k))]
[y-hi (ash y (fx* -2 k))]
[x-mid (bitwise-bit-field x k (fx* 2 k))]
[y-mid (bitwise-bit-field y k (fx* 2 k))]
[x-lo (bitwise-bit-field x 0 k)]
[y-lo (bitwise-bit-field y 0 k)]
[z0 (toom3 x-hi y-hi)]
[z4 (toom3 x-lo y-lo)]
[t1 (toom3 (+ x-hi x-mid x-lo) (+ y-hi y-mid y-lo))]
[t2 (toom3 (+ (- x-hi x-mid) x-lo) (+ (- y-hi y-mid) y-lo))]
[t3 (* (+ x-hi (ash x-mid 1) (ash x-lo 2))
(+ y-hi (ash y-mid 1) (ash y-lo 2)))]
[z2 (- (ash (+ t1 t2) -1) z0 z4)]
[t4 (- t3 z0 (ash z2 2) (ash z4 4))]
[z3 (quotient (+ (- t4 t1) t2) 6)]
[z1 (- (ash (- t1 t2) -1) z3)])
(+ (ash z0 (* k 4))
(ash z1 (* k 3))
(ash z2 (* k 2))
(ash z3 (* k 1))
(ash z4 (* k 0))))]))
(define (toom4 x y)
(define xl (if (bignum? x) ($bignum-length x) 0))
(define yl (if (bignum? y) ($bignum-length y) 0))
(cond
[(and (fx< xl slim) (fx< yl slim))
(integer* x y)]
[(and (fx< xl klim) (fx< yl klim))
(karatsuba x y)]
[(and (fx< xl t3lim) (fx< yl t3lim))
(toom3 x y)]
[else
(let* ((k (fx* (fxquotient (fxmax xl yl) 4) (constant bigit-bits)))
(x0 (ash x (fx* -3 k)))
(y0 (ash y (fx* -3 k)))
(x1 (bitwise-bit-field x (fx* 2 k) (fx* 3 k)))
(y1 (bitwise-bit-field y (fx* 2 k) (fx* 3 k)))
(x2 (bitwise-bit-field x (fx* 1 k) (fx* 2 k)))
(y2 (bitwise-bit-field y (fx* 1 k) (fx* 2 k)))
(x3 (bitwise-bit-field x 0 k))
(y3 (bitwise-bit-field y 0 k))
(z0 (toom4 x0 y0))
(z6 (toom4 x3 y3))
(t0 (+ z0 z6))
(xeven (+ x0 x2))
(xodd (+ x1 x3))
(yeven (+ y0 y2))
(yodd (+ y1 y3))
(t1 (- (toom4 (+ xeven xodd) (+ yeven yodd)) t0))
(t2 (- (toom4 (- xeven xodd) (- yeven yodd)) t0))
(xeven (+ x0 (ash x2 2)))
(xodd (+ (ash x1 1) (ash x3 3)))
(yeven (+ y0 (ash y2 2)))
(yodd (+ (ash y1 1) (ash y3 3)))
(t0 (+ z0 (ash z6 6)))
(t3 (- (toom4 (+ xeven xodd) (+ yeven yodd)) t0))
(t4 (- (toom4 (- xeven xodd) (- yeven yodd)) t0))
(t5 (- (* (+ x0 (* 3 x1) (* 9 x2) (* 27 x3))
(+ y0 (* 3 y1) (* 9 y2) (* 27 y3)))
(+ z0 (* 729 z6))))
(t6 (+ t1 t2))
(t7 (+ t3 t4))
(z4 (quotient (- t7 (ash t6 2)) 24))
(z2 (- (ash t6 -1) z4))
(t8 (- t1 z2 z4))
(t9 (- t3 (ash z2 2) (ash z4 4)))
(t10 (- t5 (* 9 z2) (* 81 z4)))
(t11 (- t10 (* 3 t8)))
(t12 (- t9 (ash t8 1)))
(z5 (quotient (- t11 (ash t12 2)) 120))
(z3 (quotient (- (ash t12 3) t11) 24))
(z1 (- t8 z3 z5)))
(+ (ash z0 (* k 6))
(ash z1 (* k 5))
(ash z2 (* k 4))
(ash z3 (* k 3))
(ash z4 (* k 2))
(ash z5 (* k 1))
(ash z6 (* k 0))))]))
;; _Modern Computer Arithmetic_, Brent and Zimmermann
(define (karatsuba x y)
(define xl (if (bignum? x) ($bignum-length x) 0))
@ -2400,8 +2496,8 @@
(let* ([k (fx* (fxquotient (fxmax xl yl) 2) (constant bigit-bits))]
[x-hi (ash x (fx- k))]
[y-hi (ash y (fx- k))]
[x-lo (- x (ash x-hi k))]
[y-lo (- y (ash y-hi k))]
[x-lo (bitwise-bit-field x 0 k)]
[y-lo (bitwise-bit-field y 0 k)]
[c0 (karatsuba x-lo y-lo)]
[c1 (karatsuba x-hi y-hi)]
[c1-c2 (cond
@ -2424,10 +2520,10 @@
[yz (if (bignum? y) ($bignum-trailing-zero-bits y) 0)])
(let ([z (fx+ xz yz)])
(if (fx= z 0)
(karatsuba x y)
(toom4 x y)
(bitwise-arithmetic-shift-left
(karatsuba (bitwise-arithmetic-shift-right x xz)
(bitwise-arithmetic-shift-right y yz))
(toom4 (bitwise-arithmetic-shift-right x xz)
(bitwise-arithmetic-shift-right y yz))
z))))))]
[(ratnum?) (/ (* x ($ratio-numerator y)) ($ratio-denominator y))]
[($exactnum? $inexactnum?)