db/sqlite3: support create-{function,aggregate}

This commit is contained in:
Ryan Culpepper 2018-04-03 04:29:38 +02:00
parent aadbe1a7d2
commit 93a899cf4c
3 changed files with 161 additions and 4 deletions

View File

@ -12,9 +12,13 @@
"dbsystem.rkt")
(provide connection%
handle-status*
(protect-out unsafe-load-extension))
(protect-out unsafe-load-extension
unsafe-create-function
unsafe-create-aggregate))
(define-local-member-name unsafe-load-extension)
(define-local-member-name unsafe-create-function)
(define-local-member-name unsafe-create-aggregate)
;; == Connection
@ -350,7 +354,7 @@
(set! name-counter (add1 name-counter))
(format "λmz_~a" n)))
;; Reflection
;; == Reflection
(define/public (list-tables fsym schema)
(let ([stmt
@ -361,7 +365,7 @@
(for/list ([row (in-list (rows-result-rows result))])
(vector-ref row 0)))))
;; ----
;; == Load Extension
(define/public (unsafe-load-extension who lib)
(define lib-path (cleanse-path (path->complete-path lib)))
@ -373,7 +377,29 @@
(HANDLE who (A (sqlite3_enable_load_extension -db 0)))
(void))))
;; ----
;; == Create Function
(define dont-gc null)
(define/public (unsafe-create-function who name arity proc)
(define wrapped (wrap-fun name proc))
(call-with-lock who
(lambda ()
(set! dont-gc (cons wrapped dont-gc))
(HANDLE who (A (sqlite3_create_function_v2 -db name (or arity -1) wrapped))))))
(define/public (unsafe-create-aggregate who name arity step final [init #f])
(define aggbox (box init))
(define wrapped-step (wrap-agg-step name step aggbox init))
(define wrapped-final (wrap-agg-final name final aggbox init))
(call-with-lock who
(lambda ()
(set! dont-gc (list* wrapped-step wrapped-final dont-gc))
(HANDLE who
(A (sqlite3_create_aggregate -db name (or arity -1)
wrapped-step wrapped-final))))))
;; == Error handling
(define-syntax HANDLE
(syntax-rules ()

View File

@ -57,3 +57,8 @@
(define SQLITE_OPEN_FULLMUTEX #x00010000)
(define SQLITE_OPEN_SHAREDCACHE #x00020000)
(define SQLITE_OPEN_PRIVATECACHE #x00040000)
;; Create function
(define SQLITE_UTF8 1)
(define SQLITE_DETERMINISTIC #x800)

View File

@ -238,3 +238,129 @@
;; FIXME: handle error string?
(_fun _sqlite3_database _path (_pointer = #f) (_pointer = #f)
-> _int))
;; ----------------------------------------
(define-cpointer-type _sqlite3_context)
(define-cpointer-type _sqlite3_value)
(define-sqlite sqlite3_value_type (_fun _sqlite3_value -> _int))
(define-sqlite sqlite3_value_double (_fun _sqlite3_value -> _double))
(define-sqlite sqlite3_value_int64 (_fun _sqlite3_value -> _int64))
(define-sqlite sqlite3_value_bytes (_fun _sqlite3_value -> _int))
(define-sqlite sqlite3_value_blob (_fun _sqlite3_value -> _pointer))
(define-sqlite sqlite3_value_text (_fun _sqlite3_value -> _pointer))
(define-ffi-definer define-rkt #f)
(define-rkt scheme_make_sized_utf8_string (_fun _pointer _intptr -> _racket))
(define-rkt scheme_make_sized_byte_string (_fun _pointer _intptr -> _racket))
(define _sqlite3_value*
(make-ctype _sqlite3_value
#f
(lambda (v)
(define type (sqlite3_value_type v))
(cond [(= type SQLITE_INTEGER) (sqlite3_value_int64 v)]
[(= type SQLITE_FLOAT) (sqlite3_value_double v)]
[(= type SQLITE_TEXT)
(scheme_make_sized_utf8_string (sqlite3_value_text v)
(sqlite3_value_bytes v))]
[(= type SQLITE_BLOB)
(scheme_make_sized_byte_string (sqlite3_value_blob v)
(sqlite3_value_bytes v))]
[else (error '_sqlite3_value* "cannot convert: ~e (type = ~s)" v type)]))))
(define-sqlite sqlite3_create_function_v2
(_fun _sqlite3_database
_string/utf-8
_int
(_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC))
(_pointer = #f)
(_fun _sqlite3_context _int _pointer -> _void)
(_fpointer = #f)
(_fpointer = #f)
(_fpointer = #f)
-> _int))
(define-sqlite sqlite3_create_aggregate
(_fun _sqlite3_database
_string/utf-8
_int
(_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC))
(_pointer = #f)
(_fpointer = #f)
(_fun _sqlite3_context _int _pointer -> _void)
(_fun _sqlite3_context -> _void)
(_fpointer = #f)
-> _int)
#:c-id sqlite3_create_function_v2)
(define-sqlite sqlite3_aggregate_context
(_fun _sqlite3_context _int -> _pointer))
(define-sqlite sqlite3_result_null (_fun _sqlite3_context -> _void))
(define-sqlite sqlite3_result_int64 (_fun _sqlite3_context _int64 -> _void))
(define-sqlite sqlite3_result_double (_fun _sqlite3_context _double* -> _void))
(define-sqlite sqlite3_result_blob
(_fun _sqlite3_context
(buf : _bytes)
(_int = (bytes-length buf))
(_intptr = SQLITE_TRANSIENT)
-> _void))
(define-sqlite sqlite3_result_text
(_fun _sqlite3_context
(buf : _string/utf-8)
(_int = (string-utf-8-length buf))
(_intptr = SQLITE_TRANSIENT)
-> _void))
(define-sqlite sqlite3_result_error
(_fun _sqlite3_context (s : _string/utf-8) (_int = (string-utf-8-length s)) -> _void))
(define ((wrap-fun who proc) ctx argc argp)
(define args (get-args argc argp))
(call/wrap who ctx (lambda () (sqlite3_result* ctx (apply proc args)))))
;; sqlite3 supports an "aggregate context" for storing aggregate
;; state, but it's hidden from Racket's GC. So instead we make a
;; closure with Racket-visible state and use sqlite's aggregate
;; context just to tell us whether we need to reset the Racket-level
;; state. The connection object is responsible for preventing the
;; closure from being prematurely collected.
(define ((wrap-agg-step who proc aggbox agginit) ctx argc argp)
(define args (get-args argc argp))
(define aggctx (sqlite3_aggregate_context ctx 1))
(when (zero? (ptr-ref aggctx _byte))
(set-box! aggbox agginit)
(ptr-set! aggctx _byte 1))
(set-box! aggbox (call/wrap who ctx (lambda () (apply proc (unbox aggbox) args))))
(sqlite3_result* ctx 0))
(define ((wrap-agg-final who proc aggbox agginit) ctx)
(define aggctx (sqlite3_aggregate_context ctx 1))
(define r (call/wrap who ctx (lambda () (proc (unbox aggbox)))))
(set-box! aggbox agginit)
(sqlite3_result* ctx r))
(define (call/wrap who ctx proc)
(with-handlers
([(lambda (e) #t)
(lambda (e)
(define err
(format "[racket:~a] ~a"
who
(cond [(exn? e) (exn-message e)]
[else (format "caught non-exception\n caught: ~e" e)])))
(sqlite3_result_error ctx err))])
(call-with-continuation-barrier proc)))
(define (get-args argc argp)
(for/list ([i (in-range argc)])
(ptr-ref argp _sqlite3_value* i)))
(define (sqlite3_result* ctx r)
(cond [(fixnum? r) (sqlite3_result_int64 ctx r)] ;; FIXME: fixnum -> int64
[(real? r) (sqlite3_result_double ctx r)]
[(string? r) (sqlite3_result_text ctx r)]
[(bytes? r) (sqlite3_result_blob ctx r)]
[else (sqlite3_result_error ctx (format "bad result: ~e" r))]))