Fixed broadcast and reduce for non powers of 2. Added alltoall and alltoallv
This commit is contained in:
parent
b6a4a48474
commit
0876466a08
|
@ -4,7 +4,9 @@
|
|||
racket/match
|
||||
racket/list
|
||||
racket/place
|
||||
racket/class)
|
||||
racket/class
|
||||
racket/flonum
|
||||
racket/fixnum)
|
||||
|
||||
(provide rmpi-init
|
||||
rmpi-send
|
||||
|
@ -12,11 +14,14 @@
|
|||
rmpi-broadcast
|
||||
rmpi-reduce
|
||||
rmpi-allreduce
|
||||
rmpi-alltoall
|
||||
rmpi-alltoallv
|
||||
rmpi-barrier
|
||||
rmpi-id
|
||||
rmpi-cnt
|
||||
rmpi-partition
|
||||
rmpi-build-default-config
|
||||
rmpi-make-localhost-config
|
||||
rmpi-launch
|
||||
rmpi-finish
|
||||
(struct-out rmpi-comm))
|
||||
|
@ -73,27 +78,40 @@
|
|||
[(not (= 0 (bitwise-and id round)))
|
||||
(define peer-id (- id round))
|
||||
(define real-peer-id (modulo (+ peer-id offset) cnt))
|
||||
;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
|
||||
;(printf "BRECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
|
||||
(place-channel-get (vector-ref chs real-peer-id))
|
||||
]
|
||||
[else
|
||||
(define peer-id (+ id round))
|
||||
(define real-peer-id (modulo (+ peer-id offset) cnt))
|
||||
;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
|
||||
(place-channel-put (vector-ref chs real-peer-id) val)
|
||||
(cond
|
||||
[(< peer-id cnt)
|
||||
(define real-peer-id (modulo (+ peer-id offset) cnt))
|
||||
;(printf "BSEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
|
||||
(place-channel-put (vector-ref chs real-peer-id) val)])
|
||||
val])]
|
||||
[else val]))]
|
||||
[else val]))]))
|
||||
|
||||
(define (fancy-reducer op recv-val val)
|
||||
(cond
|
||||
[(number? recv-val)
|
||||
[(or (number? recv-val)
|
||||
(boolean? recv-val))
|
||||
(op recv-val val)]
|
||||
[(vector? recv-val)
|
||||
(for/vector #:length (vector-length recv-val)
|
||||
([a (in-vector recv-val)]
|
||||
[b (in-vector val)])
|
||||
(fancy-reducer op a b))]
|
||||
[(fxvector? recv-val)
|
||||
(for/fxvector #:length (fxvector-length recv-val)
|
||||
([a (in-fxvector recv-val)]
|
||||
[b (in-fxvector val)])
|
||||
(fancy-reducer op a b))]
|
||||
[(flvector? recv-val)
|
||||
(for/flvector #:length (flvector-length recv-val)
|
||||
([a (in-flvector recv-val)]
|
||||
[b (in-flvector val)])
|
||||
(fancy-reducer op a b))]
|
||||
[else (raise (format "fancy-reducer error on ~a ~a ~a" op recv-val val))]))
|
||||
|
||||
(define (rmpi-reduce comm dest op val)
|
||||
|
@ -121,17 +139,20 @@
|
|||
[(not (= 0 (bitwise-and id round)))
|
||||
(define peer-id (- id round))
|
||||
(define real-peer-id (convert peer-id))
|
||||
;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
|
||||
;(printf/f "SEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
|
||||
(place-channel-put (vector-ref chs real-peer-id) val)
|
||||
val
|
||||
]
|
||||
[else
|
||||
(define peer-id (+ id round))
|
||||
(define real-peer-id (convert peer-id))
|
||||
;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
|
||||
(define recv-val (place-channel-get (vector-ref chs real-peer-id)))
|
||||
;(define recv-val val)
|
||||
(fancy-reducer op recv-val val)])]))]
|
||||
(cond
|
||||
[(< peer-id cnt)
|
||||
(define real-peer-id (convert peer-id))
|
||||
;(printf/f "RECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
|
||||
(define recv-val (place-channel-get (vector-ref chs real-peer-id)))
|
||||
;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val)
|
||||
(fancy-reducer op recv-val val)]
|
||||
[else val])])]))]
|
||||
[else val])))
|
||||
|
||||
(define (rmpi-barrier comm)
|
||||
|
@ -139,6 +160,7 @@
|
|||
(rmpi-broadcast comm 0 1))
|
||||
|
||||
(define (rmpi-allreduce comm op val)
|
||||
|
||||
(define rv (rmpi-reduce comm 0 op val))
|
||||
(rmpi-broadcast comm 0 rv))
|
||||
|
||||
|
@ -244,6 +266,81 @@
|
|||
(when (= (rmpi-id comm) 0)
|
||||
(place-channel-put (second (tc-get 'done? tc)) 'done)))
|
||||
|
||||
(define (rmpi-make-localhost-config cnt start-port name)
|
||||
(for/list ([i cnt])
|
||||
(list "localhost" (+ start-port i) (string->symbol (format "~a_~a" (symbol->string name) (number->string i)))
|
||||
i)))
|
||||
|
||||
(define (rmpi-alltoall comm outvec)
|
||||
(match-define (rmpi-comm real-id cnt chs) comm)
|
||||
(define-values (v! vr invec)
|
||||
(cond
|
||||
[(vector? outvec) (values vector-set! vector-ref (make-vector (vector-length outvec) 0))]
|
||||
[(fxvector? outvec) (values fxvector-set! fxvector-ref (make-fxvector (fxvector-length outvec) 0))]
|
||||
[(flvector? outvec) (values flvector-set! flvector-ref (make-flvector (flvector-length outvec) 0.0))]
|
||||
[else (error (format "Unrecognized type of vector ~a" outvec))]))
|
||||
(define (send+i i)
|
||||
(define peer-id (modulo (fx+ real-id i) cnt))
|
||||
;(printf/f "A2ASEND ~a ~a ~a ~a\n" i real-id peer-id (vr outvec peer-id))
|
||||
(place-channel-put (vector-ref chs peer-id) (vr outvec peer-id)))
|
||||
(define (recv+i i)
|
||||
(define peer-id (modulo (fx- real-id i) cnt))
|
||||
(define val (place-channel-get (vector-ref chs peer-id)))
|
||||
;(printf/f "A2ARECV ~a ~a ~a ~a\n" i real-id peer-id val)
|
||||
(v! invec peer-id val))
|
||||
(v! invec real-id (vr outvec real-id))
|
||||
(for ([i (in-range 1 cnt)])
|
||||
(send+i i)
|
||||
(recv+i i))
|
||||
invec)
|
||||
|
||||
(define (rmpi-alltoallv comm outvec send-count send-displ invec recv-count recv-displ)
|
||||
(match-define (rmpi-comm real-id cnt chs) comm)
|
||||
(define-values (v! vr mk-v in-v)
|
||||
(cond
|
||||
[(vector? outvec) (values vector-set! vector-ref make-vector in-vector)]
|
||||
[(fxvector? outvec) (values fxvector-set! fxvector-ref make-fxvector in-fxvector)]
|
||||
[(flvector? outvec) (values flvector-set! flvector-ref make-flvector in-flvector)]
|
||||
[else (error (format "Unrecognized type of vector ~a" outvec))]))
|
||||
|
||||
;; convert from outvec to vector of outvectors
|
||||
(define outvv
|
||||
(for/vector #:length (fxvector-length send-count) ([i (in-fxvector send-count)]
|
||||
[d (in-fxvector send-displ)])
|
||||
(define vv (mk-v i))
|
||||
(for ([ii (in-range d (fx+ d i))]
|
||||
[iii (in-naturals)])
|
||||
(v! vv iii (vr outvec ii)))
|
||||
vv))
|
||||
(define invv (make-vector cnt #f))
|
||||
|
||||
;; alltoall
|
||||
(vector-set! invv real-id (vector-ref outvv real-id))
|
||||
(define (send+i i)
|
||||
(define peer-id (modulo (fx+ real-id i) cnt))
|
||||
;(printf/f "ALLTOALLV SENDING TO ~a ~a\n" peer-id real-id)
|
||||
#;(thread (lambda ()
|
||||
(sleep 1)
|
||||
(printf/f "WOW!\n")))
|
||||
(place-channel-put (vector-ref chs peer-id) (vector-ref outvv peer-id)))
|
||||
(define (recv+i i)
|
||||
(define peer-id (modulo (fx- real-id i) cnt))
|
||||
;(printf/f "ALLTOALLV RECVING FROM ~a ~a\n" peer-id real-id)
|
||||
(vector-set! invv peer-id (place-channel-get (vector-ref chs peer-id))))
|
||||
(for ([i (in-range 1 cnt)])
|
||||
(send+i i)
|
||||
(recv+i i))
|
||||
|
||||
;; convert form vector of invectors to invector
|
||||
(for ([v (in-vector invv)]
|
||||
[i (in-fxvector recv-count)]
|
||||
[d (in-fxvector recv-displ)])
|
||||
(for ([x (in-v v)]
|
||||
[ii (in-range d (fx+ d i))])
|
||||
(v! invec ii x)))
|
||||
|
||||
invec)
|
||||
|
||||
(module+ bcast-print-test
|
||||
(rmpi-broadcast (rmpi-comm 0 8 (vector 0 1 2 3 4 5 6 7)) 0 "Hi")
|
||||
(rmpi-broadcast (rmpi-comm 3 8 (vector 0 1 2 3 4 5 6 7)) 0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user