diff --git a/racket/collects/db/private/sqlite3/connection.rkt b/racket/collects/db/private/sqlite3/connection.rkt index 9ecd7a4632..3caca40f19 100644 --- a/racket/collects/db/private/sqlite3/connection.rkt +++ b/racket/collects/db/private/sqlite3/connection.rkt @@ -12,9 +12,13 @@ "dbsystem.rkt") (provide connection% handle-status* - (protect-out unsafe-load-extension)) + (protect-out unsafe-load-extension + unsafe-create-function + unsafe-create-aggregate)) (define-local-member-name unsafe-load-extension) +(define-local-member-name unsafe-create-function) +(define-local-member-name unsafe-create-aggregate) ;; == Connection @@ -350,7 +354,7 @@ (set! name-counter (add1 name-counter)) (format "λmz_~a" n))) - ;; Reflection + ;; == Reflection (define/public (list-tables fsym schema) (let ([stmt @@ -361,7 +365,7 @@ (for/list ([row (in-list (rows-result-rows result))]) (vector-ref row 0))))) - ;; ---- + ;; == Load Extension (define/public (unsafe-load-extension who lib) (define lib-path (cleanse-path (path->complete-path lib))) @@ -373,7 +377,29 @@ (HANDLE who (A (sqlite3_enable_load_extension -db 0))) (void)))) - ;; ---- + ;; == Create Function + + (define dont-gc null) + + (define/public (unsafe-create-function who name arity proc) + (define wrapped (wrap-fun name proc)) + (call-with-lock who + (lambda () + (set! dont-gc (cons wrapped dont-gc)) + (HANDLE who (A (sqlite3_create_function_v2 -db name (or arity -1) wrapped)))))) + + (define/public (unsafe-create-aggregate who name arity step final [init #f]) + (define aggbox (box init)) + (define wrapped-step (wrap-agg-step name step aggbox init)) + (define wrapped-final (wrap-agg-final name final aggbox init)) + (call-with-lock who + (lambda () + (set! dont-gc (list* wrapped-step wrapped-final dont-gc)) + (HANDLE who + (A (sqlite3_create_aggregate -db name (or arity -1) + wrapped-step wrapped-final)))))) + + ;; == Error handling (define-syntax HANDLE (syntax-rules () diff --git a/racket/collects/db/private/sqlite3/ffi-constants.rkt b/racket/collects/db/private/sqlite3/ffi-constants.rkt index 59f0465476..2639de140f 100644 --- a/racket/collects/db/private/sqlite3/ffi-constants.rkt +++ b/racket/collects/db/private/sqlite3/ffi-constants.rkt @@ -57,3 +57,8 @@ (define SQLITE_OPEN_FULLMUTEX #x00010000) (define SQLITE_OPEN_SHAREDCACHE #x00020000) (define SQLITE_OPEN_PRIVATECACHE #x00040000) + +;; Create function + +(define SQLITE_UTF8 1) +(define SQLITE_DETERMINISTIC #x800) diff --git a/racket/collects/db/private/sqlite3/ffi.rkt b/racket/collects/db/private/sqlite3/ffi.rkt index 29831d59e0..8c29d2c887 100644 --- a/racket/collects/db/private/sqlite3/ffi.rkt +++ b/racket/collects/db/private/sqlite3/ffi.rkt @@ -238,3 +238,129 @@ ;; FIXME: handle error string? (_fun _sqlite3_database _path (_pointer = #f) (_pointer = #f) -> _int)) + +;; ---------------------------------------- + +(define-cpointer-type _sqlite3_context) +(define-cpointer-type _sqlite3_value) + +(define-sqlite sqlite3_value_type (_fun _sqlite3_value -> _int)) +(define-sqlite sqlite3_value_double (_fun _sqlite3_value -> _double)) +(define-sqlite sqlite3_value_int64 (_fun _sqlite3_value -> _int64)) +(define-sqlite sqlite3_value_bytes (_fun _sqlite3_value -> _int)) +(define-sqlite sqlite3_value_blob (_fun _sqlite3_value -> _pointer)) +(define-sqlite sqlite3_value_text (_fun _sqlite3_value -> _pointer)) + +(define-ffi-definer define-rkt #f) +(define-rkt scheme_make_sized_utf8_string (_fun _pointer _intptr -> _racket)) +(define-rkt scheme_make_sized_byte_string (_fun _pointer _intptr -> _racket)) + +(define _sqlite3_value* + (make-ctype _sqlite3_value + #f + (lambda (v) + (define type (sqlite3_value_type v)) + (cond [(= type SQLITE_INTEGER) (sqlite3_value_int64 v)] + [(= type SQLITE_FLOAT) (sqlite3_value_double v)] + [(= type SQLITE_TEXT) + (scheme_make_sized_utf8_string (sqlite3_value_text v) + (sqlite3_value_bytes v))] + [(= type SQLITE_BLOB) + (scheme_make_sized_byte_string (sqlite3_value_blob v) + (sqlite3_value_bytes v))] + [else (error '_sqlite3_value* "cannot convert: ~e (type = ~s)" v type)])))) + +(define-sqlite sqlite3_create_function_v2 + (_fun _sqlite3_database + _string/utf-8 + _int + (_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC)) + (_pointer = #f) + (_fun _sqlite3_context _int _pointer -> _void) + (_fpointer = #f) + (_fpointer = #f) + (_fpointer = #f) + -> _int)) + +(define-sqlite sqlite3_create_aggregate + (_fun _sqlite3_database + _string/utf-8 + _int + (_int = (+ SQLITE_UTF8 SQLITE_DETERMINISTIC)) + (_pointer = #f) + (_fpointer = #f) + (_fun _sqlite3_context _int _pointer -> _void) + (_fun _sqlite3_context -> _void) + (_fpointer = #f) + -> _int) + #:c-id sqlite3_create_function_v2) + +(define-sqlite sqlite3_aggregate_context + (_fun _sqlite3_context _int -> _pointer)) + +(define-sqlite sqlite3_result_null (_fun _sqlite3_context -> _void)) +(define-sqlite sqlite3_result_int64 (_fun _sqlite3_context _int64 -> _void)) +(define-sqlite sqlite3_result_double (_fun _sqlite3_context _double* -> _void)) +(define-sqlite sqlite3_result_blob + (_fun _sqlite3_context + (buf : _bytes) + (_int = (bytes-length buf)) + (_intptr = SQLITE_TRANSIENT) + -> _void)) +(define-sqlite sqlite3_result_text + (_fun _sqlite3_context + (buf : _string/utf-8) + (_int = (string-utf-8-length buf)) + (_intptr = SQLITE_TRANSIENT) + -> _void)) +(define-sqlite sqlite3_result_error + (_fun _sqlite3_context (s : _string/utf-8) (_int = (string-utf-8-length s)) -> _void)) + +(define ((wrap-fun who proc) ctx argc argp) + (define args (get-args argc argp)) + (call/wrap who ctx (lambda () (sqlite3_result* ctx (apply proc args))))) + +;; sqlite3 supports an "aggregate context" for storing aggregate +;; state, but it's hidden from Racket's GC. So instead we make a +;; closure with Racket-visible state and use sqlite's aggregate +;; context just to tell us whether we need to reset the Racket-level +;; state. The connection object is responsible for preventing the +;; closure from being prematurely collected. + +(define ((wrap-agg-step who proc aggbox agginit) ctx argc argp) + (define args (get-args argc argp)) + (define aggctx (sqlite3_aggregate_context ctx 1)) + (when (zero? (ptr-ref aggctx _byte)) + (set-box! aggbox agginit) + (ptr-set! aggctx _byte 1)) + (set-box! aggbox (call/wrap who ctx (lambda () (apply proc (unbox aggbox) args)))) + (sqlite3_result* ctx 0)) + +(define ((wrap-agg-final who proc aggbox agginit) ctx) + (define aggctx (sqlite3_aggregate_context ctx 1)) + (define r (call/wrap who ctx (lambda () (proc (unbox aggbox))))) + (set-box! aggbox agginit) + (sqlite3_result* ctx r)) + +(define (call/wrap who ctx proc) + (with-handlers + ([(lambda (e) #t) + (lambda (e) + (define err + (format "[racket:~a] ~a" + who + (cond [(exn? e) (exn-message e)] + [else (format "caught non-exception\n caught: ~e" e)]))) + (sqlite3_result_error ctx err))]) + (call-with-continuation-barrier proc))) + +(define (get-args argc argp) + (for/list ([i (in-range argc)]) + (ptr-ref argp _sqlite3_value* i))) + +(define (sqlite3_result* ctx r) + (cond [(fixnum? r) (sqlite3_result_int64 ctx r)] ;; FIXME: fixnum -> int64 + [(real? r) (sqlite3_result_double ctx r)] + [(string? r) (sqlite3_result_text ctx r)] + [(bytes? r) (sqlite3_result_blob ctx r)] + [else (sqlite3_result_error ctx (format "bad result: ~e" r))]))