db: fix custom functions for blocking sqlite3_step change

Also fix error reporting for aggregate functions.
This commit is contained in:
Ryan Culpepper 2021-03-27 23:27:47 +01:00
parent d65e648d6a
commit 9a4f9688ad
3 changed files with 52 additions and 29 deletions

View File

@ -440,14 +440,17 @@
(define dont-gc null)
(define/public (unsafe-create-function who name arity proc)
(define/public (unsafe-create-function who name arity proc
#:flags [flags null])
(define wrapped (wrap-fun name proc))
(call-with-lock who
(lambda ()
(set! dont-gc (cons wrapped dont-gc))
(HANDLE who (A (sqlite3_create_function_v2 -db name (or arity -1) wrapped))))))
(HANDLE who (A (sqlite3_create_function_v2/scalar
-db name (or arity -1) flags wrapped))))))
(define/public (unsafe-create-aggregate who name arity step final [init #f])
(define/public (unsafe-create-aggregate who name arity step final [init #f]
#:flags [flags null])
(define aggbox (box init))
(define wrapped-step (wrap-agg-step name step aggbox init))
(define wrapped-final (wrap-agg-final name final aggbox init))
@ -455,8 +458,8 @@
(lambda ()
(set! dont-gc (list* wrapped-step wrapped-final dont-gc))
(HANDLE who
(A (sqlite3_create_aggregate -db name (or arity -1)
wrapped-step wrapped-final))))))
(A (sqlite3_create_function_v2/aggregate
-db name (or arity -1) flags wrapped-step wrapped-final))))))
;; == Error handling

View File

@ -62,3 +62,4 @@
(define SQLITE_UTF8 1)
(define SQLITE_DETERMINISTIC #x800)
(define SQLITE_DIRECTONLY #x80000)

View File

@ -277,27 +277,39 @@
(sqlite3_value_bytes v))]
[else (error '_sqlite3_value* "cannot convert: ~e (type = ~s)" v type)]))))
(define-sqlite sqlite3_create_function_v2
(_fun _sqlite3_database
_string/utf-8
_int
(_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC))
(_pointer = #f)
(_fun _sqlite3_context _int _pointer -> _void)
(_fpointer = #f)
(_fpointer = #f)
(_fpointer = #f)
-> _int))
(define default-async-apply (lambda (p) (p)))
(define-sqlite sqlite3_create_aggregate
(_fun _sqlite3_database
_string/utf-8
_int
(_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC))
(define-sqlite sqlite3_create_function_v2/scalar
(_fun (db name arity flags proc) ::
(db : _sqlite3_database)
(name : _string/utf-8)
(arity : _int)
(_int = (bitwise-ior SQLITE_UTF8
(if (memq 'direct-only flags) SQLITE_DIRECTONLY 0)
(if (memq 'deterministic flags) SQLITE_DETERMINISTIC 0)))
(_pointer = #f)
(proc : (_fun #:async-apply default-async-apply
_sqlite3_context _int _pointer -> _void))
(_fpointer = #f)
(_fpointer = #f)
(_fpointer = #f)
-> _int)
#:c-id sqlite3_create_function_v2)
(define-sqlite sqlite3_create_function_v2/aggregate
(_fun (db name arity flags step final) ::
(db : _sqlite3_database)
(name : _string/utf-8)
(arity : _int)
(_int = (bitwise-ior SQLITE_UTF8
(if (memq 'direct-only flags) SQLITE_DIRECTONLY 0)
(if (memq 'deterministic flags) SQLITE_DETERMINISTIC 0)))
(_pointer = #f)
(_fpointer = #f)
(_fun _sqlite3_context _int _pointer -> _void)
(_fun _sqlite3_context -> _void)
(step : (_fun #:async-apply default-async-apply
_sqlite3_context _int _pointer -> _void))
(final : (_fun #:async-apply default-async-apply
_sqlite3_context -> _void))
(_fpointer = #f)
-> _int)
#:c-id sqlite3_create_function_v2)
@ -334,20 +346,26 @@
;; state. The connection object is responsible for preventing the
;; closure from being prematurely collected.
;; An aggbox is (box (U aggerror Any)); aggerror indicates that
;; sqlite3_result_error has already been called to report an error.
(define aggerror (gensym 'error))
(define ((wrap-agg-step who proc aggbox agginit) ctx argc argp)
(define args (get-args argc argp))
(define aggctx (sqlite3_aggregate_context ctx 1))
(when (zero? (ptr-ref aggctx _byte))
(set-box! aggbox agginit)
(ptr-set! aggctx _byte 1))
(set-box! aggbox (call/wrap who ctx (lambda () (apply proc (unbox aggbox) args))))
(sqlite3_result* ctx 0))
(unless (eq? (unbox aggbox) aggerror)
(set-box! aggbox (call/wrap who ctx (lambda () (apply proc (unbox aggbox) args))))))
(define ((wrap-agg-final who proc aggbox agginit) ctx)
(define aggctx (sqlite3_aggregate_context ctx 1))
(define r (call/wrap who ctx (lambda () (proc (unbox aggbox)))))
(set-box! aggbox agginit)
(sqlite3_result* ctx r))
(unless (eq? (unbox aggbox) aggerror)
(define r (call/wrap who ctx (lambda () (proc (unbox aggbox)))))
(set-box! aggbox #f)
(unless (eq? r aggerror)
(sqlite3_result* ctx r))))
(define (call/wrap who ctx proc)
(with-handlers
@ -358,7 +376,8 @@
who
(cond [(exn? e) (exn-message e)]
[else (format "caught non-exception\n caught: ~e" e)])))
(sqlite3_result_error ctx err))])
(sqlite3_result_error ctx err)
aggerror)])
(call-with-continuation-barrier proc)))
(define (get-args argc argp)