diff --git a/collects/openssl/mzssl2.ss b/collects/openssl/mzssl2.ss index bc00374dfb..aa10566fdc 100644 --- a/collects/openssl/mzssl2.ss +++ b/collects/openssl/mzssl2.ss @@ -50,8 +50,9 @@ [(_ id type) (with-syntax ([str (symbol->string (syntax-e #'id))]) #'(define id - (and chk - (get-ffi-obj str lib (_fun . type)))))]))])) + (if chk + (get-ffi-obj str lib (_fun . type)) + (lambda args (raise-not-available)))))]))])) (define-define-X define-ssl libssl libssl) (define-define-X define-mzscheme #t #f) @@ -136,6 +137,9 @@ ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Error handling + (define (raise-not-available) + (error 'openssl "OpenSSL shared library not found")) + (define-syntax with-failure (syntax-rules () [(_ thunk body ...) @@ -152,20 +156,45 @@ (define (check-valid v who what) (when (ptr-equal? v #f) (let ([id (ERR_get_error)]) - (error who "~a failed ~a" - what - (get-error-message id))))) + (escape-atomic + (lambda () + (error who "~a failed ~a" + what + (get-error-message id))))))) ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Atomic blocks + ;; Obviously, be careful in an atomic block. In particular, + ;; DO NOT CONSTRUCT AN ERROR DIRECTLY IN AN ATOMIC BLOCK, + ;; because the error message almost certainly involves things + ;; like a ~a or ~e format, which can trigger all sorts of + ;; printing extensions. Instead, send a thunk that + ;; constructs and raises the exception to `escape-atomic'. + + (define in-atomic? (make-parameter #f)) + (define-struct (exn:atomic exn) (thunk)) + (define-syntax atomically (syntax-rules () [(_ body ...) - (dynamic-wind - (lambda () (scheme_start_atomic)) - (lambda () body ...) - (lambda () (scheme_end_atomic)))])) + (parameterize-break + #f + (with-handlers ([exn:atomic? (lambda (exn) + ((exn:atomic-thunk exn)))]) + (parameterize ([in-atomic? #t]) + (dynamic-wind + (lambda () (scheme_start_atomic)) + (lambda () body ...) + (lambda () (scheme_end_atomic))))))])) + + (define (escape-atomic thunk) + (if (in-atomic?) + (raise (make-exn:atomic + "error during atomic..." + (current-continuation-marks) + thunk)) + (thunk))) ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Structs @@ -177,7 +206,7 @@ (define-struct ssl-listener (l mzctx)) ;; internal: - (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock flushing? refcount)) + (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock flushing? refcount close? finalizer-cancel)) ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Contexts, certificates, etc. @@ -198,10 +227,12 @@ [(tls) (if client? TLSv1_client_method TLSv1_server_method)] - [else (raise-type-error - who - (string-append also-expect "'sslv2-or-v3, 'sslv2, 'sslv3, or 'tls") - e)]))) + [else (escape-atomic + (lambda () + (raise-type-error + who + (string-append also-expect "'sslv2-or-v3, 'sslv2, 'sslv3, or 'tls") + e)))]))) (define make-context (opt-lambda (who protocol-symbol also-expected client?) @@ -305,7 +336,12 @@ (lambda () (set-mzssl-refcount! mzssl (sub1 (mzssl-refcount mzssl))) (when (zero? (mzssl-refcount mzssl)) - (SSL_free (mzssl-ssl mzssl)))))) + (atomically + (set-box! (mzssl-finalizer-cancel mzssl) #f) + (SSL_free (mzssl-ssl mzssl))) + (when (mzssl-close? mzssl) + (close-input-port (mzssl-i mzssl)) + (close-output-port (mzssl-o mzssl))))))) (define (pump-input-once mzssl need-progress?/out) (let ([buffer (mzssl-buffer mzssl)] @@ -416,7 +452,7 @@ (lambda () (mzssl-release mzssl)))) - (define (flush-ssl mzssl) + (define (flush-ssl mzssl enable-break?) ;; Make sure that this SSL connection has said everything that it ;; wants to say --- that is, move data from the SLL output to the ;; underlying output port. Depending on the transport, the other end @@ -434,7 +470,7 @@ ;; wait until either input or output is ready: (when v (when (eq? v 'blocked) - (sync (mzssl-o mzssl) (mzssl-i mzssl))) + ((if enable-break? sync/enable-break sync) (mzssl-o mzssl) (mzssl-i mzssl))) (loop))))))) (define (kernel-thread thunk) @@ -461,7 +497,7 @@ (let flush-loop () (sync flush-ch) (semaphore-wait (mzssl-lock mzssl)) - (flush-ssl mzssl) + (flush-ssl mzssl #f) (set-mzssl-flushing?! mzssl #f) (semaphore-post (mzssl-lock mzssl)) (loop))))) @@ -477,10 +513,9 @@ ;; Flush request; all data is in the the SSL ;; stream, but make sure it's gone ;; through the ports: - (parameterize-break - enable-break? - (flush-ssl mzssl) - 0) + (begin + (flush-ssl mzssl enable-break?) + 0) ;; Write request; even if blocking is ok, we treat ;; it as non-blocking and let MzScheme handle blocking (let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)]) @@ -570,78 +605,98 @@ (define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?) (wrap-ports 'port->ssl-ports i o context-or-encrypt-method connect/accept close?)) + (define (create-ssl who context-or-encrypt-method connect/accept) + (atomically ; so we register the finalizer (and it's ok since everything is non-blocking) + (let ([ctx (get-context who context-or-encrypt-method (eq? connect/accept 'connect))]) + (check-valid ctx who "context creation") + (with-failure + (lambda () (when (and ctx + (symbol? context-or-encrypt-method)) + (SSL_CTX_free ctx))) + (let ([connect? (case connect/accept + [(connect) #t] + [(accept) #f] + [else + (escape-atomic + (lambda () + (raise-type-error who "'connect or 'accept" + connect/accept)))])] + [r-bio (BIO_new (BIO_s_mem))] + [w-bio (BIO_new (BIO_s_mem))] + [free-bio? #t]) + (with-failure + (lambda () (when free-bio? + (BIO_free r-bio) + (BIO_free w-bio))) + (unless (or (symbol? context-or-encrypt-method) + (if connect? + (ssl-client-context? context-or-encrypt-method) + (ssl-server-context? context-or-encrypt-method))) + (escape-atomic + (lambda () + (error who + "'~a mode requires a ~a context, given: ~e" + (if connect? 'connect 'accept) + (if connect? "client" "server") + context-or-encrypt-method)))) + (let ([ssl (SSL_new ctx)] + [cancel (box #t)]) + (check-valid ssl who "ssl setup") + ;; ssl has a ref count on ctx, so release: + (when (symbol? context-or-encrypt-method) + (SSL_CTX_free ctx) + (set! ctx #f)) + (with-failure + (lambda () (SSL_free ssl)) + (SSL_set_bio ssl r-bio w-bio) + ;; ssl has r-bio & w-bio (no ref count?), so drop it: + (set! free-bio? #f) + + ;; Register a finalizer for ssl: + (register-finalizer ssl + (lambda (v) + (when (unbox cancel) + (SSL_free ssl)))) + ;; Return SSL and the cancel boxL: + (values ssl cancel r-bio w-bio connect?))))))))) + (define (wrap-ports who i o context-or-encrypt-method connect/accept close?) (unless (input-port? i) (raise-type-error who "input port" i)) (unless (output-port? o) (raise-type-error who "output port" o)) - (let ([ctx (get-context who context-or-encrypt-method (eq? connect/accept 'connect))]) - (check-valid ctx who "context creation") - (with-failure - (lambda () (when (and ctx - (symbol? context-or-encrypt-method)) - (SSL_CTX_free ctx))) - (let ([connect? (case connect/accept - [(connect) #t] - [(accept) #f] - [else - (raise-type-error who "'connect or 'accept" - connect/accept)])] - [r-bio (BIO_new (BIO_s_mem))] - [w-bio (BIO_new (BIO_s_mem))] - [free-bio? #t]) - (with-failure - (lambda () (when free-bio? - (BIO_free r-bio) - (BIO_free w-bio))) - (unless (or (symbol? context-or-encrypt-method) - (if connect? - (ssl-client-context? context-or-encrypt-method) - (ssl-server-context? context-or-encrypt-method))) - (error who - "'~a mode requires a ~a context, given: ~e" - (if connect? 'connect 'accept) - (if connect? "client" "server") - context-or-encrypt-method)) - (let ([ssl (SSL_new ctx)]) - (check-valid ssl who "ssl setup") - ;; ssl has a ref count on ctx, so release: - (when (symbol? context-or-encrypt-method) - (SSL_CTX_free ctx) - (set! ctx #f)) - (with-failure - (lambda () (SSL_free ssl)) - (SSL_set_bio ssl r-bio w-bio) - ;; ssl has r-bio & w-bio (no ref count?), so drop it: - (set! free-bio? #f) - ;; connect/accept: - (let-values ([(buffer) (make-bytes 512)] - [(pipe-r pipe-w) (make-pipe)]) - (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f 2)]) - (let loop () - (let ([status (if connect? - (SSL_connect ssl) - (SSL_accept ssl))]) - (let ([out-blocked? (pump-output mzssl)]) - (when (status . < . 1) - (let ([err (SSL_get_error ssl status)]) - (cond - [(= err SSL_ERROR_WANT_READ) - (let ([n (pump-input-once mzssl (if out-blocked? o #t))]) - (when (eof-object? n) - (error who "~a failed (input terminated prematurely)" - (if connect? "connect" "accept")))) - (loop)] - [(= err SSL_ERROR_WANT_WRITE) - (pump-output-once mzssl #t #f) - (loop)] - [else - (error who "~a failed ~a" - (if connect? "connect" "accept") - (get-error-message (ERR_get_error)))])))))) - ;; Connection complete; make ports - (values (make-ssl-input-port mzssl) - (make-ssl-output-port mzssl))))))))))) + ;; Create the SSL connection: + (let-values ([(ssl cancel r-bio w-bio connect?) + (create-ssl who context-or-encrypt-method connect/accept)]) + ;; connect/accept: + (let-values ([(buffer) (make-bytes 512)] + [(pipe-r pipe-w) (make-pipe)] + [(cancel) (box #t)]) + (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f 2 close? cancel)]) + (let loop () + (let ([status (if connect? + (SSL_connect ssl) + (SSL_accept ssl))]) + (let ([out-blocked? (pump-output mzssl)]) + (when (status . < . 1) + (let ([err (SSL_get_error ssl status)]) + (cond + [(= err SSL_ERROR_WANT_READ) + (let ([n (pump-input-once mzssl (if out-blocked? o #t))]) + (when (eof-object? n) + (error who "~a failed (input terminated prematurely)" + (if connect? "connect" "accept")))) + (loop)] + [(= err SSL_ERROR_WANT_WRITE) + (pump-output-once mzssl #t #f) + (loop)] + [else + (error who "~a failed ~a" + (if connect? "connect" "accept") + (get-error-message (ERR_get_error)))])))))) + ;; Connection complete; make ports + (values (make-ssl-input-port mzssl) + (make-ssl-output-port mzssl)))))) ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; SSL listen @@ -664,7 +719,17 @@ (define (do-ssl-accept who tcp-accept ssl-listener) (let-values ([(i o) (tcp-accept (ssl-listener-l ssl-listener))]) - (wrap-ports who i o (ssl-listener-mzctx ssl-listener) 'accept #t))) + ;; Obviously, there's a race condition between accepting the + ;; connections and installing the exception handler below. However, + ;; if breaks are enabled, then i and o could get lost between + ;; the time that tcp-accept returns and `i' and `o' are bound, + ;; anyway. So we can assume that breaks are enabled without loss + ;; of (additional) resources. + (with-handlers ([void (lambda (exn) + (close-input-port i) + (close-output-port o) + (raise exn))]) + (wrap-ports who i o (ssl-listener-mzctx ssl-listener) 'accept #t)))) (define (ssl-accept ssl-listener) (do-ssl-accept 'ssl-accept tcp-accept ssl-listener)) @@ -677,7 +742,12 @@ (define (do-ssl-connect who tcp-connect hostname port-k client-context-or-protocol-symbol) (let-values ([(i o) (tcp-connect hostname port-k)]) - (wrap-ports who i o client-context-or-protocol-symbol 'connect #t))) + ;; See do-ssl-accept for note on race condition here: + (with-handlers ([void (lambda (exn) + (close-input-port i) + (close-output-port o) + (raise exn))]) + (wrap-ports who i o client-context-or-protocol-symbol 'connect #t)))) (define ssl-connect (opt-lambda (hostname port-k [client-context-or-protocol-symbol default-encrypt])