diff --git a/collects/racket/place/distributed/rmpi.rkt b/collects/racket/place/distributed/rmpi.rkt index c634545faa..bc6b0235c4 100644 --- a/collects/racket/place/distributed/rmpi.rkt +++ b/collects/racket/place/distributed/rmpi.rkt @@ -13,7 +13,9 @@ rmpi-recv rmpi-broadcast rmpi-reduce + rmpi-gather rmpi-allreduce + rmpi-allgather rmpi-alltoall rmpi-alltoallv rmpi-barrier @@ -77,7 +79,7 @@ (cond [(not (= 0 (bitwise-and id round))) (define peer-id (- id round)) - (define real-peer-id (modulo (+ peer-id offset) cnt)) + (define real-peer-id (modulo (+ peer-id src) cnt)) ;(printf "BRECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) (place-channel-get (vector-ref chs real-peer-id)) ] @@ -85,7 +87,7 @@ (define peer-id (+ id round)) (cond [(< peer-id cnt) - (define real-peer-id (modulo (+ peer-id offset) cnt)) + (define real-peer-id (modulo (+ peer-id src) cnt)) ;(printf "BSEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) (place-channel-put (vector-ref chs real-peer-id) val)]) val])] @@ -125,6 +127,7 @@ (define offset (- cnt dest)) (define (convert v) (modulo (+ v offset) cnt)) + (define (pconvert v) (modulo (+ v dest) cnt)) (define id (convert real-id)) (let loop ([i i] [val val]) @@ -138,8 +141,8 @@ (cond [(not (= 0 (bitwise-and id round))) (define peer-id (- id round)) - (define real-peer-id (convert peer-id)) - ;(printf/f "SEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) + (define real-peer-id (pconvert peer-id)) + ;(printf/f "SEND ROUND ~a RID ~a RPID ~a ID ~a PID ~a OFF ~a\n" round real-id real-peer-id id peer-id offset) (place-channel-put (vector-ref chs real-peer-id) val) val ] @@ -147,23 +150,100 @@ (define peer-id (+ id round)) (cond [(< peer-id cnt) - (define real-peer-id (convert peer-id)) - ;(printf/f "RECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) + (define real-peer-id (pconvert peer-id)) + ;(printf/f "RECV ROUND ~a RID ~a RPID ~a ID ~a PID ~a OFF ~a\n" round real-id real-peer-id id peer-id offset) (define recv-val (place-channel-get (vector-ref chs real-peer-id))) ;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val) (fancy-reducer op recv-val val)] [else val])])]))] [else val]))) +(define (rmpi-gather comm dest val) + (match-define (rmpi-comm real-id cnt chs) comm) + (define i + (let loop ([i 0]) + (if (>= (arithmetic-shift 1 i) cnt) + i + (loop (add1 i))))) + + + (define offset (- cnt dest)) + (define (convert v) (modulo (+ v offset) cnt)) + (define (pconvert v) (modulo (+ v dest) cnt)) + (define id (convert real-id)) + (define retval + (let loop ([i i] + [val (vector val)]) + (cond + [(> i 0) + (define round (arithmetic-shift 1 (sub1 i))) + (loop + (sub1 i) + (cond + [(< id (arithmetic-shift round 1)) + (cond + [(not (= 0 (bitwise-and id round))) + (define peer-id (- id round)) + (define real-peer-id (pconvert peer-id)) + ;(printf/f "SEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) + (place-channel-put (vector-ref chs real-peer-id) val) + val + ] + [else + (define peer-id (+ id round)) + (cond + [(< peer-id cnt) + (define real-peer-id (pconvert peer-id)) + ;(printf/f "RECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) + (define recv-val (place-channel-get (vector-ref chs real-peer-id))) + ;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val) + (define (order real-id val real-peer-id pval) + (define vl (vector-length val)) + (define pvl (vector-length pval)) + (define new-val (make-vector (+ vl pvl))) + (define v1 (if (< real-id real-peer-id) val pval)) + (define v2 (if (< real-id real-peer-id) pval val)) + (define v1l (if (< real-id real-peer-id) vl pvl)) + (define v2l (if (< real-id real-peer-id) pvl vl)) + (let loop ([i 0] + [j 0] + [ii 0]) + (cond + [(< i v1l) + (vector-set! new-val ii (vector-ref v1 i)) + (cond + [(< j v2l) + (vector-set! new-val (+ 1 ii) (vector-ref v2 j)) + (loop (+ 1 i) (+ 1 j) (+ 2 ii))] + [else + (loop (+ 1 i) j (+ 1 ii))])] + [(< j v2l) + (vector-set! new-val (+ 1 ii) (vector-ref v2 j)) + (loop i (+ 1 j) (+ 1 ii))] + [else + new-val]))) + (order id val peer-id recv-val)] + [else val])])]))] + [else val]))) + (cond + [(and (not (zero? dest)) (= real-id dest)) + (for/vector #:length cnt ([i cnt]) + (vector-ref retval (modulo (+ i offset) cnt)))] + [else retval])) + + (define (rmpi-barrier comm) (rmpi-reduce comm 0 + 1) (rmpi-broadcast comm 0 1)) (define (rmpi-allreduce comm op val) - (define rv (rmpi-reduce comm 0 op val)) (rmpi-broadcast comm 0 rv)) +(define (rmpi-allgather comm op val) + (define rv (rmpi-gather comm 0 val)) + (rmpi-broadcast comm 0 rv)) + (define (partit num cnt id) (define-values (quo rem) (quotient/remainder num cnt)) (values (+ (* id quo) (if (< id rem) id rem))