corrected buffering/flushing behavior of new ssl binding

svn: r2732
This commit is contained in:
Matthew Flatt 2006-04-21 14:23:59 +00:00
parent 1c753c3948
commit 63f30b1df1

View File

@ -2,9 +2,9 @@
;; This is a re-implementation of "mzssl.c" using `(lib "foreign.ss")'. ;; This is a re-implementation of "mzssl.c" using `(lib "foreign.ss")'.
;; It will soon replace "mzssl.c". ;; It will soon replace "mzssl.c".
;; Warn clients: even when a (non-blocking) write fails to write all ;; Warn clients: when a (non-blocking) write fails to write all the
;; the data, the stream is committed to writing the given data in ;; data, the stream is actually committed to writing the given data
;; the future. (This requirement comes from the SSL library.) ;; in the future. (This requirement comes from the SSL library.)
(module mzssl2 mzscheme (module mzssl2 mzscheme
(require (lib "foreign.ss") (require (lib "foreign.ss")
@ -131,6 +131,7 @@
(define-mzscheme scheme_start_atomic (-> _void)) (define-mzscheme scheme_start_atomic (-> _void))
(define-mzscheme scheme_end_atomic (-> _void)) (define-mzscheme scheme_end_atomic (-> _void))
(define-mzscheme scheme_make_custodian (_pointer -> _scheme))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Error handling ;; Error handling
@ -176,7 +177,7 @@
(define-struct ssl-listener (l mzctx)) (define-struct ssl-listener (l mzctx))
;; internal: ;; internal:
(define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock refcount)) (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock flushing? refcount))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Contexts, certificates, etc. ;; Contexts, certificates, etc.
@ -306,16 +307,23 @@
(when (zero? (mzssl-refcount mzssl)) (when (zero? (mzssl-refcount mzssl))
(SSL_free (mzssl-ssl mzssl)))))) (SSL_free (mzssl-ssl mzssl))))))
(define (pump-input-once mzssl need-progress?) (define (pump-input-once mzssl need-progress?/out)
(let ([buffer (mzssl-buffer mzssl)] (let ([buffer (mzssl-buffer mzssl)]
[i (mzssl-i mzssl)] [i (mzssl-i mzssl)]
[r-bio (mzssl-r-bio mzssl)]) [r-bio (mzssl-r-bio mzssl)])
(let ([n ((if need-progress? read-bytes-avail! read-bytes-avail!*) buffer i)]) (let ([n ((if (and need-progress?/out
(not (output-port? need-progress?/out)))
read-bytes-avail!
read-bytes-avail!*)
buffer i)])
(cond (cond
[(eof-object? n) [(eof-object? n)
(BIO_set_mem_eof_return r-bio 1) (BIO_set_mem_eof_return r-bio 1)
eof] eof]
[(zero? n) 0] [(zero? n)
(when need-progress?/out
(sync need-progress?/out i))
0]
[else (let ([m (BIO_write r-bio buffer n)]) [else (let ([m (BIO_write r-bio buffer n)])
(unless (= m n) (unless (= m n)
(error 'pump-input-once "couldn't write all bytes to BIO!")) (error 'pump-input-once "couldn't write all bytes to BIO!"))
@ -337,8 +345,7 @@
#f) #f)
(begin (begin
(write-bytes buffer pipe-w 0 n) (write-bytes buffer pipe-w 0 n)
(pump-output-once mzssl need-progress? (pump-output-once mzssl need-progress? output-blocked-result))))
output-blocked-result))))
(let ([n ((if need-progress? write-bytes-avail write-bytes-avail*) buffer o 0 n)]) (let ([n ((if need-progress? write-bytes-avail write-bytes-avail*) buffer o 0 n)])
(if (zero? n) (if (zero? n)
output-blocked-result output-blocked-result
@ -346,9 +353,14 @@
(port-commit-peeked n (port-progress-evt pipe-r) always-evt pipe-r) (port-commit-peeked n (port-progress-evt pipe-r) always-evt pipe-r)
#t))))))) #t)))))))
;; result is #t if there's more data to send out the
;; underlying output port, but the port is full
(define (pump-output mzssl) (define (pump-output mzssl)
(when (pump-output-once mzssl #f #f) (let ([v (pump-output-once mzssl #f 'blocked)])
(pump-output mzssl))) (if (eq? v 'blocked)
#t
(and v
(pump-output mzssl)))))
(define (make-ssl-input-port mzssl) (define (make-ssl-input-port mzssl)
(make-input-port/read-to-peek (make-input-port/read-to-peek
@ -356,7 +368,7 @@
;; read proc: ;; read proc:
(letrec ([do-read (letrec ([do-read
(lambda (buffer) (lambda (buffer)
(pump-output mzssl) (let ([out-blocked? (pump-output mzssl)])
(let ([n (SSL_read (mzssl-ssl mzssl) buffer (bytes-length buffer))]) (let ([n (SSL_read (mzssl-ssl mzssl) buffer (bytes-length buffer))])
(if (n . >= . 1) (if (n . >= . 1)
n n
@ -369,7 +381,12 @@
[(= err SSL_ERROR_WANT_READ) [(= err SSL_ERROR_WANT_READ)
(let ([n (pump-input-once mzssl #f)]) (let ([n (pump-input-once mzssl #f)])
(if (eq? n 0) (if (eq? n 0)
(wrap-evt (mzssl-i mzssl) (lambda (x) 0)) (wrap-evt (choice-evt
(mzssl-i mzssl)
(if out-blocked?
(mzssl-o mzssl)
never-evt))
(lambda (x) 0))
(do-read buffer)))] (do-read buffer)))]
[(= err SSL_ERROR_WANT_WRITE) [(= err SSL_ERROR_WANT_WRITE)
(if (pump-output-once mzssl #f #f) (if (pump-output-once mzssl #f #f)
@ -377,13 +394,20 @@
(wrap-evt (mzssl-o mzssl) (lambda (x) 0)))] (wrap-evt (mzssl-o mzssl) (lambda (x) 0)))]
[else [else
(error 'read-bytes "SSL read failed ~a" (error 'read-bytes "SSL read failed ~a"
(get-error-message (ERR_get_error)))])))))] (get-error-message (ERR_get_error)))]))))))]
[top-read
(lambda (buffer)
(if (mzssl-flushing? mzssl)
;; Flush in progress; try again later:
0
(do-read buffer)))]
[lock-unavailable [lock-unavailable
(lambda () (wrap-evt (mzssl-lock mzssl) (lambda (x) 0)))]) (lambda () (wrap-evt (semaphore-peek-evt (mzssl-lock mzssl))
(lambda (x) 0)))])
(lambda (buffer) (lambda (buffer)
(call-with-semaphore (call-with-semaphore
(mzssl-lock mzssl) (mzssl-lock mzssl)
do-read top-read
lock-unavailable lock-unavailable
buffer))) buffer)))
;; fast peek: ;; fast peek:
@ -392,89 +416,156 @@
(lambda () (lambda ()
(mzssl-release mzssl)))) (mzssl-release mzssl))))
(define (flush-ssl mzssl)
;; Make sure that this SSL connection has said everything that it
;; wants to say --- that is, move data from the SLL output to the
;; underlying output port. Depending on the transport, the other end
;; may be stuck trying to tell us something before it will listen,
;; so we also have to read in any available information.
(let loop ()
(let ([v (pump-input-once mzssl #f)])
(if (and (number? v) (positive? v))
;; Received some input, so start over
(loop)
;; Try sending output
(let ([v (pump-output-once mzssl #f 'blocked)])
;; If we sent something, continue tring in case there's more.
;; Further, if we blocked on the underlying output, then
;; wait until either input or output is ready:
(when v
(when (eq? v 'blocked)
(sync (mzssl-o mzssl) (mzssl-i mzssl)))
(loop)))))))
(define (kernel-thread thunk)
;; Since we provide #f to scheme_make_custodian,
;; the custodian is managed directly by the root:
(parameterize ([current-custodian (scheme_make_custodian #f)])
(thread thunk)))
(define (make-ssl-output-port mzssl) (define (make-ssl-output-port mzssl)
;; Need a consistent buffer to use with SSL_write ;; Need a consistent buffer to use with SSL_write
;; across calls to the port's write function. ;; across calls to the port's write function.
(let ([xfer-buffer (make-bytes 512)]) (let ([xfer-buffer (make-bytes 512)]
[buffer-mode (or (file-stream-buffer-mode (mzssl-o mzssl)) 'bloack)]
[flush-ch (make-channel)])
;; This thread mkoves data from the SLL stream to the underlying
;; output port, because this port's write prodcue claims that the
;; data is flushed if it gets into the SSL stream. In other words,
;; this flushing thread is analogous to the OS's job of pushing
;; data from a socket through the actual network device. It therefore
;; runs with the highest possible custodian:
(kernel-thread (lambda ()
(let loop ()
(sync flush-ch)
(let flush-loop ()
(sync flush-ch)
(semaphore-wait (mzssl-lock mzssl))
(flush-ssl mzssl)
(set-mzssl-flushing?! mzssl #f)
(semaphore-post (mzssl-lock mzssl))
(loop)))))
;; Create the output port:
(make-output-port (make-output-port
(format "SSL ~a" (object-name (mzssl-o mzssl))) (format "SSL ~a" (object-name (mzssl-o mzssl)))
(mzssl-o mzssl) (mzssl-o mzssl)
;; write proc: ;; write proc:
(letrec ([do-write (letrec ([do-write
(lambda (len block-ok? enable-break?) (lambda (len non-block? enable-break?)
(pump-output mzssl) (let ([out-blocked? (pump-output mzssl)])
(if (zero? len) (if (zero? len)
;; Flush request; all data is in the the SSL ;; Flush request; all data is in the the SSL
;; stream, but how do we know that it's gone ;; stream, but make sure it's gone
;; through the ports (which may involve both ;; through the ports:
;; output and input)? It seems that making
;; sure all output is gone is sufficient.
;; We've already pumped output, but maybe some
;; is stuck in the bio...
(parameterize-break (parameterize-break
enable-break? enable-break?
(let loop () (flush-ssl mzssl)
(flush-output (mzssl-o mzssl))
(when (pump-output-once mzssl #f #t)
(loop)))
0) 0)
;; Write request; even if blocking is ok, we treat ;; Write request; even if blocking is ok, we treat
;; it as non-blocking and let MzScheme handle blocking ;; it as non-blocking and let MzScheme handle blocking
(let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)]) (let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)])
(if (n . > . 0) (if (n . > . 0)
n (begin
;; Start flush in bg thread, if necessary:
(unless (and (not non-block?)
(eq? buffer-mode 'block))
(channel-put flush-ch #t))
n)
(let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) (let ([err (SSL_get_error (mzssl-ssl mzssl) n)])
(cond (cond
[(= err SSL_ERROR_WANT_READ) [(= err SSL_ERROR_WANT_READ)
(let ([n (pump-input-once mzssl #f)]) (let ([n (pump-input-once mzssl #f)])
(if (eq? n 0) (if (eq? n 0)
(wrap-evt (mzssl-i mzssl) (lambda (x) #f)) (wrap-evt (choice-evt
(do-write len block-ok? enable-break?)))] (mzssl-i mzssl)
(if out-blocked?
(mzssl-o mzssl)
never-evt))
(lambda (x) #f))
(do-write len non-block? enable-break?)))]
[(= err SSL_ERROR_WANT_WRITE) [(= err SSL_ERROR_WANT_WRITE)
(if (pump-output-once mzssl #f #f) (if (pump-output-once mzssl #f #f)
(do-write len block-ok? enable-break?) (do-write len non-block? enable-break?)
(wrap-evt (mzssl-o mzssl) (lambda (x) #f)))] (wrap-evt (mzssl-o mzssl) (lambda (x) #f)))]
[else [else
(error 'read-bytes "SSL read failed ~a" (error 'read-bytes "SSL read failed ~a"
(get-error-message (ERR_get_error)))]))))))] (get-error-message (ERR_get_error)))])))))))]
[top-write [top-write
(lambda (buffer s e block-ok? enable-break?) (lambda (buffer s e non-block? enable-break?)
(bytes-copy! xfer-buffer 0 buffer s e) (if (mzssl-flushing? mzssl)
(do-write (- e s) block-ok? enable-break?))] ;; Oops -- wait until flush done
(if (= s e)
;; Ok, it's as good as flushed:
0
;; Try again later:
(wrap-evt always-evt (lambda (v) #f)))
;; Normal write (since no flush is active):
(let ([len (min (- e s) (bytes-length xfer-buffer))])
(bytes-copy! xfer-buffer 0 buffer s (+ s len))
(do-write len non-block? enable-break?))))]
[lock-unavailable [lock-unavailable
(lambda () (wrap-evt (mzssl-lock mzssl) (lambda (x) #f)))]) (lambda () (wrap-evt (semaphore-peek-evt (mzssl-lock mzssl))
(lambda (buffer s e block-ok? enable-break?) (lambda (x) #f)))])
(lambda (buffer s e non-block? enable-break?)
(call-with-semaphore (call-with-semaphore
(mzssl-lock mzssl) (mzssl-lock mzssl)
top-write top-write
lock-unavailable lock-unavailable
buffer s e block-ok? enable-break?))) buffer s e non-block? enable-break?)))
;; close proc: ;; close proc:
(lambda () (lambda ()
;; issue shutdown (i.e., EOF on read end) ;; issue shutdown (i.e., EOF on read end)
(let loop ([cnt 1]) (let loop ([cnt 0])
(pump-output mzssl) (let ([out-blocked? (pump-output mzssl)])
(let ([n (SSL_shutdown (mzssl-ssl mzssl))]) (let ([n (SSL_shutdown (mzssl-ssl mzssl))])
(if (= n 0)
;; 0 seems to be the result in many cases because the socket
;; is non-blocking, and then neither of the WANTs is returned.
;; We address this by simply trying 10 times and then giving
;; up. The two-step shutdown is optional, anyway.
(unless (cnt . >= . 10)
(loop (add1 cnt)))
(unless (= n 1) (unless (= n 1)
(let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) (let ([err (SSL_get_error (mzssl-ssl mzssl) n)])
(cond (cond
[(= err SSL_ERROR_WANT_READ) [(= err SSL_ERROR_WANT_READ)
(pump-input-once mzssl #t) (pump-input-once mzssl (if out-blocked? (mzssl-o mzssl) #t))
(loop)] (loop cnt)]
[(= err SSL_ERROR_WANT_WRITE) [(= err SSL_ERROR_WANT_WRITE)
(pump-output-once mzssl #t #f) (pump-output-once mzssl #t #f)
(loop)] (loop cnt)]
[else [else
(if (= n 0)
;; When 0 is returned, the SSL object no longer correctly
;; reports what it wants (e.g., a write). If pumping blocked
;; or if this is our first time around, then wait on the
;; underlying output and try again.
(when (or (zero? cnt) out-blocked?)
(sync (mzssl-o mzssl))
(loop (add1 cnt)))
(error 'read-bytes "SSL shutdown failed ~a" (error 'read-bytes "SSL shutdown failed ~a"
(get-error-message (ERR_get_error)))])))))) (get-error-message (ERR_get_error))))]))))))
(mzssl-release mzssl))))) (mzssl-release mzssl))
;; Unimplemented port methods:
#f #f #f #f
void 1
;; Buffer mode proc:
(case-lambda
[() buffer-mode]
[(mode) (set! buffer-mode mode)]))))
(define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?) (define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?)
(wrap-ports 'port->ssl-ports i o context-or-encrypt-method connect/accept close?)) (wrap-ports 'port->ssl-ports i o context-or-encrypt-method connect/accept close?))
@ -526,17 +617,17 @@
;; connect/accept: ;; connect/accept:
(let-values ([(buffer) (make-bytes 512)] (let-values ([(buffer) (make-bytes 512)]
[(pipe-r pipe-w) (make-pipe)]) [(pipe-r pipe-w) (make-pipe)])
(let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) 2)]) (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f 2)])
(let loop () (let loop ()
(let ([status (if connect? (let ([status (if connect?
(SSL_connect ssl) (SSL_connect ssl)
(SSL_accept ssl))]) (SSL_accept ssl))])
(pump-output mzssl) (let ([out-blocked? (pump-output mzssl)])
(when (status . < . 1) (when (status . < . 1)
(let ([err (SSL_get_error ssl status)]) (let ([err (SSL_get_error ssl status)])
(cond (cond
[(= err SSL_ERROR_WANT_READ) [(= err SSL_ERROR_WANT_READ)
(let ([n (pump-input-once mzssl #t)]) (let ([n (pump-input-once mzssl (if out-blocked? o #t))])
(when (eof-object? n) (when (eof-object? n)
(error who "~a failed (input terminated prematurely)" (error who "~a failed (input terminated prematurely)"
(if connect? "connect" "accept")))) (if connect? "connect" "accept"))))
@ -547,7 +638,7 @@
[else [else
(error who "~a failed ~a" (error who "~a failed ~a"
(if connect? "connect" "accept") (if connect? "connect" "accept")
(get-error-message (ERR_get_error)))]))))) (get-error-message (ERR_get_error)))]))))))
;; Connection complete; make ports ;; Connection complete; make ports
(values (make-ssl-input-port mzssl) (values (make-ssl-input-port mzssl)
(make-ssl-output-port mzssl))))))))))) (make-ssl-output-port mzssl)))))))))))