Fixed broadcast and reduce for non powers of 2. Added alltoall and alltoallv

This commit is contained in:
Kevin Tew 2012-12-11 11:35:19 -07:00
parent b6a4a48474
commit 0876466a08

View File

@ -4,7 +4,9 @@
racket/match
racket/list
racket/place
racket/class)
racket/class
racket/flonum
racket/fixnum)
(provide rmpi-init
rmpi-send
@ -12,11 +14,14 @@
rmpi-broadcast
rmpi-reduce
rmpi-allreduce
rmpi-alltoall
rmpi-alltoallv
rmpi-barrier
rmpi-id
rmpi-cnt
rmpi-partition
rmpi-build-default-config
rmpi-make-localhost-config
rmpi-launch
rmpi-finish
(struct-out rmpi-comm))
@ -73,27 +78,40 @@
[(not (= 0 (bitwise-and id round)))
(define peer-id (- id round))
(define real-peer-id (modulo (+ peer-id offset) cnt))
;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
;(printf "BRECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(place-channel-get (vector-ref chs real-peer-id))
]
[else
(define peer-id (+ id round))
(define real-peer-id (modulo (+ peer-id offset) cnt))
;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
(place-channel-put (vector-ref chs real-peer-id) val)
(cond
[(< peer-id cnt)
(define real-peer-id (modulo (+ peer-id offset) cnt))
;(printf "BSEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(place-channel-put (vector-ref chs real-peer-id) val)])
val])]
[else val]))]
[else val]))]))
(define (fancy-reducer op recv-val val)
(cond
[(number? recv-val)
[(or (number? recv-val)
(boolean? recv-val))
(op recv-val val)]
[(vector? recv-val)
(for/vector #:length (vector-length recv-val)
([a (in-vector recv-val)]
[b (in-vector val)])
(fancy-reducer op a b))]
[(fxvector? recv-val)
(for/fxvector #:length (fxvector-length recv-val)
([a (in-fxvector recv-val)]
[b (in-fxvector val)])
(fancy-reducer op a b))]
[(flvector? recv-val)
(for/flvector #:length (flvector-length recv-val)
([a (in-flvector recv-val)]
[b (in-flvector val)])
(fancy-reducer op a b))]
[else (raise (format "fancy-reducer error on ~a ~a ~a" op recv-val val))]))
(define (rmpi-reduce comm dest op val)
@ -121,17 +139,20 @@
[(not (= 0 (bitwise-and id round)))
(define peer-id (- id round))
(define real-peer-id (convert peer-id))
;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
;(printf/f "SEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(place-channel-put (vector-ref chs real-peer-id) val)
val
]
[else
(define peer-id (+ id round))
(define real-peer-id (convert peer-id))
;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val)
(define recv-val (place-channel-get (vector-ref chs real-peer-id)))
;(define recv-val val)
(fancy-reducer op recv-val val)])]))]
(cond
[(< peer-id cnt)
(define real-peer-id (convert peer-id))
;(printf/f "RECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset)
(define recv-val (place-channel-get (vector-ref chs real-peer-id)))
;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val)
(fancy-reducer op recv-val val)]
[else val])])]))]
[else val])))
(define (rmpi-barrier comm)
@ -139,6 +160,7 @@
(rmpi-broadcast comm 0 1))
(define (rmpi-allreduce comm op val)
(define rv (rmpi-reduce comm 0 op val))
(rmpi-broadcast comm 0 rv))
@ -244,6 +266,81 @@
(when (= (rmpi-id comm) 0)
(place-channel-put (second (tc-get 'done? tc)) 'done)))
(define (rmpi-make-localhost-config cnt start-port name)
(for/list ([i cnt])
(list "localhost" (+ start-port i) (string->symbol (format "~a_~a" (symbol->string name) (number->string i)))
i)))
(define (rmpi-alltoall comm outvec)
(match-define (rmpi-comm real-id cnt chs) comm)
(define-values (v! vr invec)
(cond
[(vector? outvec) (values vector-set! vector-ref (make-vector (vector-length outvec) 0))]
[(fxvector? outvec) (values fxvector-set! fxvector-ref (make-fxvector (fxvector-length outvec) 0))]
[(flvector? outvec) (values flvector-set! flvector-ref (make-flvector (flvector-length outvec) 0.0))]
[else (error (format "Unrecognized type of vector ~a" outvec))]))
(define (send+i i)
(define peer-id (modulo (fx+ real-id i) cnt))
;(printf/f "A2ASEND ~a ~a ~a ~a\n" i real-id peer-id (vr outvec peer-id))
(place-channel-put (vector-ref chs peer-id) (vr outvec peer-id)))
(define (recv+i i)
(define peer-id (modulo (fx- real-id i) cnt))
(define val (place-channel-get (vector-ref chs peer-id)))
;(printf/f "A2ARECV ~a ~a ~a ~a\n" i real-id peer-id val)
(v! invec peer-id val))
(v! invec real-id (vr outvec real-id))
(for ([i (in-range 1 cnt)])
(send+i i)
(recv+i i))
invec)
(define (rmpi-alltoallv comm outvec send-count send-displ invec recv-count recv-displ)
(match-define (rmpi-comm real-id cnt chs) comm)
(define-values (v! vr mk-v in-v)
(cond
[(vector? outvec) (values vector-set! vector-ref make-vector in-vector)]
[(fxvector? outvec) (values fxvector-set! fxvector-ref make-fxvector in-fxvector)]
[(flvector? outvec) (values flvector-set! flvector-ref make-flvector in-flvector)]
[else (error (format "Unrecognized type of vector ~a" outvec))]))
;; convert from outvec to vector of outvectors
(define outvv
(for/vector #:length (fxvector-length send-count) ([i (in-fxvector send-count)]
[d (in-fxvector send-displ)])
(define vv (mk-v i))
(for ([ii (in-range d (fx+ d i))]
[iii (in-naturals)])
(v! vv iii (vr outvec ii)))
vv))
(define invv (make-vector cnt #f))
;; alltoall
(vector-set! invv real-id (vector-ref outvv real-id))
(define (send+i i)
(define peer-id (modulo (fx+ real-id i) cnt))
;(printf/f "ALLTOALLV SENDING TO ~a ~a\n" peer-id real-id)
#;(thread (lambda ()
(sleep 1)
(printf/f "WOW!\n")))
(place-channel-put (vector-ref chs peer-id) (vector-ref outvv peer-id)))
(define (recv+i i)
(define peer-id (modulo (fx- real-id i) cnt))
;(printf/f "ALLTOALLV RECVING FROM ~a ~a\n" peer-id real-id)
(vector-set! invv peer-id (place-channel-get (vector-ref chs peer-id))))
(for ([i (in-range 1 cnt)])
(send+i i)
(recv+i i))
;; convert form vector of invectors to invector
(for ([v (in-vector invv)]
[i (in-fxvector recv-count)]
[d (in-fxvector recv-displ)])
(for ([x (in-v v)]
[ii (in-range d (fx+ d i))])
(v! invec ii x)))
invec)
(module+ bcast-print-test
(rmpi-broadcast (rmpi-comm 0 8 (vector 0 1 2 3 4 5 6 7)) 0 "Hi")
(rmpi-broadcast (rmpi-comm 3 8 (vector 0 1 2 3 4 5 6 7)) 0)