diff --git a/collects/openssl/mzssl2.ss b/collects/openssl/mzssl2.ss index dcb67bd5da..f6cca895e1 100644 --- a/collects/openssl/mzssl2.ss +++ b/collects/openssl/mzssl2.ss @@ -12,6 +12,10 @@ ;; (This is due to the fact that unbuffered data cannot be written ;; without blocking.) +;; One last warning: a write/read must block because a previous +;; read/write (the opposite direction) didn't finish, and so that +;; opposite must be completed, first. + (module mzssl2 mzscheme (require (lib "foreign.ss") (lib "port.ss") @@ -212,7 +216,21 @@ (define-struct ssl-listener (l mzctx)) ;; internal: - (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock flushing? refcount close? finalizer-cancel)) + (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w + buffer lock + flushing? must-write must-read + refcount close? finalizer-cancel)) + + (define (make-immobile-bytes n) + (if (regexp-match #rx#"3m" (path->bytes (system-library-subpath))) + ;; Allocate the byte string via malloc: + (atomically + (let* ([p (malloc 'raw n)] + [s (make-sized-byte-string p n)]) + (register-finalizer s (lambda (v) (free p))) + s)) + ;; Normal byte string is immobile: + (make-bytes n))) ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Contexts, certificates, etc. @@ -403,61 +421,104 @@ (pump-output mzssl))))) (define (make-ssl-input-port mzssl) - (make-input-port/read-to-peek - (format "SSL ~a" (object-name (mzssl-i mzssl))) - ;; read proc: - (letrec ([do-read - (lambda (buffer) - (let ([out-blocked? (pump-output mzssl)]) - (let ([n (SSL_read (mzssl-ssl mzssl) buffer (bytes-length buffer))]) - (if (n . >= . 1) - n - (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) - (cond - [(or (= err SSL_ERROR_ZERO_RETURN) - (and (= err SSL_ERROR_SYSCALL) (zero? n))) - ;; We hit the end-of-file - eof] - [(= err SSL_ERROR_WANT_READ) - (let ([n (pump-input-once mzssl #f)]) - (if (eq? n 0) - (wrap-evt (choice-evt - (mzssl-i mzssl) - (if out-blocked? - (mzssl-o mzssl) - never-evt)) - (lambda (x) 0)) - (do-read buffer)))] - [(= err SSL_ERROR_WANT_WRITE) - (if (pump-output-once mzssl #f #f) - (do-read buffer) - (wrap-evt (mzssl-o mzssl) (lambda (x) 0)))] - [else - (error 'read-bytes "SSL read failed ~a" - (get-error-message (ERR_get_error)))]))))))] - [top-read - (lambda (buffer) - (if (mzssl-flushing? mzssl) + ;; If SSL_read produces NEED_READ or NEED_WRITE, then the next + ;; call to SSL_read must use the same arguments. + ;; Use xfer-buffer so we have a consistent buffer to use with + ;; SSL_read across calls to the port's write function. + (let-values ([(xfer-buffer) (make-immobile-bytes 4096)] + [(got-r got-w) (make-pipe)] + [(must-read-len) #f]) + (make-input-port/read-to-peek + (format "SSL ~a" (object-name (mzssl-i mzssl))) + ;; read proc: + (letrec ([do-read + (lambda (buffer) + (let ([out-blocked? (pump-output mzssl)] + [len (or must-read-len (min (bytes-length xfer-buffer) + (bytes-length buffer)))]) + (let ([n (SSL_read (mzssl-ssl mzssl) xfer-buffer len)]) + (if (n . >= . 1) + (begin + (set! must-read-len #f) + (if must-read-len + ;; If we were forced to try to read a certain amount, + ;; then we may have reda too much for the immediate + ;; request. + (let ([orig-n (bytes-length buffer)]) + (bytes-copy! buffer 0 xfer-buffer 0 (min n orig-n)) + (when (n . > . orig-n) + (write-bytes buffer got-w orig-n n))) + (bytes-copy! buffer 0 xfer-buffer 0 n)) + n) + (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) + (cond + [(or (= err SSL_ERROR_ZERO_RETURN) + (and (= err SSL_ERROR_SYSCALL) (zero? n))) + ;; We hit the end-of-file + (set! must-read-len #f) + eof] + [(= err SSL_ERROR_WANT_READ) + (set! must-read-len len) + (let ([n (pump-input-once mzssl #f)]) + (if (eq? n 0) + (begin + (set-mzssl-must-read! mzssl (make-semaphore)) + (wrap-evt (choice-evt + (mzssl-i mzssl) + (if out-blocked? + (mzssl-o mzssl) + never-evt)) + (lambda (x) 0))) + (do-read buffer)))] + [(= err SSL_ERROR_WANT_WRITE) + (set! must-read-len len) + (if (pump-output-once mzssl #f #f) + (do-read buffer) + (begin + (set-mzssl-must-read! mzssl (make-semaphore)) + (wrap-evt (mzssl-o mzssl) (lambda (x) 0))))] + [else + (set! must-read-len #f) + (error 'read-bytes "SSL read failed ~a" + (get-error-message (ERR_get_error)))]))))))] + [top-read + (lambda (buffer) + (cond + [(mzssl-flushing? mzssl) ;; Flush in progress; try again later: - 0 - (do-read buffer)))] - [lock-unavailable - (lambda () (wrap-evt (semaphore-peek-evt (mzssl-lock mzssl)) - (lambda (x) 0)))]) - (lambda (buffer) + 0] + [(mzssl-must-write mzssl) + => (lambda (sema) + (wrap-evt (semaphore-peek-evt sema) (lambda (x) 0)))] + [else + (let ([sema (mzssl-must-read mzssl)]) + (when sema + (set-mzssl-must-read! mzssl #f) + (semaphore-post sema))) + ;; First, try pipe for previously read data: + (let ([n (read-bytes-avail!* buffer got-r)]) + (if (zero? n) + ;; Nothing already read, so use SSL_read: + (do-read buffer) + ;; Got previously read data: + n))]))] + [lock-unavailable + (lambda () (wrap-evt (semaphore-peek-evt (mzssl-lock mzssl)) + (lambda (x) 0)))]) + (lambda (buffer) + (call-with-semaphore + (mzssl-lock mzssl) + top-read + lock-unavailable + buffer))) + ;; fast peek: + #f + ;; close proc: + (lambda () (call-with-semaphore (mzssl-lock mzssl) - top-read - lock-unavailable - buffer))) - ;; fast peek: - #f - ;; close proc: - (lambda () - (call-with-semaphore - (mzssl-lock mzssl) - (lambda () - (mzssl-release mzssl)))))) + (lambda () + (mzssl-release mzssl))))))) (define (flush-ssl mzssl enable-break?) ;; Make sure that this SSL connection has said everything that it @@ -487,11 +548,14 @@ (thread thunk))) (define (make-ssl-output-port mzssl) - ;; Need a consistent buffer to use with SSL_write - ;; across calls to the port's write function. - (let ([xfer-buffer (make-bytes 512)] + ;; If SSL_write produces NEED_READ or NEED_WRITE, then the next + ;; call to SSL_write must use the same arguments. + ;; Use xfer-buffer so we have a consistent buffer to use with + ;; SSL_write across calls to the port's write function. + (let ([xfer-buffer (make-immobile-bytes 4096)] [buffer-mode (or (file-stream-buffer-mode (mzssl-o mzssl)) 'bloack)] - [flush-ch (make-channel)]) + [flush-ch (make-channel)] + [must-write-len #f]) ;; This thread mkoves data from the SLL stream to the underlying ;; output port, because this port's write prodcue claims that the ;; data is flushed if it gets into the SSL stream. In other words, @@ -527,6 +591,7 @@ (let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)]) (if (n . > . 0) (begin + (set! must-write-len #f) ;; Start flush as necessary: (cond [non-block? @@ -546,35 +611,65 @@ (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) (cond [(= err SSL_ERROR_WANT_READ) + (set! must-write-len len) (let ([n (pump-input-once mzssl #f)]) (if (eq? n 0) - (wrap-evt (choice-evt - (mzssl-i mzssl) - (if out-blocked? - (mzssl-o mzssl) - never-evt)) - (lambda (x) #f)) + (begin + (set-mzssl-must-write! mzssl (make-semaphore)) + (wrap-evt (choice-evt + (mzssl-i mzssl) + (if out-blocked? + (mzssl-o mzssl) + never-evt)) + (lambda (x) #f))) (do-write len non-block? enable-break?)))] [(= err SSL_ERROR_WANT_WRITE) + (set! must-write-len len) (if (pump-output-once mzssl #f #f) (do-write len non-block? enable-break?) - (wrap-evt (mzssl-o mzssl) (lambda (x) #f)))] + (begin + (set-mzssl-must-write! mzssl (make-semaphore)) + (wrap-evt (mzssl-o mzssl) (lambda (x) #f))))] [else - (error 'read-bytes "SSL read failed ~a" + (set! must-write-len #f) + (error 'write-bytes "SSL write failed ~a" (get-error-message (ERR_get_error)))])))))))] [top-write (lambda (buffer s e non-block? enable-break?) - (if (mzssl-flushing? mzssl) - ;; Need to wait until flush done - (if (= s e) - ;; Let the background flush finish: - (list (semaphore-peek-evt (mzssl-flushing? mzssl))) - ;; Try again later: - (wrap-evt always-evt (lambda (v) #f))) - ;; Normal write (since no flush is active): - (let ([len (min (- e s) (bytes-length xfer-buffer))]) - (bytes-copy! xfer-buffer 0 buffer s (+ s len)) - (do-write len non-block? enable-break?))))] + (cond + [(mzssl-flushing? mzssl) + ;; Need to wait until flush done + (if (= s e) + ;; Let the background flush finish: + (list (semaphore-peek-evt (mzssl-flushing? mzssl))) + ;; Try again later: + (wrap-evt always-evt (lambda (v) #f)))] + [(mzssl-must-read mzssl) + ;; Read pending, so wait until it's done: + => (lambda (sema) + (wrap-evt (semaphore-peek-evt sema) (lambda (x) #f)))] + [else + ;; Normal write (since no flush is active or read pending): + (let ([sema (mzssl-must-write mzssl)]) + (when sema + (set-mzssl-must-write! mzssl #f) + (semaphore-post sema))) + (let ([len (min (- e s) (bytes-length xfer-buffer))]) + (if must-write-len + ;; Previous SSL_write result obligates certain output: + (begin + (unless (and (len . >= . must-write-len) + (bytes=? (subbytes xfer-buffer 0 must-write-len) + (subbytes buffer s (+ s must-write-len)))) + (error 'write-bytes + "SSL output request: ~e different from previous unsatisfied request: ~e" + (subbytes buffer s e) + (subbytes xfer-buffer 0 must-write-len))) + (do-write must-write-len non-block? enable-break?)) + ;; No previous write obligation: + (begin + (bytes-copy! xfer-buffer 0 buffer s (+ s len)) + (do-write len non-block? enable-break?))))]))] [lock-unavailable (lambda () (wrap-evt (semaphore-peek-evt (mzssl-lock mzssl)) (lambda (x) #f)))]) @@ -710,10 +805,10 @@ (let-values ([(ssl cancel r-bio w-bio connect?) (create-ssl who context-or-encrypt-method connect/accept)]) ;; connect/accept: - (let-values ([(buffer) (make-bytes 512)] + (let-values ([(buffer) (make-bytes 4096)] [(pipe-r pipe-w) (make-pipe)] [(cancel) (box #t)]) - (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f 2 close? cancel)]) + (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f #f #f 2 close? cancel)]) (let loop () (let ([status (if connect? (SSL_connect ssl)