From 14f03bcf5ba32288da45393219cc6b6066244f2b Mon Sep 17 00:00:00 2001 From: Matthew Flatt Date: Thu, 1 Mar 2012 11:00:33 -0700 Subject: [PATCH] openssl: thread safety There are many SSL_() functions that produce return codes with more information from SLL_get_error() and/or ERR_get_error(). Those need to be grouped in an atomic section to ensure thread safety at the level of Racket threads. --- collects/openssl/mzssl.rkt | 76 +++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 22 deletions(-) diff --git a/collects/openssl/mzssl.rkt b/collects/openssl/mzssl.rkt index a67834b392..53e93761ba 100644 --- a/collects/openssl/mzssl.rkt +++ b/collects/openssl/mzssl.rkt @@ -153,6 +153,7 @@ (define X509_V_OK 0) + (define SSL_ERROR_SSL 1) (define SSL_ERROR_WANT_READ 2) (define SSL_ERROR_WANT_WRITE 3) (define SSL_ERROR_SYSCALL 5) @@ -303,6 +304,32 @@ ;; Normal byte string is immobile: (make-bytes n))) + ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + ;; Errors + + (define (do-save-errors thunk ssl) + ;; Atomically run a function and get error results + ;; so that this library is thread-safe (at the level of Racket threads) + (atomically + (define v (thunk)) + (define e (if (negative? v) + (SSL_get_error ssl v) + 0)) + (define estr + (cond + [(= e SSL_ERROR_SSL) + (get-error-message (ERR_get_error))] + [(= e SSL_ERROR_SYSCALL) + (define v (ERR_get_error)) + (if (zero? v) + (get-error-message v) + #f)] + [else #f])) + (values v e estr))) + + (define-syntax-rule (save-errors e ssl) + (do-save-errors (lambda () e) ssl)) + ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Contexts, certificates, etc. @@ -375,11 +402,12 @@ (let ([path (path->bytes (path->complete-path (cleanse-path pathname) (current-directory)))]) - (let ([n (load-it ctx path)]) - (unless (= n 1) - (error who "load failed from: ~e ~a" - pathname - (get-error-message (ERR_get_error)))))))) + (atomically ;; for to connect ERR_get_error to `load-it' + (let ([n (load-it ctx path)]) + (unless (= n 1) + (error who "load failed from: ~e ~a" + pathname + (get-error-message (ERR_get_error))))))))) (define (ssl-load-certificate-chain! ssl-context-or-listener pathname) (ssl-load-... 'ssl-load-certificate-chain! @@ -455,9 +483,8 @@ (define (renegotiate who mzssl) (define (check-err thunk) (let loop () - (define v (thunk)) + (define-values (v err estr) (save-errors (thunk) (mzssl-ssl mzssl))) (when (negative? v) - (define err (SSL_get_error (mzssl-ssl mzssl) v)) (cond [(= err SSL_ERROR_WANT_READ) (let ([n (pump-input-once mzssl #f)]) @@ -476,7 +503,7 @@ (sync (mzssl-o mzssl)) (loop)))] [else - (error who "failed: ~a" (get-error-message (ERR_get_error)))])))) + (error who "failed: ~a" estr)])))) (check-err (lambda () (SSL_renegotiate (mzssl-ssl mzssl)))) (check-err (lambda () (SSL_do_handshake (mzssl-ssl mzssl)))) ;; Really demanding a negotiation from the server side @@ -572,7 +599,9 @@ (lambda (buffer) (let ([len (or must-read-len (min (bytes-length xfer-buffer) (bytes-length buffer)))]) - (let ([n (SSL_read (mzssl-ssl mzssl) xfer-buffer len)]) + (let-values ([(n err estr) (save-errors + (SSL_read (mzssl-ssl mzssl) xfer-buffer len) + (mzssl-ssl mzssl))]) (if (n . >= . 1) (begin (set! must-read-len #f) @@ -586,7 +615,7 @@ (write-bytes buffer got-w orig-n n))) (bytes-copy! buffer 0 xfer-buffer 0 n)) n) - (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) + (let () (cond [(or (= err SSL_ERROR_ZERO_RETURN) (and (= err SSL_ERROR_SYSCALL) (zero? n))) @@ -621,7 +650,7 @@ (set! must-read-len #f) ((mzssl-error mzssl) 'read-bytes "SSL read failed ~a" - (get-error-message (ERR_get_error)))]))))))] + estr)]))))))] [top-read (lambda (buffer) (cond @@ -733,7 +762,8 @@ 0) ;; Write request; even if blocking is ok, we treat ;; it as non-blocking and let Racket handle blocking - (let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)]) + (let-values ([(n err estr) (save-errors (SSL_write (mzssl-ssl mzssl) xfer-buffer len) + (mzssl-ssl mzssl))]) (if (n . > . 0) (begin (set! must-write-len #f) @@ -750,7 +780,7 @@ ;; through (even though we're allowed to buffer): (flush-ssl mzssl enable-break?)]) n) - (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) + (let () (cond [(= err SSL_ERROR_WANT_READ) (when enforce-retry? @@ -783,7 +813,7 @@ (set! must-write-len #f) ((mzssl-error mzssl) 'write-bytes "SSL write failed ~a" - (get-error-message (ERR_get_error)))])))))))] + estr)])))))))] [top-write (lambda (buffer s e non-block? enable-break?) (cond @@ -850,9 +880,10 @@ (when (mzssl-shutdown-on-close? mzssl) (let loop ([cnt 0]) (let ([out-blocked? (flush-ssl mzssl #f)]) - (let ([n (SSL_shutdown (mzssl-ssl mzssl))]) + (let-values ([(n err estr) (save-errors (SSL_shutdown (mzssl-ssl mzssl)) + (mzssl-ssl mzssl))]) (unless (= n 1) - (let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) + (let () (cond [(= err SSL_ERROR_WANT_READ) (pump-input-once mzssl (if out-blocked? (mzssl-o mzssl) #t)) @@ -869,7 +900,7 @@ (loop (add1 cnt))) ((mzssl-error mzssl) 'read-bytes "SSL shutdown failed ~a" - (get-error-message (ERR_get_error))))]))))))) + estr))]))))))) (set-mzssl-w-closed?! mzssl #t) (mzssl-release mzssl) #f]))] @@ -978,12 +1009,13 @@ cancel error/ssl)]) (let loop () - (let ([status (if connect? - (SSL_connect ssl) - (SSL_accept ssl))]) + (let-values ([(status err estr) (save-errors (if connect? + (SSL_connect ssl) + (SSL_accept ssl)) + ssl)]) (let ([out-blocked? (pump-output mzssl)]) (when (status . < . 1) - (let ([err (SSL_get_error ssl status)]) + (let () (cond [(= err SSL_ERROR_WANT_READ) (let ([n (pump-input-once mzssl (if out-blocked? o #t))]) @@ -997,7 +1029,7 @@ [else (error/ssl who "~a failed ~a" (if connect? "connect" "accept") - (get-error-message (ERR_get_error)))])))))) + estr)])))))) ;; Connection complete; make ports (values (register (make-ssl-input-port mzssl) mzssl #t) (register (make-ssl-output-port mzssl) mzssl #f))))))