diff --git a/collects/openssl/mzssl2.ss b/collects/openssl/mzssl2.ss index 87208c4fab..bc00374dfb 100644 --- a/collects/openssl/mzssl2.ss +++ b/collects/openssl/mzssl2.ss @@ -2,9 +2,9 @@ ;; This is a re-implementation of "mzssl.c" using `(lib "foreign.ss")'. ;; It will soon replace "mzssl.c". -;; Warn clients: even when a (non-blocking) write fails to write all -;; the data, the stream is committed to writing the given data in -;; the future. (This requirement comes from the SSL library.) +;; Warn clients: when a (non-blocking) write fails to write all the +;; data, the stream is actually committed to writing the given data +;; in the future. (This requirement comes from the SSL library.) (module mzssl2 mzscheme (require (lib "foreign.ss") @@ -131,6 +131,7 @@ (define-mzscheme scheme_start_atomic (-> _void)) (define-mzscheme scheme_end_atomic (-> _void)) + (define-mzscheme scheme_make_custodian (_pointer -> _scheme)) ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Error handling @@ -176,7 +177,7 @@ (define-struct ssl-listener (l mzctx)) ;; internal: - (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock refcount)) + (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock flushing? refcount)) ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Contexts, certificates, etc. @@ -306,16 +307,23 @@ (when (zero? (mzssl-refcount mzssl)) (SSL_free (mzssl-ssl mzssl)))))) - (define (pump-input-once mzssl need-progress?) + (define (pump-input-once mzssl need-progress?/out) (let ([buffer (mzssl-buffer mzssl)] [i (mzssl-i mzssl)] [r-bio (mzssl-r-bio mzssl)]) - (let ([n ((if need-progress? read-bytes-avail! read-bytes-avail!*) buffer i)]) + (let ([n ((if (and need-progress?/out + (not (output-port? need-progress?/out))) + read-bytes-avail! + read-bytes-avail!*) + buffer i)]) (cond [(eof-object? n) (BIO_set_mem_eof_return r-bio 1) eof] - [(zero? n) 0] + [(zero? n) + (when need-progress?/out + (sync need-progress?/out i)) + 0] [else (let ([m (BIO_write r-bio buffer n)]) (unless (= m n) (error 'pump-input-once "couldn't write all bytes to BIO!")) @@ -337,8 +345,7 @@ #f) (begin (write-bytes buffer pipe-w 0 n) - (pump-output-once mzssl need-progress? - output-blocked-result)))) + (pump-output-once mzssl need-progress? output-blocked-result)))) (let ([n ((if need-progress? write-bytes-avail write-bytes-avail*) buffer o 0 n)]) (if (zero? n) output-blocked-result @@ -346,9 +353,14 @@ (port-commit-peeked n (port-progress-evt pipe-r) always-evt pipe-r) #t))))))) + ;; result is #t if there's more data to send out the + ;; underlying output port, but the port is full (define (pump-output mzssl) - (when (pump-output-once mzssl #f #f) - (pump-output mzssl))) + (let ([v (pump-output-once mzssl #f 'blocked)]) + (if (eq? v 'blocked) + #t + (and v + (pump-output mzssl))))) (define (make-ssl-input-port mzssl) (make-input-port/read-to-peek @@ -356,34 +368,46 @@ ;; read proc: (letrec ([do-read (lambda (buffer) - (pump-output mzssl) - (let ([n (SSL_read (mzssl-ssl mzssl) buffer (bytes-length buffer))]) - (if (n . >= . 1) - n - (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) - (cond - [(or (= err SSL_ERROR_ZERO_RETURN) - (and (= err SSL_ERROR_SYSCALL) (zero? n))) - ;; We hit the end-of-file - eof] - [(= err SSL_ERROR_WANT_READ) - (let ([n (pump-input-once mzssl #f)]) - (if (eq? n 0) - (wrap-evt (mzssl-i mzssl) (lambda (x) 0)) - (do-read buffer)))] - [(= err SSL_ERROR_WANT_WRITE) - (if (pump-output-once mzssl #f #f) - (do-read buffer) - (wrap-evt (mzssl-o mzssl) (lambda (x) 0)))] - [else - (error 'read-bytes "SSL read failed ~a" - (get-error-message (ERR_get_error)))])))))] + (let ([out-blocked? (pump-output mzssl)]) + (let ([n (SSL_read (mzssl-ssl mzssl) buffer (bytes-length buffer))]) + (if (n . >= . 1) + n + (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) + (cond + [(or (= err SSL_ERROR_ZERO_RETURN) + (and (= err SSL_ERROR_SYSCALL) (zero? n))) + ;; We hit the end-of-file + eof] + [(= err SSL_ERROR_WANT_READ) + (let ([n (pump-input-once mzssl #f)]) + (if (eq? n 0) + (wrap-evt (choice-evt + (mzssl-i mzssl) + (if out-blocked? + (mzssl-o mzssl) + never-evt)) + (lambda (x) 0)) + (do-read buffer)))] + [(= err SSL_ERROR_WANT_WRITE) + (if (pump-output-once mzssl #f #f) + (do-read buffer) + (wrap-evt (mzssl-o mzssl) (lambda (x) 0)))] + [else + (error 'read-bytes "SSL read failed ~a" + (get-error-message (ERR_get_error)))]))))))] + [top-read + (lambda (buffer) + (if (mzssl-flushing? mzssl) + ;; Flush in progress; try again later: + 0 + (do-read buffer)))] [lock-unavailable - (lambda () (wrap-evt (mzssl-lock mzssl) (lambda (x) 0)))]) + (lambda () (wrap-evt (semaphore-peek-evt (mzssl-lock mzssl)) + (lambda (x) 0)))]) (lambda (buffer) (call-with-semaphore (mzssl-lock mzssl) - do-read + top-read lock-unavailable buffer))) ;; fast peek: @@ -392,89 +416,156 @@ (lambda () (mzssl-release mzssl)))) + (define (flush-ssl mzssl) + ;; Make sure that this SSL connection has said everything that it + ;; wants to say --- that is, move data from the SLL output to the + ;; underlying output port. Depending on the transport, the other end + ;; may be stuck trying to tell us something before it will listen, + ;; so we also have to read in any available information. + (let loop () + (let ([v (pump-input-once mzssl #f)]) + (if (and (number? v) (positive? v)) + ;; Received some input, so start over + (loop) + ;; Try sending output + (let ([v (pump-output-once mzssl #f 'blocked)]) + ;; If we sent something, continue tring in case there's more. + ;; Further, if we blocked on the underlying output, then + ;; wait until either input or output is ready: + (when v + (when (eq? v 'blocked) + (sync (mzssl-o mzssl) (mzssl-i mzssl))) + (loop))))))) + + (define (kernel-thread thunk) + ;; Since we provide #f to scheme_make_custodian, + ;; the custodian is managed directly by the root: + (parameterize ([current-custodian (scheme_make_custodian #f)]) + (thread thunk))) + (define (make-ssl-output-port mzssl) ;; Need a consistent buffer to use with SSL_write ;; across calls to the port's write function. - (let ([xfer-buffer (make-bytes 512)]) + (let ([xfer-buffer (make-bytes 512)] + [buffer-mode (or (file-stream-buffer-mode (mzssl-o mzssl)) 'bloack)] + [flush-ch (make-channel)]) + ;; This thread mkoves data from the SLL stream to the underlying + ;; output port, because this port's write prodcue claims that the + ;; data is flushed if it gets into the SSL stream. In other words, + ;; this flushing thread is analogous to the OS's job of pushing + ;; data from a socket through the actual network device. It therefore + ;; runs with the highest possible custodian: + (kernel-thread (lambda () + (let loop () + (sync flush-ch) + (let flush-loop () + (sync flush-ch) + (semaphore-wait (mzssl-lock mzssl)) + (flush-ssl mzssl) + (set-mzssl-flushing?! mzssl #f) + (semaphore-post (mzssl-lock mzssl)) + (loop))))) + ;; Create the output port: (make-output-port (format "SSL ~a" (object-name (mzssl-o mzssl))) (mzssl-o mzssl) ;; write proc: (letrec ([do-write - (lambda (len block-ok? enable-break?) - (pump-output mzssl) - (if (zero? len) - ;; Flush request; all data is in the the SSL - ;; stream, but how do we know that it's gone - ;; through the ports (which may involve both - ;; output and input)? It seems that making - ;; sure all output is gone is sufficient. - ;; We've already pumped output, but maybe some - ;; is stuck in the bio... - (parameterize-break - enable-break? - (let loop () - (flush-output (mzssl-o mzssl)) - (when (pump-output-once mzssl #f #t) - (loop))) - 0) - ;; Write request; even if blocking is ok, we treat - ;; it as non-blocking and let MzScheme handle blocking - (let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)]) - (if (n . > . 0) - n - (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) - (cond - [(= err SSL_ERROR_WANT_READ) - (let ([n (pump-input-once mzssl #f)]) - (if (eq? n 0) - (wrap-evt (mzssl-i mzssl) (lambda (x) #f)) - (do-write len block-ok? enable-break?)))] - [(= err SSL_ERROR_WANT_WRITE) - (if (pump-output-once mzssl #f #f) - (do-write len block-ok? enable-break?) - (wrap-evt (mzssl-o mzssl) (lambda (x) #f)))] - [else - (error 'read-bytes "SSL read failed ~a" - (get-error-message (ERR_get_error)))]))))))] + (lambda (len non-block? enable-break?) + (let ([out-blocked? (pump-output mzssl)]) + (if (zero? len) + ;; Flush request; all data is in the the SSL + ;; stream, but make sure it's gone + ;; through the ports: + (parameterize-break + enable-break? + (flush-ssl mzssl) + 0) + ;; Write request; even if blocking is ok, we treat + ;; it as non-blocking and let MzScheme handle blocking + (let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)]) + (if (n . > . 0) + (begin + ;; Start flush in bg thread, if necessary: + (unless (and (not non-block?) + (eq? buffer-mode 'block)) + (channel-put flush-ch #t)) + n) + (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) + (cond + [(= err SSL_ERROR_WANT_READ) + (let ([n (pump-input-once mzssl #f)]) + (if (eq? n 0) + (wrap-evt (choice-evt + (mzssl-i mzssl) + (if out-blocked? + (mzssl-o mzssl) + never-evt)) + (lambda (x) #f)) + (do-write len non-block? enable-break?)))] + [(= err SSL_ERROR_WANT_WRITE) + (if (pump-output-once mzssl #f #f) + (do-write len non-block? enable-break?) + (wrap-evt (mzssl-o mzssl) (lambda (x) #f)))] + [else + (error 'read-bytes "SSL read failed ~a" + (get-error-message (ERR_get_error)))])))))))] [top-write - (lambda (buffer s e block-ok? enable-break?) - (bytes-copy! xfer-buffer 0 buffer s e) - (do-write (- e s) block-ok? enable-break?))] + (lambda (buffer s e non-block? enable-break?) + (if (mzssl-flushing? mzssl) + ;; Oops -- wait until flush done + (if (= s e) + ;; Ok, it's as good as flushed: + 0 + ;; Try again later: + (wrap-evt always-evt (lambda (v) #f))) + ;; Normal write (since no flush is active): + (let ([len (min (- e s) (bytes-length xfer-buffer))]) + (bytes-copy! xfer-buffer 0 buffer s (+ s len)) + (do-write len non-block? enable-break?))))] [lock-unavailable - (lambda () (wrap-evt (mzssl-lock mzssl) (lambda (x) #f)))]) - (lambda (buffer s e block-ok? enable-break?) + (lambda () (wrap-evt (semaphore-peek-evt (mzssl-lock mzssl)) + (lambda (x) #f)))]) + (lambda (buffer s e non-block? enable-break?) (call-with-semaphore (mzssl-lock mzssl) top-write lock-unavailable - buffer s e block-ok? enable-break?))) + buffer s e non-block? enable-break?))) ;; close proc: (lambda () ;; issue shutdown (i.e., EOF on read end) - (let loop ([cnt 1]) - (pump-output mzssl) - (let ([n (SSL_shutdown (mzssl-ssl mzssl))]) - (if (= n 0) - ;; 0 seems to be the result in many cases because the socket - ;; is non-blocking, and then neither of the WANTs is returned. - ;; We address this by simply trying 10 times and then giving - ;; up. The two-step shutdown is optional, anyway. - (unless (cnt . >= . 10) - (loop (add1 cnt))) - (unless (= n 1) - (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) - (cond - [(= err SSL_ERROR_WANT_READ) - (pump-input-once mzssl #t) - (loop)] - [(= err SSL_ERROR_WANT_WRITE) - (pump-output-once mzssl #t #f) - (loop)] - [else - (error 'read-bytes "SSL shutdown failed ~a" - (get-error-message (ERR_get_error)))])))))) - (mzssl-release mzssl))))) + (let loop ([cnt 0]) + (let ([out-blocked? (pump-output mzssl)]) + (let ([n (SSL_shutdown (mzssl-ssl mzssl))]) + (unless (= n 1) + (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) + (cond + [(= err SSL_ERROR_WANT_READ) + (pump-input-once mzssl (if out-blocked? (mzssl-o mzssl) #t)) + (loop cnt)] + [(= err SSL_ERROR_WANT_WRITE) + (pump-output-once mzssl #t #f) + (loop cnt)] + [else + (if (= n 0) + ;; When 0 is returned, the SSL object no longer correctly + ;; reports what it wants (e.g., a write). If pumping blocked + ;; or if this is our first time around, then wait on the + ;; underlying output and try again. + (when (or (zero? cnt) out-blocked?) + (sync (mzssl-o mzssl)) + (loop (add1 cnt))) + (error 'read-bytes "SSL shutdown failed ~a" + (get-error-message (ERR_get_error))))])))))) + (mzssl-release mzssl)) + ;; Unimplemented port methods: + #f #f #f #f + void 1 + ;; Buffer mode proc: + (case-lambda + [() buffer-mode] + [(mode) (set! buffer-mode mode)])))) (define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?) (wrap-ports 'port->ssl-ports i o context-or-encrypt-method connect/accept close?)) @@ -526,28 +617,28 @@ ;; connect/accept: (let-values ([(buffer) (make-bytes 512)] [(pipe-r pipe-w) (make-pipe)]) - (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) 2)]) + (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f 2)]) (let loop () (let ([status (if connect? (SSL_connect ssl) (SSL_accept ssl))]) - (pump-output mzssl) - (when (status . < . 1) - (let ([err (SSL_get_error ssl status)]) - (cond - [(= err SSL_ERROR_WANT_READ) - (let ([n (pump-input-once mzssl #t)]) - (when (eof-object? n) - (error who "~a failed (input terminated prematurely)" - (if connect? "connect" "accept")))) - (loop)] - [(= err SSL_ERROR_WANT_WRITE) - (pump-output-once mzssl #t #f) - (loop)] - [else - (error who "~a failed ~a" - (if connect? "connect" "accept") - (get-error-message (ERR_get_error)))]))))) + (let ([out-blocked? (pump-output mzssl)]) + (when (status . < . 1) + (let ([err (SSL_get_error ssl status)]) + (cond + [(= err SSL_ERROR_WANT_READ) + (let ([n (pump-input-once mzssl (if out-blocked? o #t))]) + (when (eof-object? n) + (error who "~a failed (input terminated prematurely)" + (if connect? "connect" "accept")))) + (loop)] + [(= err SSL_ERROR_WANT_WRITE) + (pump-output-once mzssl #t #f) + (loop)] + [else + (error who "~a failed ~a" + (if connect? "connect" "accept") + (get-error-message (ERR_get_error)))])))))) ;; Connection complete; make ports (values (make-ssl-input-port mzssl) (make-ssl-output-port mzssl)))))))))))