clean up atomic regions, port closing

svn: r2733
This commit is contained in:
Matthew Flatt 2006-04-21 15:27:44 +00:00
parent 63f30b1df1
commit 662af63c05

View File

@ -50,8 +50,9 @@
[(_ id type) [(_ id type)
(with-syntax ([str (symbol->string (syntax-e #'id))]) (with-syntax ([str (symbol->string (syntax-e #'id))])
#'(define id #'(define id
(and chk (if chk
(get-ffi-obj str lib (_fun . type)))))]))])) (get-ffi-obj str lib (_fun . type))
(lambda args (raise-not-available)))))]))]))
(define-define-X define-ssl libssl libssl) (define-define-X define-ssl libssl libssl)
(define-define-X define-mzscheme #t #f) (define-define-X define-mzscheme #t #f)
@ -136,6 +137,9 @@
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Error handling ;; Error handling
(define (raise-not-available)
(error 'openssl "OpenSSL shared library not found"))
(define-syntax with-failure (define-syntax with-failure
(syntax-rules () (syntax-rules ()
[(_ thunk body ...) [(_ thunk body ...)
@ -152,20 +156,45 @@
(define (check-valid v who what) (define (check-valid v who what)
(when (ptr-equal? v #f) (when (ptr-equal? v #f)
(let ([id (ERR_get_error)]) (let ([id (ERR_get_error)])
(error who "~a failed ~a" (escape-atomic
what (lambda ()
(get-error-message id))))) (error who "~a failed ~a"
what
(get-error-message id)))))))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Atomic blocks ;; Atomic blocks
;; Obviously, be careful in an atomic block. In particular,
;; DO NOT CONSTRUCT AN ERROR DIRECTLY IN AN ATOMIC BLOCK,
;; because the error message almost certainly involves things
;; like a ~a or ~e format, which can trigger all sorts of
;; printing extensions. Instead, send a thunk that
;; constructs and raises the exception to `escape-atomic'.
(define in-atomic? (make-parameter #f))
(define-struct (exn:atomic exn) (thunk))
(define-syntax atomically (define-syntax atomically
(syntax-rules () (syntax-rules ()
[(_ body ...) [(_ body ...)
(dynamic-wind (parameterize-break
(lambda () (scheme_start_atomic)) #f
(lambda () body ...) (with-handlers ([exn:atomic? (lambda (exn)
(lambda () (scheme_end_atomic)))])) ((exn:atomic-thunk exn)))])
(parameterize ([in-atomic? #t])
(dynamic-wind
(lambda () (scheme_start_atomic))
(lambda () body ...)
(lambda () (scheme_end_atomic))))))]))
(define (escape-atomic thunk)
(if (in-atomic?)
(raise (make-exn:atomic
"error during atomic..."
(current-continuation-marks)
thunk))
(thunk)))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Structs ;; Structs
@ -177,7 +206,7 @@
(define-struct ssl-listener (l mzctx)) (define-struct ssl-listener (l mzctx))
;; internal: ;; internal:
(define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock flushing? refcount)) (define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock flushing? refcount close? finalizer-cancel))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Contexts, certificates, etc. ;; Contexts, certificates, etc.
@ -198,10 +227,12 @@
[(tls) (if client? [(tls) (if client?
TLSv1_client_method TLSv1_client_method
TLSv1_server_method)] TLSv1_server_method)]
[else (raise-type-error [else (escape-atomic
who (lambda ()
(string-append also-expect "'sslv2-or-v3, 'sslv2, 'sslv3, or 'tls") (raise-type-error
e)]))) who
(string-append also-expect "'sslv2-or-v3, 'sslv2, 'sslv3, or 'tls")
e)))])))
(define make-context (define make-context
(opt-lambda (who protocol-symbol also-expected client?) (opt-lambda (who protocol-symbol also-expected client?)
@ -305,7 +336,12 @@
(lambda () (lambda ()
(set-mzssl-refcount! mzssl (sub1 (mzssl-refcount mzssl))) (set-mzssl-refcount! mzssl (sub1 (mzssl-refcount mzssl)))
(when (zero? (mzssl-refcount mzssl)) (when (zero? (mzssl-refcount mzssl))
(SSL_free (mzssl-ssl mzssl)))))) (atomically
(set-box! (mzssl-finalizer-cancel mzssl) #f)
(SSL_free (mzssl-ssl mzssl)))
(when (mzssl-close? mzssl)
(close-input-port (mzssl-i mzssl))
(close-output-port (mzssl-o mzssl)))))))
(define (pump-input-once mzssl need-progress?/out) (define (pump-input-once mzssl need-progress?/out)
(let ([buffer (mzssl-buffer mzssl)] (let ([buffer (mzssl-buffer mzssl)]
@ -416,7 +452,7 @@
(lambda () (lambda ()
(mzssl-release mzssl)))) (mzssl-release mzssl))))
(define (flush-ssl mzssl) (define (flush-ssl mzssl enable-break?)
;; Make sure that this SSL connection has said everything that it ;; Make sure that this SSL connection has said everything that it
;; wants to say --- that is, move data from the SLL output to the ;; wants to say --- that is, move data from the SLL output to the
;; underlying output port. Depending on the transport, the other end ;; underlying output port. Depending on the transport, the other end
@ -434,7 +470,7 @@
;; wait until either input or output is ready: ;; wait until either input or output is ready:
(when v (when v
(when (eq? v 'blocked) (when (eq? v 'blocked)
(sync (mzssl-o mzssl) (mzssl-i mzssl))) ((if enable-break? sync/enable-break sync) (mzssl-o mzssl) (mzssl-i mzssl)))
(loop))))))) (loop)))))))
(define (kernel-thread thunk) (define (kernel-thread thunk)
@ -461,7 +497,7 @@
(let flush-loop () (let flush-loop ()
(sync flush-ch) (sync flush-ch)
(semaphore-wait (mzssl-lock mzssl)) (semaphore-wait (mzssl-lock mzssl))
(flush-ssl mzssl) (flush-ssl mzssl #f)
(set-mzssl-flushing?! mzssl #f) (set-mzssl-flushing?! mzssl #f)
(semaphore-post (mzssl-lock mzssl)) (semaphore-post (mzssl-lock mzssl))
(loop))))) (loop)))))
@ -477,10 +513,9 @@
;; Flush request; all data is in the the SSL ;; Flush request; all data is in the the SSL
;; stream, but make sure it's gone ;; stream, but make sure it's gone
;; through the ports: ;; through the ports:
(parameterize-break (begin
enable-break? (flush-ssl mzssl enable-break?)
(flush-ssl mzssl) 0)
0)
;; Write request; even if blocking is ok, we treat ;; Write request; even if blocking is ok, we treat
;; it as non-blocking and let MzScheme handle blocking ;; it as non-blocking and let MzScheme handle blocking
(let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)]) (let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)])
@ -570,78 +605,98 @@
(define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?) (define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?)
(wrap-ports 'port->ssl-ports i o context-or-encrypt-method connect/accept close?)) (wrap-ports 'port->ssl-ports i o context-or-encrypt-method connect/accept close?))
(define (create-ssl who context-or-encrypt-method connect/accept)
(atomically ; so we register the finalizer (and it's ok since everything is non-blocking)
(let ([ctx (get-context who context-or-encrypt-method (eq? connect/accept 'connect))])
(check-valid ctx who "context creation")
(with-failure
(lambda () (when (and ctx
(symbol? context-or-encrypt-method))
(SSL_CTX_free ctx)))
(let ([connect? (case connect/accept
[(connect) #t]
[(accept) #f]
[else
(escape-atomic
(lambda ()
(raise-type-error who "'connect or 'accept"
connect/accept)))])]
[r-bio (BIO_new (BIO_s_mem))]
[w-bio (BIO_new (BIO_s_mem))]
[free-bio? #t])
(with-failure
(lambda () (when free-bio?
(BIO_free r-bio)
(BIO_free w-bio)))
(unless (or (symbol? context-or-encrypt-method)
(if connect?
(ssl-client-context? context-or-encrypt-method)
(ssl-server-context? context-or-encrypt-method)))
(escape-atomic
(lambda ()
(error who
"'~a mode requires a ~a context, given: ~e"
(if connect? 'connect 'accept)
(if connect? "client" "server")
context-or-encrypt-method))))
(let ([ssl (SSL_new ctx)]
[cancel (box #t)])
(check-valid ssl who "ssl setup")
;; ssl has a ref count on ctx, so release:
(when (symbol? context-or-encrypt-method)
(SSL_CTX_free ctx)
(set! ctx #f))
(with-failure
(lambda () (SSL_free ssl))
(SSL_set_bio ssl r-bio w-bio)
;; ssl has r-bio & w-bio (no ref count?), so drop it:
(set! free-bio? #f)
;; Register a finalizer for ssl:
(register-finalizer ssl
(lambda (v)
(when (unbox cancel)
(SSL_free ssl))))
;; Return SSL and the cancel boxL:
(values ssl cancel r-bio w-bio connect?)))))))))
(define (wrap-ports who i o context-or-encrypt-method connect/accept close?) (define (wrap-ports who i o context-or-encrypt-method connect/accept close?)
(unless (input-port? i) (unless (input-port? i)
(raise-type-error who "input port" i)) (raise-type-error who "input port" i))
(unless (output-port? o) (unless (output-port? o)
(raise-type-error who "output port" o)) (raise-type-error who "output port" o))
(let ([ctx (get-context who context-or-encrypt-method (eq? connect/accept 'connect))]) ;; Create the SSL connection:
(check-valid ctx who "context creation") (let-values ([(ssl cancel r-bio w-bio connect?)
(with-failure (create-ssl who context-or-encrypt-method connect/accept)])
(lambda () (when (and ctx ;; connect/accept:
(symbol? context-or-encrypt-method)) (let-values ([(buffer) (make-bytes 512)]
(SSL_CTX_free ctx))) [(pipe-r pipe-w) (make-pipe)]
(let ([connect? (case connect/accept [(cancel) (box #t)])
[(connect) #t] (let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f 2 close? cancel)])
[(accept) #f] (let loop ()
[else (let ([status (if connect?
(raise-type-error who "'connect or 'accept" (SSL_connect ssl)
connect/accept)])] (SSL_accept ssl))])
[r-bio (BIO_new (BIO_s_mem))] (let ([out-blocked? (pump-output mzssl)])
[w-bio (BIO_new (BIO_s_mem))] (when (status . < . 1)
[free-bio? #t]) (let ([err (SSL_get_error ssl status)])
(with-failure (cond
(lambda () (when free-bio? [(= err SSL_ERROR_WANT_READ)
(BIO_free r-bio) (let ([n (pump-input-once mzssl (if out-blocked? o #t))])
(BIO_free w-bio))) (when (eof-object? n)
(unless (or (symbol? context-or-encrypt-method) (error who "~a failed (input terminated prematurely)"
(if connect? (if connect? "connect" "accept"))))
(ssl-client-context? context-or-encrypt-method) (loop)]
(ssl-server-context? context-or-encrypt-method))) [(= err SSL_ERROR_WANT_WRITE)
(error who (pump-output-once mzssl #t #f)
"'~a mode requires a ~a context, given: ~e" (loop)]
(if connect? 'connect 'accept) [else
(if connect? "client" "server") (error who "~a failed ~a"
context-or-encrypt-method)) (if connect? "connect" "accept")
(let ([ssl (SSL_new ctx)]) (get-error-message (ERR_get_error)))]))))))
(check-valid ssl who "ssl setup") ;; Connection complete; make ports
;; ssl has a ref count on ctx, so release: (values (make-ssl-input-port mzssl)
(when (symbol? context-or-encrypt-method) (make-ssl-output-port mzssl))))))
(SSL_CTX_free ctx)
(set! ctx #f))
(with-failure
(lambda () (SSL_free ssl))
(SSL_set_bio ssl r-bio w-bio)
;; ssl has r-bio & w-bio (no ref count?), so drop it:
(set! free-bio? #f)
;; connect/accept:
(let-values ([(buffer) (make-bytes 512)]
[(pipe-r pipe-w) (make-pipe)])
(let ([mzssl (make-mzssl ssl i o r-bio w-bio pipe-r pipe-w buffer (make-semaphore 1) #f 2)])
(let loop ()
(let ([status (if connect?
(SSL_connect ssl)
(SSL_accept ssl))])
(let ([out-blocked? (pump-output mzssl)])
(when (status . < . 1)
(let ([err (SSL_get_error ssl status)])
(cond
[(= err SSL_ERROR_WANT_READ)
(let ([n (pump-input-once mzssl (if out-blocked? o #t))])
(when (eof-object? n)
(error who "~a failed (input terminated prematurely)"
(if connect? "connect" "accept"))))
(loop)]
[(= err SSL_ERROR_WANT_WRITE)
(pump-output-once mzssl #t #f)
(loop)]
[else
(error who "~a failed ~a"
(if connect? "connect" "accept")
(get-error-message (ERR_get_error)))]))))))
;; Connection complete; make ports
(values (make-ssl-input-port mzssl)
(make-ssl-output-port mzssl)))))))))))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; SSL listen ;; SSL listen
@ -664,7 +719,17 @@
(define (do-ssl-accept who tcp-accept ssl-listener) (define (do-ssl-accept who tcp-accept ssl-listener)
(let-values ([(i o) (tcp-accept (ssl-listener-l ssl-listener))]) (let-values ([(i o) (tcp-accept (ssl-listener-l ssl-listener))])
(wrap-ports who i o (ssl-listener-mzctx ssl-listener) 'accept #t))) ;; Obviously, there's a race condition between accepting the
;; connections and installing the exception handler below. However,
;; if breaks are enabled, then i and o could get lost between
;; the time that tcp-accept returns and `i' and `o' are bound,
;; anyway. So we can assume that breaks are enabled without loss
;; of (additional) resources.
(with-handlers ([void (lambda (exn)
(close-input-port i)
(close-output-port o)
(raise exn))])
(wrap-ports who i o (ssl-listener-mzctx ssl-listener) 'accept #t))))
(define (ssl-accept ssl-listener) (define (ssl-accept ssl-listener)
(do-ssl-accept 'ssl-accept tcp-accept ssl-listener)) (do-ssl-accept 'ssl-accept tcp-accept ssl-listener))
@ -677,7 +742,12 @@
(define (do-ssl-connect who tcp-connect hostname port-k client-context-or-protocol-symbol) (define (do-ssl-connect who tcp-connect hostname port-k client-context-or-protocol-symbol)
(let-values ([(i o) (tcp-connect hostname port-k)]) (let-values ([(i o) (tcp-connect hostname port-k)])
(wrap-ports who i o client-context-or-protocol-symbol 'connect #t))) ;; See do-ssl-accept for note on race condition here:
(with-handlers ([void (lambda (exn)
(close-input-port i)
(close-output-port o)
(raise exn))])
(wrap-ports who i o client-context-or-protocol-symbol 'connect #t))))
(define ssl-connect (define ssl-connect
(opt-lambda (hostname port-k [client-context-or-protocol-symbol default-encrypt]) (opt-lambda (hostname port-k [client-context-or-protocol-symbol default-encrypt])