diff --git a/collects/racket/place/distributed/rmpi.rkt b/collects/racket/place/distributed/rmpi.rkt index fad60a367d..c634545faa 100644 --- a/collects/racket/place/distributed/rmpi.rkt +++ b/collects/racket/place/distributed/rmpi.rkt @@ -4,7 +4,9 @@ racket/match racket/list racket/place - racket/class) + racket/class + racket/flonum + racket/fixnum) (provide rmpi-init rmpi-send @@ -12,11 +14,14 @@ rmpi-broadcast rmpi-reduce rmpi-allreduce + rmpi-alltoall + rmpi-alltoallv rmpi-barrier rmpi-id rmpi-cnt rmpi-partition rmpi-build-default-config + rmpi-make-localhost-config rmpi-launch rmpi-finish (struct-out rmpi-comm)) @@ -73,27 +78,40 @@ [(not (= 0 (bitwise-and id round))) (define peer-id (- id round)) (define real-peer-id (modulo (+ peer-id offset) cnt)) - ;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) + ;(printf "BRECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) (place-channel-get (vector-ref chs real-peer-id)) ] [else (define peer-id (+ id round)) - (define real-peer-id (modulo (+ peer-id offset) cnt)) - ;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) - (place-channel-put (vector-ref chs real-peer-id) val) + (cond + [(< peer-id cnt) + (define real-peer-id (modulo (+ peer-id offset) cnt)) + ;(printf "BSEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) + (place-channel-put (vector-ref chs real-peer-id) val)]) val])] [else val]))] [else val]))])) (define (fancy-reducer op recv-val val) (cond - [(number? recv-val) + [(or (number? recv-val) + (boolean? recv-val)) (op recv-val val)] [(vector? recv-val) (for/vector #:length (vector-length recv-val) ([a (in-vector recv-val)] [b (in-vector val)]) (fancy-reducer op a b))] + [(fxvector? recv-val) + (for/fxvector #:length (fxvector-length recv-val) + ([a (in-fxvector recv-val)] + [b (in-fxvector val)]) + (fancy-reducer op a b))] + [(flvector? recv-val) + (for/flvector #:length (flvector-length recv-val) + ([a (in-flvector recv-val)] + [b (in-flvector val)]) + (fancy-reducer op a b))] [else (raise (format "fancy-reducer error on ~a ~a ~a" op recv-val val))])) (define (rmpi-reduce comm dest op val) @@ -121,17 +139,20 @@ [(not (= 0 (bitwise-and id round))) (define peer-id (- id round)) (define real-peer-id (convert peer-id)) - ;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) + ;(printf/f "SEND ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) (place-channel-put (vector-ref chs real-peer-id) val) val ] [else (define peer-id (+ id round)) - (define real-peer-id (convert peer-id)) - ;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) - (define recv-val (place-channel-get (vector-ref chs real-peer-id))) - ;(define recv-val val) - (fancy-reducer op recv-val val)])]))] + (cond + [(< peer-id cnt) + (define real-peer-id (convert peer-id)) + ;(printf/f "RECV ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset) + (define recv-val (place-channel-get (vector-ref chs real-peer-id))) + ;(printf/f "RECVVAL ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset recv-val) + (fancy-reducer op recv-val val)] + [else val])])]))] [else val]))) (define (rmpi-barrier comm) @@ -139,6 +160,7 @@ (rmpi-broadcast comm 0 1)) (define (rmpi-allreduce comm op val) + (define rv (rmpi-reduce comm 0 op val)) (rmpi-broadcast comm 0 rv)) @@ -244,6 +266,81 @@ (when (= (rmpi-id comm) 0) (place-channel-put (second (tc-get 'done? tc)) 'done))) +(define (rmpi-make-localhost-config cnt start-port name) + (for/list ([i cnt]) + (list "localhost" (+ start-port i) (string->symbol (format "~a_~a" (symbol->string name) (number->string i))) + i))) + +(define (rmpi-alltoall comm outvec) + (match-define (rmpi-comm real-id cnt chs) comm) + (define-values (v! vr invec) + (cond + [(vector? outvec) (values vector-set! vector-ref (make-vector (vector-length outvec) 0))] + [(fxvector? outvec) (values fxvector-set! fxvector-ref (make-fxvector (fxvector-length outvec) 0))] + [(flvector? outvec) (values flvector-set! flvector-ref (make-flvector (flvector-length outvec) 0.0))] + [else (error (format "Unrecognized type of vector ~a" outvec))])) + (define (send+i i) + (define peer-id (modulo (fx+ real-id i) cnt)) + ;(printf/f "A2ASEND ~a ~a ~a ~a\n" i real-id peer-id (vr outvec peer-id)) + (place-channel-put (vector-ref chs peer-id) (vr outvec peer-id))) + (define (recv+i i) + (define peer-id (modulo (fx- real-id i) cnt)) + (define val (place-channel-get (vector-ref chs peer-id))) + ;(printf/f "A2ARECV ~a ~a ~a ~a\n" i real-id peer-id val) + (v! invec peer-id val)) + (v! invec real-id (vr outvec real-id)) + (for ([i (in-range 1 cnt)]) + (send+i i) + (recv+i i)) + invec) + +(define (rmpi-alltoallv comm outvec send-count send-displ invec recv-count recv-displ) + (match-define (rmpi-comm real-id cnt chs) comm) + (define-values (v! vr mk-v in-v) + (cond + [(vector? outvec) (values vector-set! vector-ref make-vector in-vector)] + [(fxvector? outvec) (values fxvector-set! fxvector-ref make-fxvector in-fxvector)] + [(flvector? outvec) (values flvector-set! flvector-ref make-flvector in-flvector)] + [else (error (format "Unrecognized type of vector ~a" outvec))])) + + ;; convert from outvec to vector of outvectors + (define outvv + (for/vector #:length (fxvector-length send-count) ([i (in-fxvector send-count)] + [d (in-fxvector send-displ)]) + (define vv (mk-v i)) + (for ([ii (in-range d (fx+ d i))] + [iii (in-naturals)]) + (v! vv iii (vr outvec ii))) + vv)) + (define invv (make-vector cnt #f)) + + ;; alltoall + (vector-set! invv real-id (vector-ref outvv real-id)) + (define (send+i i) + (define peer-id (modulo (fx+ real-id i) cnt)) + ;(printf/f "ALLTOALLV SENDING TO ~a ~a\n" peer-id real-id) + #;(thread (lambda () + (sleep 1) + (printf/f "WOW!\n"))) + (place-channel-put (vector-ref chs peer-id) (vector-ref outvv peer-id))) + (define (recv+i i) + (define peer-id (modulo (fx- real-id i) cnt)) + ;(printf/f "ALLTOALLV RECVING FROM ~a ~a\n" peer-id real-id) + (vector-set! invv peer-id (place-channel-get (vector-ref chs peer-id)))) + (for ([i (in-range 1 cnt)]) + (send+i i) + (recv+i i)) + + ;; convert form vector of invectors to invector + (for ([v (in-vector invv)] + [i (in-fxvector recv-count)] + [d (in-fxvector recv-displ)]) + (for ([x (in-v v)] + [ii (in-range d (fx+ d i))]) + (v! invec ii x))) + + invec) + (module+ bcast-print-test (rmpi-broadcast (rmpi-comm 0 8 (vector 0 1 2 3 4 5 6 7)) 0 "Hi") (rmpi-broadcast (rmpi-comm 3 8 (vector 0 1 2 3 4 5 6 7)) 0)