racket/collects/db/private/mysql/connection.rkt
2012-01-08 23:25:53 -07:00

614 lines
23 KiB
Racket

#lang racket/base
(require racket/class
racket/match
openssl
openssl/sha1
"../generic/interfaces.rkt"
"../generic/prepared.rkt"
"../generic/sql-data.rkt"
"message.rkt"
"dbsystem.rkt")
(provide connection%
password-hash)
(define MAX-PACKET-LENGTH #x1000000)
;; ========================================
(define connection%
(class* transactions% (connection<%>)
(init-private notice-handler)
(define inport #f)
(define outport #f)
(inherit call-with-lock
call-with-lock*
add-delayed-call!
check-valid-tx-status
check-statement/tx)
(inherit-field tx-status)
(super-new)
;; with-disconnect-on-error
(define-syntax-rule (with-disconnect-on-error . body)
(with-handlers ([exn:fail? (lambda (e) (disconnect* #f) (raise e))])
. body))
;; ========================================
;; == Debugging
(define DEBUG? #f)
(define/public (debug debug?)
(set! DEBUG? debug?))
;; ========================================
;; == Communication
;; (Must be called with lock acquired.)
(define next-msg-num 0)
(define/private (fresh-exchange)
(set! next-msg-num 0))
;; send-message : message -> void
(define/private (send-message msg)
(buffer-message msg)
(flush-message-buffer))
;; buffer-message : message -> void
(define/private (buffer-message msg)
(when DEBUG?
(fprintf (current-error-port) " >> ~s\n" msg))
(with-disconnect-on-error
(write-packet outport msg next-msg-num)
(set! next-msg-num (add1 next-msg-num))))
;; flush-message-buffer : -> void
(define/private (flush-message-buffer)
(with-disconnect-on-error
(flush-output outport)))
;; recv : symbol/#f [(list-of symbol)] -> message
;; Automatically handles asynchronous messages
(define/private (recv fsym expectation [field-dvecs #f])
(define r
(with-disconnect-on-error
(recv* fsym expectation field-dvecs)))
(when (error-packet? r)
(raise-backend-error fsym r))
r)
(define/private (recv* fsym expectation field-dvecs)
(define (advance . ss)
(unless (or (not expectation)
(null? ss)
(memq expectation ss))
(error/comm fsym)))
(define (err packet)
(error/comm fsym))
(let-values ([(msg-num next) (parse-packet inport expectation field-dvecs)])
(set! next-msg-num (add1 msg-num))
(when DEBUG?
(eprintf " << ~s\n" next))
;; Update transaction status (see Transactions below)
(when (ok-packet? next)
(set! tx-status
(bitwise-bit-set? (ok-packet-server-status next) 0)))
(when (eof-packet? next)
(set! tx-status
(bitwise-bit-set? (eof-packet-server-status next) 0)))
(when (error-packet? next)
(when tx-status
(when (member (error-packet-errno next) '(1213 1205))
(set! tx-status 'invalid))))
(match next
[(? handshake-packet?)
(advance 'handshake)]
[(? ok-packet?)
(advance)]
[(? change-plugin-packet?)
(advance 'auth)]
[(? error-packet?)
(advance)]
[(struct result-set-header-packet (field-count _))
(advance 'result)]
[(? field-packet?)
(advance 'field)]
[(? row-data-packet?)
(advance 'data)]
[(? binary-row-data-packet?)
(advance 'binary-data)]
[(? ok-prepared-statement-packet? result)
(advance 'prep-ok)]
[(? parameter-packet? result)
(advance 'prep-params)]
[(? eof-packet?)
(advance 'field 'data 'binary-data 'prep-params)]
[(struct unknown-packet (expected contents))
(error/comm fsym expected)]
[else
(err next)])
next))
;; ========================================
;; Connection management
;; disconnect : -> (void)
(define/public (disconnect)
(disconnect* #t))
(define/private (disconnect* lock-not-held?)
(define (go politely?)
(when DEBUG?
(eprintf " ** Disconnecting\n"))
(let ([outport* outport]
[inport* inport])
(when outport
(when politely?
(fresh-exchange)
(send-message (make-command-packet 'quit "")))
(close-output-port outport)
(set! outport #f))
(when inport
(close-input-port inport)
(set! inport #f))))
;; If we don't hold the lock, try to acquire it and disconnect politely.
;; Except, if already disconnected, no need to acquire lock.
(cond [(and lock-not-held? (connected?))
(call-with-lock* 'disconnect
(lambda () (go #t))
(lambda () (go #f))
#f)]
[else (go #f)]))
;; connected? : -> boolean
(define/override (connected?)
(let ([outport outport])
(and outport (not (port-closed? outport)))))
(define/public (get-dbsystem)
dbsystem)
;; ========================================
;; == Connect
;; attach-to-ports : input-port output-port -> void
(define/public (attach-to-ports in out)
(set! inport in)
(set! outport out))
;; start-connection-protocol : string/#f string string/#f -> void
(define/public (start-connection-protocol dbname username password ssl ssl-context)
(with-disconnect-on-error
(fresh-exchange)
(let ([r (recv 'mysql-connect 'handshake)])
(match r
[(struct handshake-packet (pver sver tid scramble capabilities charset status auth))
(check-required-flags capabilities)
(unless (member auth '("mysql_native_password" #f))
(uerror 'mysql-connect "unsupported authentication plugin: ~s" auth))
(define do-ssl?
(and (case ssl ((yes optional) #t) ((no) #f))
(memq 'ssl capabilities)))
(when (and (eq? ssl 'yes) (not do-ssl?))
(uerror 'mysql-connect "server refused SSL connection"))
(define wanted-capabilities (desired-capabilities capabilities do-ssl? dbname))
(when do-ssl?
(send-message (make-abbrev-client-auth-packet wanted-capabilities))
(let-values ([(sin sout)
(ports->ssl-ports inport outport
#:mode 'connect
#:context ssl-context
#:close-original? #t)])
(attach-to-ports sin sout)))
(authenticate wanted-capabilities username password dbname
(or auth "mysql_native_password") scramble)]
[_ (error/comm 'mysql-connect "during authentication")]))))
(define/private (authenticate capabilities username password dbname auth-plugin scramble)
(let loop ([auth-plugin auth-plugin] [scramble scramble] [first? #t])
(define (auth data)
(if first?
(make-client-auth-packet capabilities MAX-PACKET-LENGTH 'utf8-general-ci
username data dbname auth-plugin)
(make-auth-followup-packet data)))
(cond [(equal? auth-plugin "mysql_native_password")
(send-message (auth (scramble-password scramble password)))]
[(equal? auth-plugin "mysql_old_password")
(send-message (auth (bytes-append (old-scramble-password scramble password)
(bytes 0))))]
[else (uerror 'mysql-connect
"server does not support authentication plugin: ~s"
auth-plugin)])
(match (recv 'mysql-connect 'auth)
[(struct ok-packet (_ _ status warnings message))
(after-connect)]
[(struct change-plugin-packet (plugin data))
;; if plugin = #f, means "mysql_old_password"
(loop (or plugin "mysql_old_password") (or data scramble) #f)])))
(define/private (check-required-flags capabilities)
(for-each (lambda (rf)
(unless (memq rf capabilities)
(uerror 'mysql-connect
"server does not support required capability: ~s"
rf)))
REQUIRED-CAPABILITIES))
(define/private (desired-capabilities capabilities ssl? dbname)
(append (if ssl? '(ssl) '())
(if dbname '(connect-with-db) '())
'(interactive)
(filter (lambda (c) (memq c DESIRED-CAPABILITIES)) capabilities)))
;; Set connection to use utf8 encoding
(define/private (after-connect)
(query 'mysql-connect "set names 'utf8'")
(void))
;; ========================================
;; == Query
;; query : symbol Statement -> QueryResult
(define/public (query fsym stmt)
(check-valid-tx-status fsym)
(let*-values ([(stmt result)
(call-with-lock fsym
(lambda ()
(let* ([stmt (check-statement fsym stmt)]
[stmt-type
(cond [(statement-binding? stmt)
(send (statement-binding-pst stmt) get-stmt-type)]
[(string? stmt)
(classify-my-sql stmt)])])
(check-statement/tx fsym stmt-type)
(values stmt (query1 fsym stmt #t)))))])
(when #f ;; DISABLED---for some reason, *really* slow
(statement:after-exec stmt))
(query1:process-result fsym result)))
;; query1 : symbol Statement -> QueryResult
(define/private (query1 fsym stmt warnings?)
(let ([wbox (and warnings? (box 0))])
(fresh-exchange)
(query1:enqueue stmt)
(begin0 (query1:collect fsym (not (string? stmt)) wbox)
(when (and warnings? (not (zero? (unbox wbox))))
(fetch-warnings fsym)))))
;; check-statement : symbol any -> statement-binding
(define/private (check-statement fsym stmt)
(cond [(statement-binding? stmt)
(let ([pst (statement-binding-pst stmt)])
(send pst check-owner fsym this stmt)
(for ([typeid (in-list (send pst get-result-typeids))])
(unless (supported-result-typeid? typeid)
(error/unsupported-type fsym typeid)))
stmt)]
[(and (string? stmt) (force-prepare-sql? fsym stmt))
(let ([pst (prepare1 fsym stmt #t)])
(check-statement fsym (send pst bind fsym null)))]
[else stmt]))
;; query1:enqueue : statement -> void
(define/private (query1:enqueue stmt)
(cond [(statement-binding? stmt)
(let* ([pst (statement-binding-pst stmt)]
[id (send pst get-handle)]
[params (statement-binding-params stmt)]
[null-map (map sql-null? params)])
(send-message
(make-execute-packet id null null-map params)))]
[else ;; string
(send-message (make-command-packet 'query stmt))]))
;; query1:collect : symbol bool -> QueryResult stream
(define/private (query1:collect fsym binary? wbox)
(let ([r (recv fsym 'result)])
(match r
[(struct ok-packet (affected-rows insert-id status warnings message))
(when wbox (set-box! wbox warnings))
(vector 'command `((affected-rows . ,affected-rows)
(insert-id . ,insert-id)
(status . ,status)
(message . ,message)))]
[(struct result-set-header-packet (fields extra))
(let* ([field-dvecs (query1:get-fields fsym binary?)]
[rows (query1:get-rows fsym field-dvecs binary? wbox)])
(vector 'rows field-dvecs rows))])))
(define/private (query1:get-fields fsym binary?)
(let ([r (recv fsym 'field)])
(match r
[(? field-packet?)
(cons (parse-field-dvec r) (query1:get-fields fsym binary?))]
[(struct eof-packet (warning status))
null])))
(define/private (query1:get-rows fsym field-dvecs binary? wbox)
;; Note: binary? should always be #t, unless force-prepare-sql? misses something.
(let ([r (recv fsym (if binary? 'binary-data 'data) field-dvecs)])
(match r
[(struct row-data-packet (data))
(cons data (query1:get-rows fsym field-dvecs binary? wbox))]
[(struct binary-row-data-packet (data))
(cons data (query1:get-rows fsym field-dvecs binary? wbox))]
[(struct eof-packet (warnings status))
(when wbox (set-box! wbox warnings))
null])))
(define/private (query1:process-result fsym result)
(match result
[(vector 'rows field-dvecs rows)
(rows-result (map field-dvec->field-info field-dvecs) rows)]
[(vector 'command command-info)
(simple-result command-info)]))
;; == Prepare
;; prepare : symbol string boolean -> PreparedStatement
(define/public (prepare fsym stmt close-on-exec?)
(check-valid-tx-status fsym)
(call-with-lock fsym
(lambda ()
(prepare1 fsym stmt close-on-exec?))))
(define/private (prepare1 fsym stmt close-on-exec?)
(fresh-exchange)
(send-message (make-command-packet 'statement-prepare stmt))
(let ([r (recv fsym 'prep-ok)])
(match r
[(struct ok-prepared-statement-packet (id fields params))
(let ([param-dvecs
(if (zero? params) null (prepare1:get-field-descriptions fsym))]
[field-dvecs
(if (zero? fields) null (prepare1:get-field-descriptions fsym))])
(new prepared-statement%
(handle id)
(close-on-exec? close-on-exec?)
(param-typeids (map field-dvec->typeid param-dvecs))
(result-dvecs field-dvecs)
(stmt-type (classify-my-sql stmt))
(owner this)))])))
(define/private (prepare1:get-field-descriptions fsym)
(let ([r (recv fsym 'field)])
(match r
[(struct eof-packet (warning-count status))
null]
[(? field-packet?)
(cons (parse-field-dvec r) (prepare1:get-field-descriptions fsym))])))
(define/public (get-base) this)
(define/public (free-statement pst)
(call-with-lock* 'free-statement
(lambda ()
(let ([id (send pst get-handle)])
(when (and id outport) ;; outport = connected?
(send pst set-handle #f)
(fresh-exchange)
(send-message (make-command:statement-packet 'statement-close id)))))
void
#f))
;; == Warnings
(define/private (fetch-warnings fsym)
(unless (eq? notice-handler void)
(let ([result (query1 fsym "SHOW WARNINGS" #f)])
(define (find-index name dvecs)
(for/or ([dvec (in-list dvecs)]
[i (in-naturals)])
(and (equal? (field-dvec->name dvec) name) i)))
(match result
[(vector 'rows field-dvecs rows)
(let ([code-index (find-index "Code" field-dvecs)]
[message-index (find-index "Message" field-dvecs)])
(for ([row (in-list rows)])
(let ([code (string->number (vector-ref row code-index))]
[message (vector-ref row message-index)])
(add-delayed-call! (lambda () (notice-handler code message))))))]))))
;; == Transactions
;; MySQL: what causes implicit commit, when is transaction rolled back
;; http://dev.mysql.com/doc/refman/5.1/en/implicit-commit.html
;; http://dev.mysql.com/doc/refman/5.1/en/innodb-error-handling.html
;; http://dev.mysql.com/doc/refman/5.1/en/innodb-error-codes.html
;;
;; Sounds like MySQL rolls back transaction (but may keep open!) on
;; - transaction deadlock = 1213 (ER_LOCK_DEADLOCK)
;; - lock wait timeout (depends on config) = 1205 (ER_LOCK_WAIT_TIMEOUT)
(define/override (start-transaction* fsym isolation)
(cond [(eq? isolation 'nested)
(let ([savepoint (generate-name)])
(query1 fsym (format "SAVEPOINT ~a" savepoint) #t)
savepoint)]
[else
(let ([isolation-level (isolation-symbol->string isolation)])
(when isolation-level
(query1 fsym (format "SET TRANSACTION ISOLATION LEVEL ~a" isolation-level) #t))
(query1 fsym "START TRANSACTION" #t)
#f)]))
(define/override (end-transaction* fsym mode savepoint)
(case mode
((commit)
(cond [savepoint
(query1 fsym (format "RELEASE SAVEPOINT ~a" savepoint) #t)]
[else
(query1 fsym "COMMIT" #t)]))
((rollback)
(cond [savepoint
(query1 fsym (format "ROLLBACK TO SAVEPOINT ~a" savepoint) #t)
(query1 fsym (format "RELEASE SAVEPOINT ~a" savepoint) #t)]
[else
(query1 fsym "ROLLBACK" #t)])))
(void))
;; name-counter : number
(define name-counter 0)
;; generate-name : -> string
(define/private (generate-name)
(let ([n name-counter])
(set! name-counter (add1 name-counter))
(format "λmz_~a" n)))
;; Reflection
(define/public (list-tables fsym schema)
(let* ([stmt
;; schema is ignored; search = current
(string-append "SELECT table_name FROM information_schema.tables "
"WHERE table_schema = schema()")]
[rows
(vector-ref (call-with-lock fsym (lambda () (query1 fsym stmt #t))) 2)])
(for/list ([row (in-list rows)])
(vector-ref row 0))))
))
;; ========================================
;; scramble-password : bytes string -> bytes
(define (scramble-password scramble password)
(and scramble password
(let* ([stage1 (cond [(string? password) (password-hash password)]
[(pair? password)
(hex-string->bytes (cadr password))])]
[stage2 (sha1-bytes (open-input-bytes stage1))]
[stage3 (sha1-bytes (open-input-bytes (bytes-append scramble stage2)))]
[reply (bytes-xor stage1 stage3)])
reply)))
;; password-hash : string -> bytes
(define (password-hash password)
(let* ([password (string->bytes/latin-1 password)]
[stage1 (sha1-bytes (open-input-bytes password))])
stage1))
;; bytes-xor : bytes bytes -> bytes
;; Assumes args are same length
(define (bytes-xor a b)
(let ([c (make-bytes (bytes-length a))])
(let loop ([i 0])
(when (< i (bytes-length c))
(bytes-set! c i
(bitwise-xor (bytes-ref a i) (bytes-ref b i)))
(loop (add1 i))))
c))
;; =======================================
(provide old-scramble-password
hash323
hash323->string)
(define (old-scramble-password scramble password)
(define (xor a b) (bitwise-xor a b))
(define RMAX #x3FFFFFFF)
(and scramble password
(let* ([scramble (subbytes scramble 0 8)]
[password (string->bytes/utf-8 password)]
[hp (hash323 password)]
[hm (hash323 scramble)]
[r1 (modulo (xor (car hp) (car hm)) RMAX)]
[r2 (modulo (xor (cdr hp) (cdr hm)) RMAX)]
[out (make-bytes 8 0)])
(define (rnd)
(set! r1 (modulo (+ (* 3 r1) r2) RMAX))
(set! r2 (modulo (+ r1 r2 33) RMAX))
(/ (exact->inexact r1) (exact->inexact RMAX)))
(for ([i (in-range (bytes-length scramble))])
(let ([b (+ (inexact->exact (floor (* (rnd) 31))) 64)])
(bytes-set! out i b)
(values r1 r2)))
(let ([extra (inexact->exact (floor (* (rnd) 31)))])
(for ([i (in-range (bytes-length scramble))])
(bytes-set! out i (xor (bytes-ref out i) extra))))
out)))
(define (hash323 bs)
(define (xor a b) (bitwise-xor a b))
(define-syntax-rule (normalize! var)
(set! var (bitwise-and var (sub1 (arithmetic-shift 1 64)))))
(let ([nr 1345345333]
[add 7]
[nr2 #x12345671])
(for ([i (in-range (bytes-length bs))]
#:when (not (memv (bytes-ref bs i) '(#\space #\tab))))
(let ([tmp (bytes-ref bs i)])
(set! nr (xor nr
(+ (* (+ (bitwise-and nr 63) add) tmp)
(arithmetic-shift nr 8))))
(normalize! nr)
(set! nr2 (+ nr2
(xor (arithmetic-shift nr2 8) nr)))
(normalize! nr2)
(set! add (+ add tmp))
(normalize! add)))
(cons (bitwise-and nr (sub1 (arithmetic-shift 1 31)))
(bitwise-and nr2 (sub1 (arithmetic-shift 1 31))))))
(define (hash323->string bs)
(let ([p (hash323 bs)])
(bytes-append (integer->integer-bytes (car p) 4 #f #f)
(integer->integer-bytes (cdr p) 4 #f #f))))
;; ========================================
(define REQUIRED-CAPABILITIES
'(long-flag
connect-with-db
protocol-41
secure-connection))
(define DESIRED-CAPABILITIES
'(long-password
long-flag
transactions
protocol-41
secure-connection
plugin-auth))
;; raise-backend-error : symbol ErrorPacket -> raises exn
(define (raise-backend-error who r)
(define code (error-packet-sqlstate r))
(define message (error-packet-message r))
(define props (list (cons 'errno (error-packet-errno r))
(cons 'code code)
(cons 'message message)))
(raise-sql-error who code message props))
;; ========================================
#|
MySQL allows only certain kinds of statements to be prepared; the rest
must go through the old execution path. See here:
http://dev.mysql.com/doc/refman/5.0/en/c-api-prepared-statements.html
According to that page, the following statements may be prepared:
CALL, CREATE TABLE, DELETE, DO, INSERT, REPLACE, SELECT, SET, UPDATE,
and most SHOW statements
On the other hand, we want to force all rows-returning statements
through the prepared-statement path to use the binary data
protocol. That would seem to be the following:
SELECT and SHOW
|#
(define (force-prepare-sql? fsym stmt)
(memq (classify-my-sql stmt) '(select show)))