Skip to content

Commit 359e768

Browse files
committed
Added xdata to dbmem_provider_t
1 parent a308dd3 commit 359e768

3 files changed

Lines changed: 21 additions & 13 deletions

File tree

src/sqlite-memory.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ static void dbmem_context_free (void *ptr) {
571571
if (ctx->cache_buffer) dbmem_free(ctx->cache_buffer);
572572

573573
// custom provider
574-
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine);
574+
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine, ctx->custom_provider.xdata);
575575
if (ctx->custom_provider_name) dbmem_free(ctx->custom_provider_name);
576576

577577
#ifndef DBMEM_OMIT_LOCAL_ENGINE
@@ -608,7 +608,7 @@ bool dbmem_context_is_custom (dbmem_context *ctx) {
608608

609609
int dbmem_context_custom_compute (dbmem_context *ctx, const char *text, int text_len, embedding_result_t *result) {
610610
dbmem_embedding_result_t cr = {0};
611-
int rc = ctx->custom_provider.compute(ctx->custom_engine, text, text_len, &cr);
611+
int rc = ctx->custom_provider.compute(ctx->custom_engine, text, text_len, ctx->custom_provider.xdata, &cr);
612612
if (rc != 0) return rc;
613613
result->n_tokens = cr.n_tokens;
614614
result->n_tokens_truncated = cr.n_tokens_truncated;
@@ -950,10 +950,10 @@ static void dbmem_set_model (sqlite3_context *context, int argc, sqlite3_value *
950950
// custom provider path
951951
if (is_custom_provider) {
952952
// free previous custom engine if any
953-
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine);
953+
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine, ctx->custom_provider.xdata);
954954
ctx->custom_engine = NULL;
955955

956-
ctx->custom_engine = ctx->custom_provider.init(model, ctx->api_key, ctx->error_msg);
956+
ctx->custom_engine = ctx->custom_provider.init(model, ctx->api_key, ctx->custom_provider.xdata, ctx->error_msg);
957957
if (ctx->custom_engine == NULL) {
958958
sqlite3_result_error(context, ctx->error_msg, -1);
959959
return;
@@ -1562,7 +1562,7 @@ SQLITE_DBMEMORY_API int sqlite3_memory_register_provider (sqlite3 *db, const cha
15621562
if (!ctx) return SQLITE_ERROR;
15631563

15641564
// free previous custom provider if any
1565-
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine);
1565+
if (ctx->custom_engine && ctx->custom_provider.free) ctx->custom_provider.free(ctx->custom_engine, ctx->custom_provider.xdata);
15661566
ctx->custom_engine = NULL;
15671567
if (ctx->custom_provider_name) dbmem_free(ctx->custom_provider_name);
15681568

src/sqlite-memory.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
extern "C" {
2727
#endif
2828

29-
#define SQLITE_DBMEMORY_VERSION "0.8.0"
29+
#define SQLITE_DBMEMORY_VERSION "0.8.1"
3030

3131
// public API
3232
SQLITE_DBMEMORY_API int sqlite3_memory_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
@@ -47,14 +47,18 @@ typedef struct {
4747
typedef struct {
4848
// Called when memory_set_model(provider, model) matches this provider.
4949
// api_key is the value set via memory_set_apikey() (may be NULL).
50+
// xdata is the user-supplied generic pointer from the struct.
5051
// Return opaque engine pointer, or NULL on error (fill err_msg).
51-
void *(*init)(const char *model, const char *api_key, char err_msg[1024]);
52+
void *(*init)(const char *model, const char *api_key, void *xdata, char err_msg[1024]);
5253

5354
// Compute embedding for text. Return 0 on success, non-zero on error.
54-
int (*compute)(void *engine, const char *text, int text_len, dbmem_embedding_result_t *result);
55+
int (*compute)(void *engine, const char *text, int text_len, void *xdata, dbmem_embedding_result_t *result);
5556

5657
// Free the engine. Called on context teardown or model change. May be NULL.
57-
void (*free)(void *engine);
58+
void (*free)(void *engine, void *xdata);
59+
60+
// User-supplied generic data pointer, passed to all callbacks.
61+
void *xdata;
5862
} dbmem_provider_t;
5963

6064
// Register a custom embedding provider.

test/unittest.c

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,8 +2231,9 @@ typedef struct {
22312231
char api_key[256];
22322232
} dummy_engine_t;
22332233

2234-
static void *dummy_init(const char *model, const char *api_key, char err_msg[1024]) {
2234+
static void *dummy_init(const char *model, const char *api_key, void *xdata, char err_msg[1024]) {
22352235
UNUSED_PARAM(model);
2236+
UNUSED_PARAM(xdata);
22362237
dummy_engine_t *e = (dummy_engine_t *)calloc(1, sizeof(dummy_engine_t));
22372238
if (!e) { snprintf(err_msg, 1024, "alloc failed"); return NULL; }
22382239
e->dimension = 4;
@@ -2244,9 +2245,10 @@ static void *dummy_init(const char *model, const char *api_key, char err_msg[102
22442245
return e;
22452246
}
22462247

2247-
static int dummy_compute(void *engine, const char *text, int text_len, dbmem_embedding_result_t *result) {
2248+
static int dummy_compute(void *engine, const char *text, int text_len, void *xdata, dbmem_embedding_result_t *result) {
22482249
UNUSED_PARAM(text);
22492250
UNUSED_PARAM(text_len);
2251+
UNUSED_PARAM(xdata);
22502252
dummy_engine_t *e = (dummy_engine_t *)engine;
22512253
e->compute_count++;
22522254
result->n_tokens = text_len / 4;
@@ -2256,13 +2258,15 @@ static int dummy_compute(void *engine, const char *text, int text_len, dbmem_emb
22562258
return 0;
22572259
}
22582260

2259-
static void dummy_free(void *engine) {
2261+
static void dummy_free(void *engine, void *xdata) {
2262+
UNUSED_PARAM(xdata);
22602263
free(engine);
22612264
}
22622265

2263-
static void *dummy_init_fail(const char *model, const char *api_key, char err_msg[1024]) {
2266+
static void *dummy_init_fail(const char *model, const char *api_key, void *xdata, char err_msg[1024]) {
22642267
UNUSED_PARAM(model);
22652268
UNUSED_PARAM(api_key);
2269+
UNUSED_PARAM(xdata);
22662270
snprintf(err_msg, 1024, "intentional init failure");
22672271
return NULL;
22682272
}

0 commit comments

Comments
 (0)