[Distributed Places] fix kmeans hang

This commit is contained in:
Kevin Tew 2012-05-02 15:34:10 -06:00
parent 5be7a7d980
commit 1fab365129
2 changed files with 57 additions and 16 deletions

View File

@ -86,6 +86,12 @@
*channel-put *channel-put
send-new-place-channel-to-named-dest send-new-place-channel-to-named-dest
mr-spawn-remote-node
mr-supervise-named-dynamic-place-at
mr-connect-to
start-message-router/thread
spawn-nodes/join
;classes ;classes
event-container<%> event-container<%>
spawned-process% spawned-process%
@ -103,10 +109,16 @@
;re-provides ;re-provides
quote-module-path quote-module-path
named-place-typed-channel%
tc-get
) )
(define-runtime-path distributed-launch-path "distributed/launch.rkt") (define-runtime-path distributed-launch-path "distributed/launch.rkt")
(define (build-distributed-launch-path path) (path->string (build-path path "collects/racket/place/distributed/launch.rkt"))) (define (build-distributed-launch-path [collects-path
(simplify-path (find-executable-path (find-system-path 'exec-file)
(find-system-path 'collects-dir)))])
(path->string (build-path collects-path "racket/place/distributed/launch.rkt")))
(define DEFAULT-ROUTER-PORT 6340) (define DEFAULT-ROUTER-PORT 6340)
@ -258,7 +270,7 @@
(define (send-new-place-channel-to-named-dest ch src-id dest-list) (define (send-new-place-channel-to-named-dest ch src-id dest-list)
(define-values (e1 e2) (place-channel)) (define-values (e1 e2) (place-channel))
(place-channel-put ch (dcgm DCGM-NEW-PLACE-CHANNEL src-id dest-list e2)) (place-channel-put ch (dcgm DCGM-NEW-PLACE-CHANNEL (list 'new-place-channel src-id) dest-list e2))
e1) e1)
@ -1557,14 +1569,16 @@
((cond ((cond
[(place-channel? ch) place-channel-put] [(place-channel? ch) place-channel-put]
[(async-bi-channel? ch) async-bi-channel-put] [(async-bi-channel? ch) async-bi-channel-put]
[(channel? ch) channel-put]) [(channel? ch) channel-put]
[else (raise (format "unknown channel type ~a" ch))])
ch msg)) ch msg))
(define (*channel-get ch) (define (*channel-get ch)
((cond ((cond
[(place-channel? ch) place-channel-get] [(place-channel? ch) place-channel-get]
[(async-bi-channel? ch) async-bi-channel-get] [(async-bi-channel? ch) async-bi-channel-get]
[(channel? ch) channel-get]) [(channel? ch) channel-get]
[else (raise (format "unknown channel type ~a" ch))])
ch)) ch))
(define/provide (mr-spawn-remote-node mrch host #:listen-port [listen-port DEFAULT-ROUTER-PORT] (define/provide (mr-spawn-remote-node mrch host #:listen-port [listen-port DEFAULT-ROUTER-PORT]
@ -1618,3 +1632,29 @@
(define build-node-args (define build-node-args
(make-keyword-procedure (lambda (kws kw-args . rest) (make-keyword-procedure (lambda (kws kw-args . rest)
(list kws kw-args rest)))) (list kws kw-args rest))))
(define named-place-typed-channel%
(class*
object% ()
(init-field ch)
(field [msgs null])
(define/public (get type)
(let loop ([l msgs]
[nl null])
(cond
[(null? l)
(define nm (place-channel-get ch))
;(printf/f "NM ~a ~a\n" type nm)
(set! msgs (append msgs (list nm)))
(loop msgs null)]
[(equal? type (caaar l))
(set! msgs (append (reverse nl) (cdr l)))
(car l)]
[else
(loop (cdr l) (cons (car l) nl))])))
(super-new)
))
(define (tc-get type ch) (send ch get type))

View File

@ -3,7 +3,8 @@
(require racket/place/distributed (require racket/place/distributed
racket/match racket/match
racket/list racket/list
racket/place) racket/place
racket/class)
(provide RMPI-init (provide RMPI-init
RMPI-send RMPI-send
@ -27,8 +28,9 @@
(define (RMPI-recv comm src) (place-channel-get (vector-ref (RMPI-COMM-channels comm) src))) (define (RMPI-recv comm src) (place-channel-get (vector-ref (RMPI-COMM-channels comm) src)))
(define (RMPI-init ch) (define (RMPI-init ch)
(match-define (list (list id config) return-ch) (place-channel-get ch)) (define tc (new named-place-typed-channel% [ch ch]))
(match-define (list args src-ch) (place-channel-get ch)) (match-define (list (list 'mpi-id id config) return-ch) (tc-get 'mpi-id tc))
(match-define (list (list 'args args) src-ch) (tc-get 'args tc))
(define mpi-comm-vector (define mpi-comm-vector
(for/vector #:length (length config) ([c config]) (for/vector #:length (length config) ([c config])
(match-define (list dest dest-port dest-name dest-id) c) (match-define (list dest dest-port dest-name dest-id) c)
@ -40,13 +42,14 @@
(for ([i (length config)]) (for ([i (length config)])
(cond (cond
[(> id i) [(> id i)
(match-define (list src-id src-ch) (place-channel-get ch)) (match-define (list (list 'new-place-channel src-id) src-ch) (tc-get 'new-place-channel tc))
;(printf/f "received connect from id ~a ~a" src-id src-ch) ;(printf/f "received connect from id ~a ~a" src-id src-ch)
(vector-set! mpi-comm-vector src-id src-ch)] (vector-set! mpi-comm-vector src-id src-ch)]
[else null])) [else null]))
(values (values
(RMPI-COMM id (length config) mpi-comm-vector) (RMPI-COMM id (length config) mpi-comm-vector)
args args
tc
)) ))
@ -215,22 +218,20 @@
(for ([c config]) (for ([c config])
(match-define (list-rest host port name id rest) c) (match-define (list-rest host port name id rest) c)
(define npch (mr-connect-to ch (list host port) name)) (define npch (mr-connect-to ch (list host port) name))
(*channel-put npch (list id config)) (*channel-put npch (list 'mpi-id id config))
(*channel-put npch (or (lookup-config-value rest "mpi-args") null))) (*channel-put npch (list 'args (or (lookup-config-value rest "mpi-args") null))))
(for/first ([c config]) (for/first ([c config])
(match-define (list-rest host port name id rest) c) (match-define (list-rest host port name id rest) c)
(define npch (mr-connect-to ch (list host port) name)) (define npch (mr-connect-to ch (list host port) name))
(*channel-put npch 'done?) (*channel-put npch (list 'done?))
;Wait for 'done message from mpi node id 0 ;Wait for 'done message from mpi node id 0
(*channel-get npch)) (*channel-get npch)))
)
(define (RMPI-finish comm ch) (define (RMPI-finish comm tc)
(when (= (RMPI-id comm) 0) (when (= (RMPI-id comm) 0)
(place-channel-put (second (place-channel-get ch)) 'done))) (place-channel-put (second (tc-get 'done? tc)) 'done)))
(module+ bcast-print-test (module+ bcast-print-test
(RMPI-BCast (RMPI-COMM 0 8 (vector 0 1 2 3 4 5 6 7)) 0 "Hi") (RMPI-BCast (RMPI-COMM 0 8 (vector 0 1 2 3 4 5 6 7)) 0 "Hi")