From 5be7a7d98072dfba449cebed93cb80b1f92faeb6 Mon Sep 17 00:00:00 2001 From: Kevin Tew Date: Tue, 17 Apr 2012 15:28:20 -0600 Subject: [PATCH] [Distributed Places] simple MPI layer over distribute places --- collects/racket/place/distributed.rkt | 179 +++++++++++---- collects/racket/place/distributed/RMPI.rkt | 252 +++++++++++++++++++++ 2 files changed, 383 insertions(+), 48 deletions(-) create mode 100644 collects/racket/place/distributed/RMPI.rkt diff --git a/collects/racket/place/distributed.rkt b/collects/racket/place/distributed.rkt index 6955f26c22..f255d8711a 100644 --- a/collects/racket/place/distributed.rkt +++ b/collects/racket/place/distributed.rkt @@ -57,6 +57,7 @@ ll-channel-put write-flush printf/f + displayln/f log-message start-spawned-node-router @@ -79,6 +80,11 @@ (struct-out dcg) ;v3 api + build-distributed-launch-path + build-node-args + *channel-get + *channel-put + send-new-place-channel-to-named-dest ;classes event-container<%> @@ -100,6 +106,7 @@ ) (define-runtime-path distributed-launch-path "distributed/launch.rkt") +(define (build-distributed-launch-path path) (path->string (build-path path "collects/racket/place/distributed/launch.rkt"))) (define DEFAULT-ROUTER-PORT 6340) @@ -147,6 +154,9 @@ (define (write-flush msg [p (current-output-port)]) +; (write msg (current-output-port)) +; (newline) + (flush-output) (write msg p) (flush-output p)) @@ -154,6 +164,11 @@ (apply printf args) (flush-output)) +(define (displayln/f . args) + (apply displayln args) + (flush-output)) + + (define (tcp-connect/backoff rname rport #:times [times 4] #:start-seconds [start-seconds 1]) (let loop ([t 0] [wait-time start-seconds]) @@ -200,9 +215,11 @@ (define DCGM-DPLACE-DIED 7) (define DCGM-TYPE-LOG-TO-PARENT 8) (define DCGM-TYPE-NEW-PLACE 9) -(define DCGM-TYPE-SET-OWNER 10) +(define DCGM-TYPE-NEW-CONNECTION 10) +(define DCGM-TYPE-SET-OWNER 11) (define DCGM-NEW-NODE-CONNECT 50) +(define DCGM-NEW-PLACE-CHANNEL 51) (define DCGM-CONTROL-NEW-NODE 100) (define DCGM-CONTROL-NEW-PLACE 101) @@ -239,6 +256,12 @@ (dcg-send-type c DCGM-TYPE-NEW-DCHANNEL dest (dchannel e2)) (dchannel e1)) +(define (send-new-place-channel-to-named-dest ch src-id dest-list) + (define-values (e1 e2) (place-channel)) + (place-channel-put ch (dcgm DCGM-NEW-PLACE-CHANNEL src-id dest-list e2)) + e1) + + ;; Contract: start-node-router : VectorOf[ (or/c place-channel socket-connection)] -> (void) ;; Purpose: Forward messages between channels and build new point-to-point subchannels ;; Example: @@ -363,8 +386,10 @@ (if (dchannel? pch) (dchannel-ch pch) pch) (lambda (e) (match e - [(dcgm 8 #;(== DCGM-TYPE-LOG-TO-PARENT) _ _ (list severity msg)) + [(dcgm #;8 (== DCGM-TYPE-LOG-TO-PARENT) _ _ (list severity msg)) (send node log-from-child #:severity severity msg)] + [(dcgm #;51 (== DCGM-NEW-PLACE-CHANNEL) _ _ _) + (send node forward-mesg e pch)] [else (put-msg e)]))) nes)) (define/public (get-sc-id) id) @@ -420,7 +445,7 @@ (hash-ref named-places (->string name) #f)) (define (add-place-channel-socket-bridge pch sch id) (add-psb (new place-socket-bridge% [pch pch] [sch sch] [id id] [node this]))) - (define (forward-mesg m src-channel) + (define/public (forward-mesg m src-channel) ;(printf/f "FORWARD MESSAGE ~a ~a\n" src-channel m) (match m [(dcgm 1 #;(== DCGM-TYPE-DIE) src dest "DIE") (exit 1)] @@ -434,7 +459,7 @@ (sconn-write-flush d (dcgm DCGM-TYPE-NEW-INTER-DCHANNEL src dest ch-id))] [(or (place-channel? d) (place? d)) (place-channel-put d m)] - [else (raise (format "Unexpected channel type ~a" d))])] + [else (raise (format "Unexpected channel type1 ~a" d))])] [(dcgm 9 #;(== DCGM-TYPE-NEW-PLACE) -1 (and place-exec (list-rest type rest)) ch-id) (match place-exec [(list 'connect name) @@ -448,6 +473,21 @@ [node this])) (add-sub-ec nc)] + [else + (sconn-write-flush src-channel (dcgm DCGM-TYPE-INTER-DCHANNEL ch-id ch-id + (format "ERROR: name not found ~a" name)))])] + [(list 'channel-connect name src-id) + (define np (named-place-lookup name)) + (cond + [np + (define nc (new connection% + [name-pl np] + [ch-id ch-id] + [sc src-channel] + [node this] + [channel-connection src-id])) + (add-sub-ec nc)] + [else (sconn-write-flush src-channel (dcgm DCGM-TYPE-INTER-DCHANNEL ch-id ch-id (format "ERROR: name not found ~a" name)))])] @@ -477,7 +517,7 @@ (send pch forward msg)] [(th-place-channel? pch) (th-place-channel-put pch msg)] - [else (raise (format "Unexpected channel type ~a" pch))])] + [else (raise (format "Unexpected channel type2 ~a" pch))])] [(dcgm 6 #;(== DCGM-TYPE-SPAWN-REMOTE-PROCESS) src (list node-name node-port mod-path funcname) ch1) (define node (spawn-remote-racket-node node-name #:listen-port node-port)) (for ([x (in-hash-values spawned-nodes)]) @@ -493,11 +533,18 @@ (log-debug (format"PLACE ~a died" ch-id))] [(dcgm 8 #;(== DCGM-TYPE-LOG-TO-PARENT) _ _ (list severity msg)) (log-from-child #:severity severity msg)] - [(dcgm 10 #;(== DCGM-TYPE-SET-OWNER) -1 -1 msg) + [(dcgm 11 #;(== DCGM-TYPE-SET-OWNER) -1 -1 msg) (log-debug (format "RECV DCGM-TYPE-SET-OWNER ~a" src-channel)) (set! owner src-channel)] [(dcgm #;50 (== DCGM-NEW-NODE-CONNECT) -1 -1 (list node-name node-port)) (add-spawned-node (list node-name node-port) (new remote-node% [host-name node-name] [listen-port node-port]))] + + [(dcgm #;51 (== DCGM-NEW-PLACE-CHANNEL) src-id (and dest (list host port name)) pch) + ;(printf/f "DCGM-NEW-PLACE-CHANNEL ~a ~a\n" src-id dest) + (define node (find-spawned-node (list host port))) + (unless node (raise (format "DCGM-CONTROL-NEW-CONNECTION Node ~a not found in ~a" dest spawned-nodes))) + (send node connect-channel src-id name #:one-sided pch)] + [(dcgm #;100 (== DCGM-CONTROL-NEW-NODE) -1 solo (list node-name node-port)) (define node (spawn-remote-racket-node node-name #:listen-port node-port)) (cond @@ -509,12 +556,15 @@ (add-spawned-node (list node-name node-port) node)])] [(dcgm #;101 (== DCGM-CONTROL-NEW-PLACE) dest -1 place-exec) (define node (find-spawned-node dest)) + (unless node (raise (format "DCGM-CONTROL-NEW-PLACE Node ~a not found in ~a" dest spawned-nodes))) (send node launch-place place-exec)] [(dcgm #;102 (== DCGM-CONTROL-NEW-CONNECTION) dest -1 (list name ch)) (define node (find-spawned-node dest)) + (unless node (raise (format "DCGM-CONTROL-NEW-CONNECTION Node ~a not found in ~a" dest spawned-nodes))) (send node remote-connect name #:one-sided ch)] [(dcgm mtype srcs dest msg) +; (printf/f "DEFAULT ACTION ~a ~a ~a ~a\n" mtype srcs dest msg) (define d (vector-ref chan-vec dest)) (cond [(is-a? d socket-connection%) @@ -555,12 +605,12 @@ (cons (cond [(is-a? x socket-connection%) - (sconn-get-forward-event x forward-mesg)] + (sconn-build-forward-event x (lambda (e x) (forward-mesg e x)))] [(or (place-channel? x) (place? x)) (wrap-evt x (lambda (e) (forward-mesg e x)))] [(channel? x) (wrap-evt x (lambda (e) (forward-mesg e x)))] - [else (raise (format "Unexpected channel type ~a" x))]) + [else (raise (format "Unexpected channel type3 ~a" x))]) n)) nes)] @@ -581,12 +631,12 @@ (cons (cond [(is-a? x socket-connection%) - (sconn-get-forward-event x forward-mesg)] + (sconn-build-forward-event x (lambda (e x) (forward-mesg e x)))] [(or (place-channel? x) (place? x)) (wrap-evt x (lambda (e) ;(printf "SOCKET-PORT PLACE MESSAGE ~a\n" e) (forward-mesg e x)))] - [else (raise (format "Unexpected channel type ~a" x))]) + [else (raise (format "Unexpected channel type4 ~a" x))]) n)) nes)] [nes @@ -644,7 +694,7 @@ (define (sconn-lookup-subchannel s ch-id) (send s lookup-subchannel ch-id)) (define (sconn-write-flush s x) (send s _write-flush x)) (define (sconn-remove-subchannel s scid) (send s remove-subchannel scid)) -(define (sconn-get-forward-event s forwarder) (send s get-forward-event forwarder)) +(define (sconn-build-forward-event s forwarder) (send s build-forward-event forwarder)) (define socket-connection% (backlink @@ -661,7 +711,9 @@ [connecting #f] [ch #f]) - (define (forward-mesg x) (void)) + (define (forward-mesg x) + (raise (format "Getting forwarded ~a" x)) + (void)) (define (tcp-connect/retry rname rport #:times [times 10] #:delay [delay 1]) (let loop ([t 0]) @@ -684,8 +736,8 @@ (define (handle-error e) (cond [remote-node => (lambda (n) - (send n tcp-connection-died host port))] - [else (raise (format "TCP connection to ~a:~a failed.\n" host port))])) + (send n tcp-connection-died host port e))] + [else (raise (format "TCP connection to ~a:~a failed ~a\n" host port e))])) (define/public (add-subchannel id pch) (set! subchannels (append subchannels (list (cons id pch))))) @@ -700,7 +752,7 @@ (lambda (x) (and (not (= (car x) id)) x)) subchannels))) (define/public (addresses) (tcp-addresses in #t)) - (define/public (get-forward-event forwarder) + (define/public (build-forward-event forwarder) (when (equal? out #f) (ensure-connected)) (wrap-evt in (lambda (e) (forwarder @@ -799,7 +851,7 @@ (th-place-channel-put pch msg)] [(async-bi-channel? pch) (async-bi-channel-put pch msg)] - [else (raise (format "Unexpected channel type ~a" pch))])] + [else (raise (format "Unexpected channel type5 ~a" pch))])] [(dcgm 8 #;(== DCGM-TYPE-LOG-TO-PARENT) _ _ (list severity msg)) (define parent (send this get-router)) (cond @@ -812,11 +864,11 @@ (log-debug (format "EOF on node socket connection pid to ~a ~a:~a CONNECTION ~a:~a -> ~a:~a" (get-sp-pid) host-name listen-port lh lp rh rp)) (set! sc #f)] - [else (log-debug (format"received message ~a" it))])) + [else (log-debug (format"received message ~a from ~a" it in-port))])) (define/public (get-log-prefix) (format "PLACE ~a:~a" host-name listen-port)) - (define/public (tcp-connection-died host port) - (log-debug (format "TCP connection~a:~a died, restarting node/connection" host-name listen-port)) + (define/public (tcp-connection-died host port e) + (log-debug (format "TCP connection ~a:~a died, ~a, restarting node/connection" host-name listen-port e)) (and sp (send sp kill)) (set! sp #f) (cond @@ -852,7 +904,7 @@ (sconn-remove-subchannel sc scid)) (define/public (launch-place place-exec #:restart-on-exit [restart-on-exit #f] #:one-sided-place? [one-sided-place? #f]) - (define rp (new remote-place% [node this] [place-exec place-exec] [restart-on-exit restart-on-exit] + (define rp (new remote-connection% [node this] [place-exec place-exec] [restart-on-exit restart-on-exit] [one-sided one-sided-place?])) (add-remote-place rp) rp) @@ -863,16 +915,16 @@ (add-remote-place rp) rp) - (define/public (spawn-remote-place place-exec dch) - (define ch-id (nextid)) - (sconn-add-subchannel sc ch-id dch) - (sconn-write-flush sc (dcgm DCGM-TYPE-NEW-PLACE -1 place-exec ch-id)) - (new place-socket-bridge% [pch dch] [sch sc] [id ch-id] [node this])) + (define/public (connect-channel src-id name #:restart-on-exit [restart-on-exit #f] #:one-sided [one-sided #f]) + (define rp (new remote-connection% [node this] [name name] [src-id src-id] [restart-on-exit restart-on-exit] + [one-sided one-sided])) + (add-remote-place rp) + rp) - (define/public (spawn-remote-connection name dch) + (define/public (spawn-remote-connection msg-gen dch) (define ch-id (nextid)) (sconn-add-subchannel sc ch-id dch) - (sconn-write-flush sc (dcgm DCGM-TYPE-NEW-PLACE -1 (list 'connect name) ch-id)) + (sconn-write-flush sc (msg-gen ch-id)) (new place-socket-bridge% [pch dch] [sch sc] [id ch-id] [node this])) (define/public (send-exit) @@ -885,7 +937,7 @@ (let* ([es (if sp (send sp register es) es)] [es (for/fold ([nes es]) ([rp remote-places]) (send rp register nes))] - [es (if sc (cons (sconn-get-forward-event sc on-socket-event) es) es)] + [es (if sc (cons (sconn-build-forward-event sc on-socket-event) es) es)] [es (if (and restart-on-exit (not (equal? restart-on-exit #t))) (send restart-on-exit register es) @@ -918,6 +970,7 @@ (init-field node) (init-field [place-exec #f]) (init-field [name #f]) + (init-field [src-id #f]) (init-field [one-sided #f]) (init-field [restart-on-exit #f]) (init-field [on-channel #f]) @@ -944,13 +997,16 @@ (set! pc pch2)]) (set! psb - (if place-exec - (send node spawn-remote-place place-exec rpc) - (send node spawn-remote-connection name rpc))) + (send node spawn-remote-connection + (cond + [place-exec (lambda (ch-id) (dcgm DCGM-TYPE-NEW-PLACE -1 place-exec ch-id))] + [src-id (lambda (ch-id) (dcgm DCGM-TYPE-NEW-PLACE -1 (list 'channel-connect name src-id) ch-id))] + [else (lambda (ch-id) (dcgm DCGM-TYPE-NEW-PLACE -1 (list 'connect name) ch-id))]) + rpc)) (define (restart-place) (send node drop-sc-id (send psb get-sc-id)) - (set! psb (send node spawn-remote-place place-exec rpc))) + (set! psb (send node spawn-remote-connection (lambda (ch-id) (dcgm DCGM-TYPE-NEW-PLACE -1 place-exec ch-id)) rpc))) (define/public (stop) (void)) (define/public (get-channel) pc) @@ -1074,20 +1130,35 @@ (init-field name-pl ch-id sc - node) + node + [channel-connection #f] + [on-place-dead #f]) (field [psb #f]) (define-values (pch1 pch2) (place-channel)) - (define name-ch (send name-pl get-channel)) + (define forward-ch + (if channel-connection + pch1 + (send name-pl get-channel))) + + (define control-ch + (if channel-connection + (send name-pl get-channel) + #f)) - (init-field [on-place-dead #f]) (sconn-add-subchannel sc ch-id this) (set! psb (new place-socket-bridge% [pch pch1] [sch sc] [id ch-id] [node node])) + + (when channel-connection + (place-channel-put control-ch (list channel-connection pch2))) (define/public (forward msg) - (place-channel-put name-ch (list msg pch2))) + (place-channel-put forward-ch + (if channel-connection + msg + (list msg pch2)))) (define/public (put msg) (sconn-write-flush sc (dcgm DCGM-TYPE-INTER-DCHANNEL ch-id ch-id msg))) @@ -1482,13 +1553,20 @@ ;;; [chan-vec (vector ch)])) ;;; (send mrn sync-events)]))) ;;; (place-channel-put mr (list listen-port))) - (define (*channel-put ch msg) ((cond [(place-channel? ch) place-channel-put] + [(async-bi-channel? ch) async-bi-channel-put] [(channel? ch) channel-put]) ch msg)) +(define (*channel-get ch) + ((cond + [(place-channel? ch) place-channel-get] + [(async-bi-channel? ch) async-bi-channel-get] + [(channel? ch) channel-get]) + ch)) + (define/provide (mr-spawn-remote-node mrch host #:listen-port [listen-port DEFAULT-ROUTER-PORT] #:solo [solo #f]) (*channel-put mrch (dcgm DCGM-CONTROL-NEW-NODE -1 solo (list host listen-port)))) @@ -1501,7 +1579,7 @@ (cond [(channel? mrch) (make-async-bi-channel)] [(place-channel? mrch) (place-channel)] - [else (raise (format "Unexpected channel type ~a" mrch))])) + [else (raise (format "Unexpected channel type6 ~a" mrch))])) (*channel-put mrch (dcgm DCGM-CONTROL-NEW-CONNECTION dest -1 (list name ch2))) ch1) @@ -1518,20 +1596,25 @@ (send mrn sync-events)))) (values mr ch)) -(define (spawn-node-at host #:listen-port [listen-port DEFAULT-ROUTER-PORT]) +(define (spawn-node-at host #:listen-port [listen-port DEFAULT-ROUTER-PORT] + #:racket-path [racketpath (racket-path)] + #:ssh-bin-path [sshpath (ssh-bin-path)] + #:distributed-launch-path [distributedlaunchpath (->module-path-bytes distributed-launch-path)]) + (define ch (make-channel)) (thread - (lambda () (channel-put ch (spawn-remote-racket-node host #:listen-port listen-port)))) + (lambda () (channel-put ch (spawn-remote-racket-node host #:listen-port listen-port + #:racket-path racketpath + #:ssh-bin-path sshpath + #:distributed-launch-path distributedlaunchpath)))) ch) (define/provide (spawn-nodes/join nodes-desc) - (channels-join - (for/list ([n nodes-desc]) - (match-define (list host listen-port) n) - (spawn-node-at host #:listen-port listen-port)))) - -(define/provide (channels-join chs) - (for/list ([x chs]) + (for/list ([x + (for/list ([n nodes-desc]) + (apply keyword-apply spawn-node-at n))]) (channel-get x))) - +(define build-node-args + (make-keyword-procedure (lambda (kws kw-args . rest) + (list kws kw-args rest)))) diff --git a/collects/racket/place/distributed/RMPI.rkt b/collects/racket/place/distributed/RMPI.rkt new file mode 100644 index 0000000000..7b6b415996 --- /dev/null +++ b/collects/racket/place/distributed/RMPI.rkt @@ -0,0 +1,252 @@ +#lang racket/base + +(require racket/place/distributed + racket/match + racket/list + racket/place) + +(provide RMPI-init + RMPI-send + RMPI-recv + RMPI-BCast + RMPI-Reduce + RMPI-AllReduce + RMPI-Barrier + RMPI-id + RMPI-cnt + RMPI-partition + RMPI-BuildDefaultConfig + RMPI-Launch + RMPI-finish) + +(struct RMPI-COMM (id cnt channels) #:transparent) + +(define (RMPI-id comm) (RMPI-COMM-id comm)) +(define (RMPI-cnt comm) (RMPI-COMM-cnt comm)) +(define (RMPI-send comm dest val) (place-channel-put (vector-ref (RMPI-COMM-channels comm) dest) val)) +(define (RMPI-recv comm src) (place-channel-get (vector-ref (RMPI-COMM-channels comm) src))) + +(define (RMPI-init ch) + (match-define (list (list id config) return-ch) (place-channel-get ch)) + (match-define (list args src-ch) (place-channel-get ch)) + (define mpi-comm-vector + (for/vector #:length (length config) ([c config]) + (match-define (list dest dest-port dest-name dest-id) c) + (cond + [(< id dest-id) + ;(printf/f "sending connect to dest-id ~a from id ~a over ~a" dest-id id ch) + (send-new-place-channel-to-named-dest ch id (list dest dest-port dest-name))] + [else null]))) + (for ([i (length config)]) + (cond + [(> id i) + (match-define (list src-id src-ch) (place-channel-get ch)) + ;(printf/f "received connect from id ~a ~a" src-id src-ch) + (vector-set! mpi-comm-vector src-id src-ch)] + [else null])) + (values + (RMPI-COMM id (length config) mpi-comm-vector) + args + )) + + +(define RMPI-BCast + (case-lambda + [(comm src) + (RMPI-BCast comm src (void))] + [(comm src val) + (match-define (RMPI-COMM real-id cnt chs) comm) + (define offset (- cnt src)) + (define id (modulo (+ real-id (- cnt src)) cnt)) + (let loop ([i 0] + [val val]) + (define round (arithmetic-shift 1 i)) + (cond + [(< round cnt) + (loop + (add1 i) + (cond + [(< id (arithmetic-shift round 1)) + (cond + [(not (= 0 (bitwise-and id round))) + (define peer-id (- id round)) + (define real-peer-id (modulo (+ peer-id offset) cnt)) + ;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) + (place-channel-get (vector-ref chs real-peer-id)) + ] + [else + (define peer-id (+ id round)) + (define real-peer-id (modulo (+ peer-id offset) cnt)) + ;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) + (place-channel-put (vector-ref chs real-peer-id) val) + val])] + [else val]))] + [else val]))])) + +(define (fancy-reducer op recv-val val) + (cond + [(number? recv-val) + (op recv-val val)] + [(vector? recv-val) + (for/vector #:length (vector-length recv-val) + ([a (in-vector recv-val)] + [b (in-vector val)]) + (fancy-reducer op a b))] + [else (raise (format "fancy-reducer error on ~a ~a ~a" op recv-val val))])) + +(define (RMPI-Reduce comm dest op val) + (match-define (RMPI-COMM real-id cnt chs) comm) + (define i + (let loop ([i 0]) + (if (>= (arithmetic-shift 1 i) cnt) + i + (loop (add1 i))))) + + + (define offset (- cnt dest)) + (define (convert v) (modulo (+ v offset) cnt)) + (define id (convert real-id)) + (let loop ([i i] + [val val]) + (cond + [(> i 0) + (define round (arithmetic-shift 1 (sub1 i))) + (loop + (sub1 i) + (cond + [(< id (arithmetic-shift round 1)) + (cond + [(not (= 0 (bitwise-and id round))) + (define peer-id (- id round)) + (define real-peer-id (convert peer-id)) + ;(printf "SEND ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) + (place-channel-put (vector-ref chs real-peer-id) val) + val + ] + [else + (define peer-id (+ id round)) + (define real-peer-id (convert peer-id)) + ;(printf "RECV ~a ~a ~a ~a ~a ~a ~a\n" round real-id id peer-id real-peer-id offset val) + (define recv-val (place-channel-get (vector-ref chs real-peer-id))) + ;(define recv-val val) + (fancy-reducer op recv-val val)])]))] + [else val]))) + +(define (RMPI-Barrier comm) + (RMPI-Reduce comm 0 + 1) + (RMPI-BCast comm 0 1)) + +(define (RMPI-AllReduce comm op val) + (define rv (RMPI-Reduce comm 0 op val)) + (RMPI-BCast comm 0 rv)) + +(define (partit num cnt id) + (define-values (quo rem) (quotient/remainder num cnt)) + (values (+ (* id quo) (if (< id rem) id 0)) + (+ quo (if (< id rem) 1 0)))) + +(define (RMPI-partition comm num) + (define id (RMPI-id comm)) + (define cnt (RMPI-cnt comm)) + (partit num cnt id)) + +(define RMPI-BuildDefaultConfig + (make-keyword-procedure (lambda (kws kw-args . rest) + (for/hash ([kw kws] + [kwa kw-args]) +; (displayln (keyword? kw)) + (values kw kwa))))) + +(define (RMPI-Launch default config) + (define (lookup-config-value rest key-str) + (define key + (string->keyword key-str)) + (cond + [(null? rest) + (hash-ref default key #f)] + [else + (hash-ref (car rest) key (lambda () + (hash-ref default key #f)))])) + +; (printf/f "~v\n" default) +; (exit 1) + (define nodes + (spawn-nodes/join + (for/list ([c config]) + (match-define (list-rest host port name id _rest) c) + (define rest + (cond + [(null? _rest) + (list (make-immutable-hash (list (cons (string->keyword "listen-port") port))))] + [else + (list + (hash-set (car _rest) (string->keyword "listen-port") port))])) +; (printf/f "~a\n" rest) + (define-values (k v) + (let loop ([keys (list "racket-path" "listen-port" "distributed-launch-path")] + [k null] + [v null]) + (cond + [(pair? keys) + (cond + [(lookup-config-value rest (car keys)) => (lambda (x) + (loop (cdr keys) + (cons (string->keyword (car keys)) k) + (cons x v)))] + [else + (loop (cdr keys) k v)])] + [else + (values k v)]))) +; (printf/f "~a\n" (list k v (list host))) + (list k v (list host))))) + + (for ([n nodes] + [c config]) + (match-define (list-rest host port name id rest) c) + (supervise-named-dynamic-place-at n + name + (lookup-config-value rest "mpi-module") + (lookup-config-value rest "mpi-func"))) + + (define-values (mrth ch) + (start-message-router/thread + #:nodes nodes)) + + (for ([c config]) + (match-define (list-rest host port name id rest) c) + (define npch (mr-connect-to ch (list host port) name)) + (*channel-put npch (list id config)) + (*channel-put npch (or (lookup-config-value rest "mpi-args") null))) + + (for/first ([c config]) + (match-define (list-rest host port name id rest) c) + (define npch (mr-connect-to ch (list host port) name)) + (*channel-put npch 'done?) + + ;Wait for 'done message from mpi node id 0 + (*channel-get npch)) + ) + + +(define (RMPI-finish comm ch) + (when (= (RMPI-id comm) 0) + (place-channel-put (second (place-channel-get ch)) 'done))) + +(module+ bcast-print-test + (RMPI-BCast (RMPI-COMM 0 8 (vector 0 1 2 3 4 5 6 7)) 0 "Hi") + (RMPI-BCast (RMPI-COMM 3 8 (vector 0 1 2 3 4 5 6 7)) 0) + (RMPI-BCast (RMPI-COMM 0 8 (vector 0 1 2 3 4 5 6 7)) 3) + ) + +(module+ reduce-print-test + (RMPI-Reduce (RMPI-COMM 0 8 (vector 0 1 2 3 4 5 6 7)) 0 + 7) + (RMPI-Reduce (RMPI-COMM 3 8 (vector 0 1 2 3 4 5 6 7)) 0 + 7) + (RMPI-Reduce (RMPI-COMM 0 8 (vector 0 1 2 3 4 5 6 7)) 3 + 7) + ) + +(module+ test + (require tests/eli-tester) + (test + (partit 10 3 0) => (values 0 4) + (partit 10 3 1) => (values 3 3) + (partit 10 3 2) => (values 6 3)))