diff --git a/racket/collects/db/private/sqlite3/connection.rkt b/racket/collects/db/private/sqlite3/connection.rkt index 96ed9cd549..df59613431 100644 --- a/racket/collects/db/private/sqlite3/connection.rkt +++ b/racket/collects/db/private/sqlite3/connection.rkt @@ -440,14 +440,17 @@ (define dont-gc null) - (define/public (unsafe-create-function who name arity proc) + (define/public (unsafe-create-function who name arity proc + #:flags [flags null]) (define wrapped (wrap-fun name proc)) (call-with-lock who (lambda () (set! dont-gc (cons wrapped dont-gc)) - (HANDLE who (A (sqlite3_create_function_v2 -db name (or arity -1) wrapped)))))) + (HANDLE who (A (sqlite3_create_function_v2/scalar + -db name (or arity -1) flags wrapped)))))) - (define/public (unsafe-create-aggregate who name arity step final [init #f]) + (define/public (unsafe-create-aggregate who name arity step final [init #f] + #:flags [flags null]) (define aggbox (box init)) (define wrapped-step (wrap-agg-step name step aggbox init)) (define wrapped-final (wrap-agg-final name final aggbox init)) @@ -455,8 +458,8 @@ (lambda () (set! dont-gc (list* wrapped-step wrapped-final dont-gc)) (HANDLE who - (A (sqlite3_create_aggregate -db name (or arity -1) - wrapped-step wrapped-final)))))) + (A (sqlite3_create_function_v2/aggregate + -db name (or arity -1) flags wrapped-step wrapped-final)))))) ;; == Error handling diff --git a/racket/collects/db/private/sqlite3/ffi-constants.rkt b/racket/collects/db/private/sqlite3/ffi-constants.rkt index 2639de140f..c47989d09f 100644 --- a/racket/collects/db/private/sqlite3/ffi-constants.rkt +++ b/racket/collects/db/private/sqlite3/ffi-constants.rkt @@ -62,3 +62,4 @@ (define SQLITE_UTF8 1) (define SQLITE_DETERMINISTIC #x800) +(define SQLITE_DIRECTONLY #x80000) diff --git a/racket/collects/db/private/sqlite3/ffi.rkt b/racket/collects/db/private/sqlite3/ffi.rkt index cb0ac1d351..c5b0c4760b 100644 --- a/racket/collects/db/private/sqlite3/ffi.rkt +++ b/racket/collects/db/private/sqlite3/ffi.rkt @@ -277,27 +277,39 @@ (sqlite3_value_bytes v))] [else (error '_sqlite3_value* "cannot convert: ~e (type = ~s)" v type)])))) -(define-sqlite sqlite3_create_function_v2 - (_fun _sqlite3_database - _string/utf-8 - _int - (_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC)) - (_pointer = #f) - (_fun _sqlite3_context _int _pointer -> _void) - (_fpointer = #f) - (_fpointer = #f) - (_fpointer = #f) - -> _int)) +(define default-async-apply (lambda (p) (p))) -(define-sqlite sqlite3_create_aggregate - (_fun _sqlite3_database - _string/utf-8 - _int - (_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC)) +(define-sqlite sqlite3_create_function_v2/scalar + (_fun (db name arity flags proc) :: + (db : _sqlite3_database) + (name : _string/utf-8) + (arity : _int) + (_int = (bitwise-ior SQLITE_UTF8 + (if (memq 'direct-only flags) SQLITE_DIRECTONLY 0) + (if (memq 'deterministic flags) SQLITE_DETERMINISTIC 0))) + (_pointer = #f) + (proc : (_fun #:async-apply default-async-apply + _sqlite3_context _int _pointer -> _void)) + (_fpointer = #f) + (_fpointer = #f) + (_fpointer = #f) + -> _int) + #:c-id sqlite3_create_function_v2) + +(define-sqlite sqlite3_create_function_v2/aggregate + (_fun (db name arity flags step final) :: + (db : _sqlite3_database) + (name : _string/utf-8) + (arity : _int) + (_int = (bitwise-ior SQLITE_UTF8 + (if (memq 'direct-only flags) SQLITE_DIRECTONLY 0) + (if (memq 'deterministic flags) SQLITE_DETERMINISTIC 0))) (_pointer = #f) (_fpointer = #f) - (_fun _sqlite3_context _int _pointer -> _void) - (_fun _sqlite3_context -> _void) + (step : (_fun #:async-apply default-async-apply + _sqlite3_context _int _pointer -> _void)) + (final : (_fun #:async-apply default-async-apply + _sqlite3_context -> _void)) (_fpointer = #f) -> _int) #:c-id sqlite3_create_function_v2) @@ -334,20 +346,26 @@ ;; state. The connection object is responsible for preventing the ;; closure from being prematurely collected. +;; An aggbox is (box (U aggerror Any)); aggerror indicates that +;; sqlite3_result_error has already been called to report an error. +(define aggerror (gensym 'error)) + (define ((wrap-agg-step who proc aggbox agginit) ctx argc argp) (define args (get-args argc argp)) (define aggctx (sqlite3_aggregate_context ctx 1)) (when (zero? (ptr-ref aggctx _byte)) (set-box! aggbox agginit) (ptr-set! aggctx _byte 1)) - (set-box! aggbox (call/wrap who ctx (lambda () (apply proc (unbox aggbox) args)))) - (sqlite3_result* ctx 0)) + (unless (eq? (unbox aggbox) aggerror) + (set-box! aggbox (call/wrap who ctx (lambda () (apply proc (unbox aggbox) args)))))) (define ((wrap-agg-final who proc aggbox agginit) ctx) (define aggctx (sqlite3_aggregate_context ctx 1)) - (define r (call/wrap who ctx (lambda () (proc (unbox aggbox))))) - (set-box! aggbox agginit) - (sqlite3_result* ctx r)) + (unless (eq? (unbox aggbox) aggerror) + (define r (call/wrap who ctx (lambda () (proc (unbox aggbox))))) + (set-box! aggbox #f) + (unless (eq? r aggerror) + (sqlite3_result* ctx r)))) (define (call/wrap who ctx proc) (with-handlers @@ -358,7 +376,8 @@ who (cond [(exn? e) (exn-message e)] [else (format "caught non-exception\n caught: ~e" e)]))) - (sqlite3_result_error ctx err))]) + (sqlite3_result_error ctx err) + aggerror)]) (call-with-continuation-barrier proc))) (define (get-args argc argp)