Fixed broadcast and reduce from nodes other that 0. Added gather and allgather

This commit is contained in:
Kevin Tew 2012-12-13 21:05:00 -07:00
parent 0876466a08
commit e30fdf0db6

View File

@ -13,7 +13,9 @@
rmpi-recv rmpi-recv
rmpi-broadcast rmpi-broadcast
rmpi-reduce rmpi-reduce
rmpi-gather
rmpi-allreduce rmpi-allreduce
rmpi-allgather
rmpi-alltoall rmpi-alltoall
rmpi-alltoallv rmpi-alltoallv
rmpi-barrier rmpi-barrier
@ -77,7 +79,7 @@
(cond (cond
[(not (= 0 (bitwise-and id round))) [(not (= 0 (bitwise-and id round)))
(define peer-id (- id round)) (define peer-id (- id round))
(define real-peer-id (modulo (+ peer-id offset) cnt)) (define real-peer-id (modulo (+ peer-id src) cnt))
;(printf "BRECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) ;(printf "BRECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(place-channel-get (vector-ref chs real-peer-id)) (place-channel-get (vector-ref chs real-peer-id))
] ]
@ -85,7 +87,7 @@
(define peer-id (+ id round)) (define peer-id (+ id round))
(cond (cond
[(< peer-id cnt) [(< peer-id cnt)
(define real-peer-id (modulo (+ peer-id offset) cnt)) (define real-peer-id (modulo (+ peer-id src) cnt))
;(printf "BSEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) ;(printf "BSEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(place-channel-put (vector-ref chs real-peer-id) val)]) (place-channel-put (vector-ref chs real-peer-id) val)])
val])] val])]
@ -125,6 +127,7 @@
(define offset (- cnt dest)) (define offset (- cnt dest))
(define (convert v) (modulo (+ v offset) cnt)) (define (convert v) (modulo (+ v offset) cnt))
(define (pconvert v) (modulo (+ v dest) cnt))
(define id (convert real-id)) (define id (convert real-id))
(let loop ([i i] (let loop ([i i]
[val val]) [val val])
@ -138,8 +141,8 @@
(cond (cond
[(not (= 0 (bitwise-and id round))) [(not (= 0 (bitwise-and id round)))
(define peer-id (- id round)) (define peer-id (- id round))
(define real-peer-id (convert peer-id)) (define real-peer-id (pconvert peer-id))
;(printf/f "SEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) ;(printf/f "SEND ROUND ~a RID ~a RPID ~a ID ~a PID ~a OFF ~a\n" round real-id real-peer-id id peer-id offset)
(place-channel-put (vector-ref chs real-peer-id) val) (place-channel-put (vector-ref chs real-peer-id) val)
val val
] ]
@ -147,23 +150,100 @@
(define peer-id (+ id round)) (define peer-id (+ id round))
(cond (cond
[(< peer-id cnt) [(< peer-id cnt)
(define real-peer-id (convert peer-id)) (define real-peer-id (pconvert peer-id))
;(printf/f "RECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) ;(printf/f "RECV ROUND ~a RID ~a RPID ~a ID ~a PID ~a OFF ~a\n" round real-id real-peer-id id peer-id offset)
(define recv-val (place-channel-get (vector-ref chs real-peer-id))) (define recv-val (place-channel-get (vector-ref chs real-peer-id)))
;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val) ;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val)
(fancy-reducer op recv-val val)] (fancy-reducer op recv-val val)]
[else val])])]))] [else val])])]))]
[else val]))) [else val])))
(define (rmpi-gather comm dest val)
(match-define (rmpi-comm real-id cnt chs) comm)
(define i
(let loop ([i 0])
(if (>= (arithmetic-shift 1 i) cnt)
i
(loop (add1 i)))))
(define offset (- cnt dest))
(define (convert v) (modulo (+ v offset) cnt))
(define (pconvert v) (modulo (+ v dest) cnt))
(define id (convert real-id))
(define retval
(let loop ([i i]
[val (vector val)])
(cond
[(> i 0)
(define round (arithmetic-shift 1 (sub1 i)))
(loop
(sub1 i)
(cond
[(< id (arithmetic-shift round 1))
(cond
[(not (= 0 (bitwise-and id round)))
(define peer-id (- id round))
(define real-peer-id (pconvert peer-id))
;(printf/f "SEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(place-channel-put (vector-ref chs real-peer-id) val)
val
]
[else
(define peer-id (+ id round))
(cond
[(< peer-id cnt)
(define real-peer-id (pconvert peer-id))
;(printf/f "RECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(define recv-val (place-channel-get (vector-ref chs real-peer-id)))
;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val)
(define (order real-id val real-peer-id pval)
(define vl (vector-length val))
(define pvl (vector-length pval))
(define new-val (make-vector (+ vl pvl)))
(define v1 (if (< real-id real-peer-id) val pval))
(define v2 (if (< real-id real-peer-id) pval val))
(define v1l (if (< real-id real-peer-id) vl pvl))
(define v2l (if (< real-id real-peer-id) pvl vl))
(let loop ([i 0]
[j 0]
[ii 0])
(cond
[(< i v1l)
(vector-set! new-val ii (vector-ref v1 i))
(cond
[(< j v2l)
(vector-set! new-val (+ 1 ii) (vector-ref v2 j))
(loop (+ 1 i) (+ 1 j) (+ 2 ii))]
[else
(loop (+ 1 i) j (+ 1 ii))])]
[(< j v2l)
(vector-set! new-val (+ 1 ii) (vector-ref v2 j))
(loop i (+ 1 j) (+ 1 ii))]
[else
new-val])))
(order id val peer-id recv-val)]
[else val])])]))]
[else val])))
(cond
[(and (not (zero? dest)) (= real-id dest))
(for/vector #:length cnt ([i cnt])
(vector-ref retval (modulo (+ i offset) cnt)))]
[else retval]))
(define (rmpi-barrier comm) (define (rmpi-barrier comm)
(rmpi-reduce comm 0 + 1) (rmpi-reduce comm 0 + 1)
(rmpi-broadcast comm 0 1)) (rmpi-broadcast comm 0 1))
(define (rmpi-allreduce comm op val) (define (rmpi-allreduce comm op val)
(define rv (rmpi-reduce comm 0 op val)) (define rv (rmpi-reduce comm 0 op val))
(rmpi-broadcast comm 0 rv)) (rmpi-broadcast comm 0 rv))
(define (rmpi-allgather comm op val)
(define rv (rmpi-gather comm 0 val))
(rmpi-broadcast comm 0 rv))
(define (partit num cnt id) (define (partit num cnt id)
(define-values (quo rem) (quotient/remainder num cnt)) (define-values (quo rem) (quotient/remainder num cnt))
(values (+ (* id quo) (if (< id rem) id rem)) (values (+ (* id quo) (if (< id rem) id rem))