diff --git a/collects/db/private/mysql/connection.rkt b/collects/db/private/mysql/connection.rkt index 5f5d44b3d9..d6760f6793 100644 --- a/collects/db/private/mysql/connection.rkt +++ b/collects/db/private/mysql/connection.rkt @@ -111,6 +111,8 @@ (advance 'handshake)] [(? ok-packet?) (advance)] + [(? change-plugin-packet?) + (advance 'auth)] [(? error-packet?) (advance)] [(struct result-set-header-packet (field-count _)) @@ -190,34 +192,50 @@ (match r [(struct handshake-packet (pver sver tid scramble capabilities charset status auth)) (check-required-flags capabilities) - (unless (equal? auth "mysql_native_password") + (unless (member auth '("mysql_native_password" #f)) (uerror 'mysql-connect "unsupported authentication plugin: ~s" auth)) (define do-ssl? (and (case ssl ((yes optional) #t) ((no) #f)) (memq 'ssl capabilities))) (when (and (eq? ssl 'yes) (not do-ssl?)) (uerror 'mysql-connect "server refused SSL connection")) + (define wanted-capabilities (desired-capabilities capabilities do-ssl?)) (when do-ssl? (send-message - (make-abbrev-client-authentication-packet - (desired-capabilities capabilities #t))) + (make-abbrev-client-auth-packet + wanted-capabilities)) (let-values ([(sin sout) (ports->ssl-ports inport outport #:mode 'connect #:context ssl-context #:close-original? #t)]) (attach-to-ports sin sout))) - (send-message - (make-client-authentication-packet - (desired-capabilities capabilities do-ssl?) - MAX-PACKET-LENGTH - 'utf8-general-ci ;; charset - username - (scramble-password scramble password) - dbname)) - (expect-auth-confirmation)] + (authenticate wanted-capabilities username password dbname + (or auth "mysql_native_password") scramble)] [_ (error/comm 'mysql-connect "during authentication")])))) + (define/private (authenticate capabilities username password dbname auth-plugin scramble) + (let loop ([auth-plugin auth-plugin] [scramble scramble] [first? #t]) + (define (auth data) + (if first? + (make-client-auth-packet capabilities MAX-PACKET-LENGTH 'utf8-general-ci + username data dbname auth-plugin) + (make-auth-followup-packet data))) + (cond [(equal? auth-plugin "mysql_native_password") + (send-message (auth (scramble-password scramble password)))] + [(equal? auth-plugin "mysql_old_password") + (send-message (auth (bytes-append (old-scramble-password scramble password) + (bytes 0))))] + [else (uerror 'mysql-connect + "server does not support authentication plugin: ~s" + auth-plugin)]) + (match (recv 'mysql-connect 'auth) + [(struct ok-packet (_ _ status warnings message)) + (after-connect)] + [(struct change-plugin-packet (plugin data)) + ;; if plugin = #f, means "mysql_old_password" + (loop (or plugin "mysql_old_password") (or data scramble) #f)]))) + (define/private (check-required-flags capabilities) (for-each (lambda (rf) (unless (memq rf capabilities) @@ -234,14 +252,6 @@ (cond [ssl? (cons 'ssl base)] [else base]))) - ;; expect-auth-confirmation : -> void - (define/private (expect-auth-confirmation) - (let ([r (recv 'mysql-connect 'auth)]) - (match r - [(struct ok-packet (_ _ status warnings message)) - (after-connect)] - [_ (error/comm 'mysql-connect "after authentication")]))) - ;; Set connection to use utf8 encoding (define/private (after-connect) (query 'mysql-connect "set names 'utf8'") @@ -495,6 +505,65 @@ (loop (add1 i)))) c)) +;; ======================================= + +(provide old-scramble-password + hash323 + hash323->string) + +(define (old-scramble-password scramble password) + (define (xor a b) (bitwise-xor a b)) + (define RMAX #x3FFFFFFF) + (and scramble password + (let* ([scramble (subbytes scramble 0 8)] + [password (string->bytes/utf-8 password)] + [hp (hash323 password)] + [hm (hash323 scramble)] + [r1 (modulo (xor (car hp) (car hm)) RMAX)] + [r2 (modulo (xor (cdr hp) (cdr hm)) RMAX)] + [out (make-bytes 8 0)]) + (define (rnd) + (set! r1 (modulo (+ (* 3 r1) r2) RMAX)) + (set! r2 (modulo (+ r1 r2 33) RMAX)) + (/ (exact->inexact r1) (exact->inexact RMAX))) + (for ([i (in-range (bytes-length scramble))]) + (let ([b (+ (inexact->exact (floor (* (rnd) 31))) 64)]) + (bytes-set! out i b) + (values r1 r2))) + (let ([extra (inexact->exact (floor (* (rnd) 31)))]) + (for ([i (in-range (bytes-length scramble))]) + (bytes-set! out i (xor (bytes-ref out i) extra)))) + out))) + +(define (hash323 bs) + (define (xor a b) (bitwise-xor a b)) + (define-syntax-rule (normalize! var) + (set! var (bitwise-and var (sub1 (arithmetic-shift 1 64))))) + (let ([nr 1345345333] + [add 7] + [nr2 #x12345671]) + (for ([i (in-range (bytes-length bs))] + #:when (not (memv (bytes-ref bs i) '(#\space #\tab)))) + (let ([tmp (bytes-ref bs i)]) + (set! nr (xor nr + (+ (* (+ (bitwise-and nr 63) add) tmp) + (arithmetic-shift nr 8)))) + (normalize! nr) + (set! nr2 (+ nr2 + (xor (arithmetic-shift nr2 8) nr))) + (normalize! nr2) + (set! add (+ add tmp)) + (normalize! add))) + (cons (bitwise-and nr (sub1 (arithmetic-shift 1 31))) + (bitwise-and nr2 (sub1 (arithmetic-shift 1 31)))))) + +(define (hash323->string bs) + (let ([p (hash323 bs)]) + (bytes-append (integer->integer-bytes (car p) 4 #f #f) + (integer->integer-bytes (cdr p) 4 #f #f)))) + +;; ======================================== + (define REQUIRED-CAPABILITIES '(long-flag connect-with-db @@ -507,7 +576,8 @@ transactions protocol-41 secure-connection - connect-with-db)) + connect-with-db + plugin-auth)) ;; raise-backend-error : symbol ErrorPacket -> raises exn (define (raise-backend-error who r) diff --git a/collects/db/private/mysql/message.rkt b/collects/db/private/mysql/message.rkt index 87ed5838cb..347654dc6a 100644 --- a/collects/db/private/mysql/message.rkt +++ b/collects/db/private/mysql/message.rkt @@ -14,8 +14,10 @@ Based on protocol documentation here: packet? (struct-out handshake-packet) - (struct-out client-authentication-packet) - (struct-out abbrev-client-authentication-packet) + (struct-out change-plugin-packet) + (struct-out client-auth-packet) + (struct-out abbrev-client-auth-packet) + (struct-out auth-followup-packet) (struct-out command-packet) (struct-out command:statement-packet) (struct-out command:change-user-packet) @@ -202,19 +204,24 @@ Based on protocol documentation here: auth) #:transparent) -(define-struct (client-authentication-packet packet) +(define-struct (client-auth-packet packet) (client-flags max-packet-length charset user scramble - database) + database + plugin) #:transparent) -(define-struct (abbrev-client-authentication-packet packet) +(define-struct (abbrev-client-auth-packet packet) (client-flags) #:transparent) +(define-struct (auth-followup-packet packet) + (data) + #:transparent) + (define-struct (command-packet packet) (command argument) @@ -306,6 +313,11 @@ Based on protocol documentation here: params) #:transparent) +(define-struct (change-plugin-packet packet) + (plugin + data) + #:transparent) + (define-struct (unknown-packet packet) (expected contents) @@ -322,19 +334,24 @@ Based on protocol documentation here: (define (write-packet* out p) (match p - [(struct abbrev-client-authentication-packet (client-flags)) + [(struct abbrev-client-auth-packet (client-flags)) (io:write-le-int32 out (encode-server-flags client-flags))] - [(struct client-authentication-packet - (client-flags max-length charset user scramble database)) + [(struct client-auth-packet (client-flags max-length charset user scramble database plugin)) (io:write-le-int32 out (encode-server-flags client-flags)) (io:write-le-int32 out max-length) (io:write-byte out (encode-charset charset)) (io:write-bytes out (make-bytes 23 0)) (io:write-null-terminated-string out user) - (if scramble - (io:write-length-coded-bytes out scramble) - (io:write-byte out 0)) - (io:write-null-terminated-string out database)] + (cond [(memq 'secure-connection client-flags) + (io:write-length-coded-bytes out scramble)] + [else ;; old-style scramble is *not* length-coded, but \0-terminated + (io:write-bytes out scramble)]) + (when (memq 'connect-with-db client-flags) + (io:write-null-terminated-string out database)) + (when (memq 'plugin-auth client-flags) + (io:write-null-terminated-string out plugin))] + [(struct auth-followup-packet (data)) + (io:write-bytes out data)] [(struct command-packet (command arg)) (io:write-byte out (encode-command command)) (io:write-null-terminated-bytes out (string->bytes/utf-8 arg))] @@ -388,19 +405,18 @@ Based on protocol documentation here: ((handshake) (parse-handshake-packet in len)) ((auth) - (cond [(eq? (peek-byte in) #x00) - (parse-ok-packet in len)] - [else - (parse-unknown-packet in len "(expected authentication ok packet)")])) + (case (peek-byte in) + ((#x00) (parse-ok-packet in len)) + ((#xFE) (parse-change-plugin-packet in len)) + (else (parse-unknown-packet in len "(expected authentication ok packet)")))) ((ok) - (cond [(eq? (peek-byte in) #x00) - (parse-ok-packet in len)] - [else - (parse-unknown-packet in len "(expected ok packet)")])) + (case (peek-byte in) + ((#x00) (parse-ok-packet in len)) + (else (parse-unknown-packet in len "(expected ok packet)")))) ((result) - (if (eq? (peek-byte in) #x00) - (parse-ok-packet in len) - (parse-result-set-header-packet in len))) + (case (peek-byte in) + ((#x00) (parse-ok-packet in len)) + (else (parse-result-set-header-packet in len)))) ((field) (if (and (eq? (peek-byte in) #xFE) (< len 9)) (parse-eof-packet in len) @@ -414,10 +430,9 @@ Based on protocol documentation here: (parse-eof-packet in len) (parse-binary-row-data-packet in len field-dvecs))) ((prep-ok) - (cond [(eq? (peek-byte in) #x00) - (parse-ok-prepared-statement-packet in len)] - [else - (parse-unknown-packet in len "(expected ok for prepared statement packet)")])) + (case (peek-byte in) + ((#x00) (parse-ok-prepared-statement-packet in len)) + (else (parse-unknown-packet in len "(expected ok for prepared statement packet)")))) ((prep-params) (if (and (eq? (peek-byte in) #xFE) (< len 9)) (parse-eof-packet in len) @@ -467,7 +482,7 @@ Based on protocol documentation here: ;; - in 5.5.12, a null-terminated auth string (cond [(memq 'plugin-auth server-capabilities) (io:read-null-terminated-string in)] - [else "mysql_native_password"])]) + [else #f])]) ;; implicit "mysql_native_password" (make-handshake-packet protocol-version server-version thread-id @@ -490,6 +505,15 @@ Based on protocol documentation here: warning-count (bytes->string/utf-8 message)))) +(define (parse-change-plugin-packet in len) + (let* ([_ (io:read-byte in)] + [plugin (and (port-has-bytes? in) + (io:read-null-terminated-string in))] + [data (and (port-has-bytes? in) + (io:read-bytes-to-eof in))]) + ;; If plugin = #f, then changing to old password plugin. + (make-change-plugin-packet plugin data))) + (define (parse-error-packet in len) (let* ([_ (io:read-byte in)] [errno (io:read-le-int16 in)]