getting close to a working replacement for mzssl

svn: r2730
This commit is contained in:
Matthew Flatt 2006-04-21 03:34:38 +00:00
parent 489519f5ed
commit d186dab805

View File

@ -2,39 +2,71 @@
;; This is a re-implementation of "mzssl.c" using `(lib "foreign.ss")'. ;; This is a re-implementation of "mzssl.c" using `(lib "foreign.ss")'.
;; It will soon replace "mzssl.c". ;; It will soon replace "mzssl.c".
;; Warn clients: even when a (non-blocking) write fails to write all
;; the data, the stream is committed to writing the given data in
;; the future. (This requirement comes from the SSL library.)
(module mzssl2 mzscheme (module mzssl2 mzscheme
(require (lib "foreign.ss") (require (lib "foreign.ss")
(lib "port.ss") (lib "port.ss")
(lib "etc.ss")) (lib "etc.ss"))
(provide ssl-make-client-context (provide ssl-available?
ports->ssl-ports)
ssl-make-client-context
ssl-make-server-context
ssl-client-context?
ssl-server-context?
ssl-context?
ssl-load-certificate-chain!
ssl-load-private-key!
ssl-load-verify-root-certificates!
ssl-load-suggested-certificate-authorities!
ssl-set-verify!
ports->ssl-ports
ssl-listen
ssl-close
ssl-accept
ssl-accept/enable-break
ssl-connect
ssl-connect/enable-break)
(unsafe!) (unsafe!)
(define libssl (ffi-lib "libssl")) (define libssl (with-handlers ([exn:fail? (lambda (x) #f)])
(ffi-lib "libssl")))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; SSL bindings and constants ;; SSL bindings and constants
(define-syntax define-define-X (define-syntax define-define-X
(syntax-rules () (syntax-rules ()
[(_ id lib) [(_ id chk lib)
(define-syntax (id stx) (define-syntax (id stx)
(syntax-case stx () (syntax-case stx ()
[(_ id type) [(_ id type)
(with-syntax ([str (symbol->string (syntax-e #'id))]) (with-syntax ([str (symbol->string (syntax-e #'id))])
#'(define id #'(define id
(get-ffi-obj str lib (_fun . type))))]))])) (and chk
(get-ffi-obj str lib (_fun . type)))))]))]))
(define-define-X define-ssl libssl) (define-define-X define-ssl libssl libssl)
(define-define-X define-mzscheme #f) (define-define-X define-mzscheme #t #f)
(define-fun-syntax _BIO_METHOD* (syntax-id-rules () [_ _pointer])) (define-syntax typedef
(define-fun-syntax _BIO* (syntax-id-rules () [_ _pointer])) (syntax-rules ()
(define-fun-syntax _SSL_METHOD* (syntax-id-rules () [_ _pointer])) [(_ id t)
(define-fun-syntax _SSL_CTX* (syntax-id-rules () [_ _pointer])) (define-fun-syntax id (syntax-id-rules () [_ t]))]))
(define-fun-syntax _SSL* (syntax-id-rules () [_ _pointer]))
(typedef _BIO_METHOD* _pointer)
(typedef _BIO* _pointer)
(typedef _SSL_METHOD* _pointer)
(typedef _SSL_CTX* _pointer)
(typedef _SSL* _pointer)
(typedef _X509_NAME* _pointer)
(define-ssl SSLv2_client_method (-> _SSL_METHOD*)) (define-ssl SSLv2_client_method (-> _SSL_METHOD*))
(define-ssl SSLv2_server_method (-> _SSL_METHOD*)) (define-ssl SSLv2_server_method (-> _SSL_METHOD*))
@ -58,6 +90,14 @@
(define-ssl SSL_CTX_new (_SSL_METHOD* -> _SSL_CTX*)) (define-ssl SSL_CTX_new (_SSL_METHOD* -> _SSL_CTX*))
(define-ssl SSL_CTX_free (_SSL_CTX* -> _void)) (define-ssl SSL_CTX_free (_SSL_CTX* -> _void))
(define-ssl SSL_CTX_set_verify (_SSL_CTX* _int _pointer -> _void))
(define-ssl SSL_CTX_use_certificate_chain_file (_SSL_CTX* _bytes -> _int))
(define-ssl SSL_CTX_load_verify_locations (_SSL_CTX* _bytes -> _int))
(define-ssl SSL_CTX_set_client_CA_list (_SSL_CTX* _X509_NAME* -> _int))
(define-ssl SSL_CTX_use_RSAPrivateKey_file (_SSL_CTX* _bytes _int -> _int))
(define-ssl SSL_CTX_use_PrivateKey_file (_SSL_CTX* _bytes _int -> _int))
(define-ssl SSL_load_client_CA_file (_bytes -> _X509_NAME*))
(define-ssl SSL_new (_SSL_CTX* -> _SSL*)) (define-ssl SSL_new (_SSL_CTX* -> _SSL*))
(define-ssl SSL_set_bio (_SSL* _BIO* _BIO* -> _void)) (define-ssl SSL_set_bio (_SSL* _BIO* _BIO* -> _void))
(define-ssl SSL_connect (_SSL* -> _int)) (define-ssl SSL_connect (_SSL* -> _int))
@ -82,6 +122,13 @@
(define BIO_C_SET_BUF_MEM_EOF_RETURN 130) (define BIO_C_SET_BUF_MEM_EOF_RETURN 130)
(define SSL_FILETYPE_PEM 1)
(define SSL_FILETYPE_ASN1 2)
(define SSL_VERIFY_NONE #x00)
(define SSL_VERIFY_PEER #x01)
(define SSL_VERIFY_FAIL_IF_NO_PEER_CERT #x02)
(define-mzscheme scheme_start_atomic (-> _void)) (define-mzscheme scheme_start_atomic (-> _void))
(define-mzscheme scheme_end_atomic (-> _void)) (define-mzscheme scheme_end_atomic (-> _void))
@ -120,8 +167,19 @@
(lambda () (scheme_end_atomic)))])) (lambda () (scheme_end_atomic)))]))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Structs
(define-struct ssl-client-context (ctx)) (define-struct ssl-context (ctx))
(define-struct (ssl-client-context ssl-context) ())
(define-struct (ssl-server-context ssl-context) ())
(define-struct ssl-listener (l mzctx))
;; internal:
(define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock refcount))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Contexts, certificates, etc.
(define default-encrypt 'sslv2-or-v3) (define default-encrypt 'sslv2-or-v3)
@ -144,21 +202,101 @@
(string-append also-expect "'sslv2-or-v3, 'sslv2, 'sslv3, or 'tls") (string-append also-expect "'sslv2-or-v3, 'sslv2, 'sslv3, or 'tls")
e)]))) e)])))
(define make-context
(opt-lambda (who protocol-symbol also-expected client?)
(let ([meth (encrypt->method who also-expected protocol-symbol client?)])
(atomically ; so we reliably register the finalizer
(let ([ctx (SSL_CTX_new meth)])
(check-valid ctx who "context creation")
(register-finalizer ctx (lambda (v) (SSL_CTX_free v)))
((if client? make-ssl-client-context make-ssl-server-context) ctx))))))
(define ssl-make-client-context (define ssl-make-client-context
(opt-lambda ([protocol-symbol default-encrypt]) (opt-lambda ([protocol-symbol default-encrypt])
(let ([meth (encrypt->method 'ssl-make-client-context "" protocol-symbol #t)]) (make-context 'ssl-make-client-context protocol-symbol "" #t)))
(atomically ; so we reliably regsiter the finalizer
(let ([ctx (SSL_CTX_new meth)])
(check-valid ctx 'ssl-make-client-context "context creation")
(register-finalizer ctx (lambda (v) (SSL_CTX_free v)))
(make-ssl-client-context ctx))))))
(define (get-context who context-or-encrypt-method) (define ssl-make-server-context
(if (ssl-client-context? context-or-encrypt-method) (opt-lambda ([protocol-symbol default-encrypt])
(ssl-client-context-ctx context-or-encrypt-method) (make-context 'ssl-make-server-context protocol-symbol "" #f)))
(SSL_CTX_new (encrypt->method who "client context, " context-or-encrypt-method #t))))
(define-struct mzssl (ssl i o r-bio w-bio pipe-r pipe-w buffer lock refcount)) (define (get-context who context-or-encrypt-method client?)
(if (ssl-context? context-or-encrypt-method)
(ssl-context-ctx context-or-encrypt-method)
(SSL_CTX_new (encrypt->method who "context" context-or-encrypt-method client?))))
(define (get-context/listener who ssl-context-or-listener)
(cond
[(ssl-context? ssl-context-or-listener)
(ssl-context-ctx ssl-context-or-listener)]
[(ssl-listener? ssl-context-or-listener)
(ssl-context-ctx (ssl-listener-mzctx ssl-context-or-listener))]
[else
(raise-type-error who
"SSL context or listener"
ssl-context-or-listener)]))
(define (ssl-load-... who load-it ssl-context-or-listener pathname)
(let ([ctx (get-context/listener 'ssl-load-certificate-chain!
ssl-context-or-listener)])
(unless (path-string? pathname)
(raise-type-error 'ssl-load-certificate-chain!
"path or string"
pathname))
(let ([path (path->bytes
(path->complete-path (expand-path pathname)
(current-directory)))])
(let ([n (load-it ctx path)])
(unless (= n 1)
(error who "load failed from: ~e ~a"
pathname
(get-error-message (ERR_get_error))))))))
(define (ssl-load-certificate-chain! ssl-context-or-listener pathname)
(ssl-load-... 'ssl-load-certificate-chain!
SSL_CTX_use_certificate_chain_file
ssl-context-or-listener pathname))
(define (ssl-load-verify-root-certificates! ssl-context-or-listener pathname)
(ssl-load-... 'ssl-load-verify-root-certificates!
SSL_CTX_load_verify_locations
ssl-context-or-listener pathname))
(define (ssl-load-suggested-certificate-authorities! ssl-listener pathname)
(ssl-load-... 'ssl-load-suggested-certificate-authorities!
(lambda (ctx path)
(let ([stk (SSL_load_client_CA_file path)])
(if (ptr-equal? stk #f)
0
(begin
(SSL_CTX_set_client_CA_list ctx stk)
1))))
ssl-listener pathname))
(define ssl-load-private-key!
(opt-lambda (ssl-context-or-listener pathname [rsa? #t] [asn1? #f])
(ssl-load-... 'ssl-load-private-key!
(lambda (ctx path)
((if rsa?
SSL_CTX_use_RSAPrivateKey_file
SSL_CTX_use_PrivateKey_file)
ctx path
(if asn1?
SSL_FILETYPE_ASN1
SSL_FILETYPE_PEM)))
ssl-context-or-listener pathname)))
(define (ssl-set-verify! ssl-context-or-listener on?)
(let ([ctx (get-context/listener 'ssl-set-verify!
ssl-context-or-listener)])
(SSL_CTX_set_verify ctx
(if on?
(bitwise-ior SSL_VERIFY_PEER
SSL_VERIFY_FAIL_IF_NO_PEER_CERT)
SSL_VERIFY_NONE)
#f)))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; SSL ports
(define (mzssl-release mzssl) (define (mzssl-release mzssl)
(call-with-semaphore (call-with-semaphore
@ -183,7 +321,7 @@
(error 'pump-input-once "couldn't write all bytes to BIO!")) (error 'pump-input-once "couldn't write all bytes to BIO!"))
m)])))) m)]))))
(define (pump-output-once mzssl need-progress?) (define (pump-output-once mzssl need-progress? output-blocked-result)
(let ([buffer (mzssl-buffer mzssl)] (let ([buffer (mzssl-buffer mzssl)]
[pipe-r (mzssl-pipe-r mzssl)] [pipe-r (mzssl-pipe-r mzssl)]
[pipe-w (mzssl-pipe-w mzssl)] [pipe-w (mzssl-pipe-w mzssl)]
@ -199,16 +337,17 @@
#f) #f)
(begin (begin
(write-bytes buffer pipe-w 0 n) (write-bytes buffer pipe-w 0 n)
(pump-output-once mzssl need-progress?)))) (pump-output-once mzssl need-progress?
output-blocked-result))))
(let ([n ((if need-progress? write-bytes-avail write-bytes-avail*) buffer o 0 n)]) (let ([n ((if need-progress? write-bytes-avail write-bytes-avail*) buffer o 0 n)])
(if (zero? n) (if (zero? n)
#f output-blocked-result
(begin (begin
(port-commit-peeked n (port-progress-evt pipe-r) always-evt pipe-r) (port-commit-peeked n (port-progress-evt pipe-r) always-evt pipe-r)
#t))))))) #t)))))))
(define (pump-output mzssl) (define (pump-output mzssl)
(when (pump-output-once mzssl #f) (when (pump-output-once mzssl #f #f)
(pump-output mzssl))) (pump-output mzssl)))
(define (make-ssl-input-port mzssl) (define (make-ssl-input-port mzssl)
@ -233,7 +372,7 @@
(wrap-evt (mzssl-i mzssl) (lambda (x) 0)) (wrap-evt (mzssl-i mzssl) (lambda (x) 0))
(do-read buffer)))] (do-read buffer)))]
[(= err SSL_ERROR_WANT_WRITE) [(= err SSL_ERROR_WANT_WRITE)
(if (pump-output-once mzssl #f) (if (pump-output-once mzssl #f #f)
(do-read buffer) (do-read buffer)
(wrap-evt (mzssl-o mzssl) (lambda (x) 0)))] (wrap-evt (mzssl-o mzssl) (lambda (x) 0)))]
[else [else
@ -254,20 +393,34 @@
(mzssl-release mzssl)))) (mzssl-release mzssl))))
(define (make-ssl-output-port mzssl) (define (make-ssl-output-port mzssl)
;; Need a consistent buffer to use with SSL_write
;; across calls to the port's write function.
(let ([xfer-buffer (make-bytes 512)])
(make-output-port (make-output-port
(format "SSL ~a" (object-name (mzssl-o mzssl))) (format "SSL ~a" (object-name (mzssl-o mzssl)))
(mzssl-o mzssl) (mzssl-o mzssl)
;; write proc: ;; write proc:
(letrec ([do-write (letrec ([do-write
(lambda (buffer s e block-ok? enable-break?) (lambda (len block-ok? enable-break?)
(pump-output mzssl) (pump-output mzssl)
(if (= s e) (if (zero? len)
0 ;; Flush request; all data is in the the SSL
(let ([n (SSL_write (mzssl-ssl mzssl) ;; stream, but how do we know that it's gone
(if (zero? s) ;; through the ports (which may involve both
buffer ;; output and input)? It seems that making
(subbytes buffer s e)) ;; sure all output is gone is sufficient.
(- e s))]) ;; We've already pumped output, but maybe some
;; is stuck in the bio...
(parameterize-break
enable-break?
(let loop ()
(flush-output (mzssl-o mzssl))
(when (pump-output-once mzssl #f #t)
(loop)))
0)
;; Write request; even if blocking is ok, we treat
;; it as non-blocking and let MzScheme handle blocking
(let ([n (SSL_write (mzssl-ssl mzssl) xfer-buffer len)])
(if (n . > . 0) (if (n . > . 0)
n n
(let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) (let ([err (SSL_get_error (mzssl-ssl mzssl) n)])
@ -276,28 +429,39 @@
(let ([n (pump-input-once mzssl #f)]) (let ([n (pump-input-once mzssl #f)])
(if (eq? n 0) (if (eq? n 0)
(wrap-evt (mzssl-i mzssl) (lambda (x) #f)) (wrap-evt (mzssl-i mzssl) (lambda (x) #f))
(do-write buffer s e block-ok? enable-break?)))] (do-write len block-ok? enable-break?)))]
[(= err SSL_ERROR_WANT_WRITE) [(= err SSL_ERROR_WANT_WRITE)
(if (pump-output-once mzssl #f) (if (pump-output-once mzssl #f #f)
(do-write buffer s e block-ok? enable-break?) (do-write len block-ok? enable-break?)
(wrap-evt (mzssl-o mzssl) (lambda (x) #f)))] (wrap-evt (mzssl-o mzssl) (lambda (x) #f)))]
[else [else
(error 'read-bytes "SSL read failed ~a" (error 'read-bytes "SSL read failed ~a"
(get-error-message (ERR_get_error)))]))))))] (get-error-message (ERR_get_error)))]))))))]
[top-write
(lambda (buffer s e block-ok? enable-break?)
(bytes-copy! xfer-buffer 0 buffer s e)
(do-write (- e s) block-ok? enable-break?))]
[lock-unavailable [lock-unavailable
(lambda () (wrap-evt (mzssl-lock mzssl) (lambda (x) #f)))]) (lambda () (wrap-evt (mzssl-lock mzssl) (lambda (x) #f)))])
(lambda (buffer s e block-ok? enable-break?) (lambda (buffer s e block-ok? enable-break?)
(call-with-semaphore (call-with-semaphore
(mzssl-lock mzssl) (mzssl-lock mzssl)
do-write top-write
lock-unavailable lock-unavailable
buffer s e block-ok? enable-break?))) buffer s e block-ok? enable-break?)))
;; close proc: ;; close proc:
(lambda () (lambda ()
;; issue shutdown (i.e., EOF on read end) ;; issue shutdown (i.e., EOF on read end)
(let loop () (let loop ([cnt 1])
(pump-output mzssl) (pump-output mzssl)
(let ([n (SSL_shutdown (mzssl-ssl mzssl))]) (let ([n (SSL_shutdown (mzssl-ssl mzssl))])
(if (= n 0)
;; 0 seems to be the result in many cases because the socket
;; is non-blocking, and then neither of the WANTs is returned.
;; We address this by simply trying 10 times and then giving
;; up. The two-step shutdown is optional, anyway.
(unless (cnt . >= . 10)
(loop (add1 cnt)))
(unless (= n 1) (unless (= n 1)
(let ([err (SSL_get_error (mzssl-ssl mzssl) n)]) (let ([err (SSL_get_error (mzssl-ssl mzssl) n)])
(cond (cond
@ -305,20 +469,22 @@
(pump-input-once mzssl #t) (pump-input-once mzssl #t)
(loop)] (loop)]
[(= err SSL_ERROR_WANT_WRITE) [(= err SSL_ERROR_WANT_WRITE)
(pump-output-once mzssl #t) (pump-output-once mzssl #t #f)
(loop)] (loop)]
[else [else
(error 'read-bytes "SSL shutdown failed ~a" (error 'read-bytes "SSL shutdown failed ~a"
(get-error-message (ERR_get_error)))]))))) (get-error-message (ERR_get_error)))]))))))
(mzssl-release mzssl)))) (mzssl-release mzssl)))))
(define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?) (define (ports->ssl-ports i o context-or-encrypt-method connect/accept close?)
(let ([who 'input-port->ssl-input-port]) (wrap-ports 'port->ssl-ports i o context-or-encrypt-method connect/accept close?))
(define (wrap-ports who i o context-or-encrypt-method connect/accept close?)
(unless (input-port? i) (unless (input-port? i)
(raise-type-error who "input port" i)) (raise-type-error who "input port" i))
(unless (output-port? o) (unless (output-port? o)
(raise-type-error who "output port" o)) (raise-type-error who "output port" o))
(let ([ctx (get-context who context-or-encrypt-method)]) (let ([ctx (get-context who context-or-encrypt-method (eq? connect/accept 'connect))])
(check-valid ctx who "context creation") (check-valid ctx who "context creation")
(with-failure (with-failure
(lambda () (when (and ctx (lambda () (when (and ctx
@ -337,6 +503,15 @@
(lambda () (when free-bio? (lambda () (when free-bio?
(BIO_free r-bio) (BIO_free r-bio)
(BIO_free w-bio))) (BIO_free w-bio)))
(unless (or (symbol? context-or-encrypt-method)
(if connect?
(ssl-client-context? context-or-encrypt-method)
(ssl-server-context? context-or-encrypt-method)))
(error who
"'~a mode requires a ~a context, given: ~e"
(if connect? 'connect 'accept)
(if connect? "client" "server")
context-or-encrypt-method))
(let ([ssl (SSL_new ctx)]) (let ([ssl (SSL_new ctx)])
(check-valid ssl who "ssl setup") (check-valid ssl who "ssl setup")
;; ssl has a ref count on ctx, so release: ;; ssl has a ref count on ctx, so release:
@ -367,7 +542,7 @@
(if connect? "connect" "accept")))) (if connect? "connect" "accept"))))
(loop)] (loop)]
[(= err SSL_ERROR_WANT_WRITE) [(= err SSL_ERROR_WANT_WRITE)
(pump-output-once mzssl #t) (pump-output-once mzssl #t #f)
(loop)] (loop)]
[else [else
(error who "~a failed ~a" (error who "~a failed ~a"
@ -375,16 +550,61 @@
(get-error-message (ERR_get_error)))]))))) (get-error-message (ERR_get_error)))])))))
;; Connection complete; make ports ;; Connection complete; make ports
(values (make-ssl-input-port mzssl) (values (make-ssl-input-port mzssl)
(make-ssl-output-port mzssl)))))))))))) (make-ssl-output-port mzssl)))))))))))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; SSL listen
(define ssl-listen
(opt-lambda (port-k [queue-k 5] [reuse? #f] [hostname-or-#f #f] [protocol-symbol-or-context default-encrypt])
(let ([ctx (cond
[(ssl-server-context? protocol-symbol-or-context) protocol-symbol-or-context]
[else (make-context 'ssl-listen protocol-symbol-or-context "server context, " #f)])])
(let ([l (tcp-listen port-k queue-k reuse? hostname-or-#f)])
(make-ssl-listener l ctx)))))
(define (ssl-close l)
(unless (ssl-listener? l)
(raise-type-error 'ssl-close "SSL listener" l))
(tcp-close (ssl-listener-l l)))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; SSL accept
(define (do-ssl-accept who tcp-accept ssl-listener)
(let-values ([(i o) (tcp-accept (ssl-listener-l ssl-listener))])
(wrap-ports who i o (ssl-listener-mzctx ssl-listener) 'accept #t)))
(define (ssl-accept ssl-listener)
(do-ssl-accept 'ssl-accept tcp-accept ssl-listener))
(define (ssl-accept/enable-break ssl-listener)
(do-ssl-accept 'ssl-accept/enable-break tcp-accept/enable-break ssl-listener))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; SSL connect
(define (do-ssl-connect who tcp-connect hostname port-k client-context-or-protocol-symbol)
(let-values ([(i o) (tcp-connect hostname port-k)])
(wrap-ports who i o client-context-or-protocol-symbol 'connect #t)))
(define ssl-connect
(opt-lambda (hostname port-k [client-context-or-protocol-symbol default-encrypt])
(do-ssl-connect 'ssl-connect tcp-connect hostname port-k
client-context-or-protocol-symbol)))
(define ssl-connect/enable-break
(opt-lambda (hostname port-k [client-context-or-protocol-symbol default-encrypt])
(do-ssl-connect 'ssl-connect/enable-break tcp-connect/enable-break hostname port-k
client-context-or-protocol-symbol)))
;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Initialization ;; Initialization
(define ssl-available? (and libssl #t))
(when ssl-available?
(SSL_library_init) (SSL_library_init)
(SSL_load_error_strings) (SSL_load_error_strings))
) )