Skip to content

Commit 286afc1

Browse files
committed
Add exact distance tests
1 parent 01dd1ee commit 286afc1

1 file changed

Lines changed: 328 additions & 0 deletions

File tree

test/test_vector.c

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <stdlib.h>
1111
#include <string.h>
1212
#include <strings.h>
13+
#include <math.h>
1314
#include "sqlite3.h"
1415
#include "sqlite-vector.h"
1516

@@ -73,13 +74,15 @@ static int setup_table(sqlite3 *db, const char *tbl, const char *type,
7374

7475
typedef struct {
7576
int count;
77+
int ids[64];
7678
double distances[64];
7779
} scan_result;
7880

7981
static int scan_cb(void *ctx, int ncols, char **vals, char **names) {
8082
(void)names;
8183
scan_result *r = (scan_result *)ctx;
8284
if (r->count < 64 && ncols >= 2 && vals[1]) {
85+
r->ids[r->count] = vals[0] ? atoi(vals[0]) : 0;
8386
r->distances[r->count] = atof(vals[1]);
8487
}
8588
r->count++;
@@ -311,6 +314,321 @@ static const char *bit_vecs[] = {
311314
static const int bit_nvecs = 10;
312315
static const char *bit_query = "[1, 0, 1, 0, 1, 0, 1, 0]";
313316

317+
/* ---------- Test: distance function values ---------- */
318+
319+
typedef struct {
320+
const char *distance_name;
321+
double eps_f32;
322+
double eps_f16;
323+
double eps_bf16;
324+
double expected[10];
325+
} expected_distance_case;
326+
327+
static const char *distance_vecs[] = {
328+
"[1.0, 2.0, 0.0, -1.0]",
329+
"[0.5, -1.5, 2.0, 1.0]",
330+
"[-2.0, 0.0, 1.0, 0.5]",
331+
"[3.0, 1.0, -1.0, 2.0]",
332+
"[-0.5, 2.5, 1.5, -2.0]",
333+
"[1.5, 1.5, 1.5, 1.5]",
334+
"[-1.0, -2.0, 0.5, 3.0]",
335+
"[2.0, -0.5, -2.5, 0.0]",
336+
"[0.0, 3.0, -1.0, -1.5]",
337+
"[-1.5, 0.5, 2.5, -0.5]"
338+
};
339+
static const int distance_nvecs = 10;
340+
static const char *distance_query = "[0.75, -0.25, 1.25, -0.75]";
341+
static const char *distance_int_vecs[] = {
342+
"[10, 2, 0, 7]",
343+
"[3, 14, 9, 1]",
344+
"[20, 5, 4, 12]",
345+
"[8, 8, 8, 8]",
346+
"[1, 0, 15, 6]",
347+
"[12, 18, 2, 4]",
348+
"[6, 3, 11, 19]",
349+
"[16, 7, 13, 5]",
350+
"[4, 20, 1, 10]",
351+
"[9, 11, 6, 14]"
352+
};
353+
static const int distance_int_nvecs = 10;
354+
static const char *distance_int_query = "[7, 9, 5, 11]";
355+
356+
static double eps_for_type(const expected_distance_case *tc, const char *vtype) {
357+
if (strcasecmp(vtype, "f16") == 0) return tc->eps_f16;
358+
if (strcasecmp(vtype, "bf16") == 0) return tc->eps_bf16;
359+
return tc->eps_f32;
360+
}
361+
362+
static void test_one_distance_case(sqlite3 *db, const char *vtype, const expected_distance_case *tc) {
363+
char tbl[64];
364+
char sql[1024];
365+
char msg[256];
366+
double eps = eps_for_type(tc, vtype);
367+
368+
snprintf(tbl, sizeof(tbl), "tdist_%s_%s", tc->distance_name, vtype);
369+
for (char *p = tbl; *p; p++) if (*p >= 'A' && *p <= 'Z') *p += 32;
370+
371+
if (setup_table(db, tbl, vtype, tc->distance_name, 4, distance_vecs, distance_nvecs) != 0) {
372+
snprintf(msg, sizeof(msg), "%s/%s distance setup", vtype, tc->distance_name);
373+
ASSERT(0, msg);
374+
return;
375+
}
376+
377+
scan_result r = {0};
378+
snprintf(sql, sizeof(sql),
379+
"SELECT id, distance FROM vector_full_scan('%s', 'v', vector_as_%s('%s')) ORDER BY id;",
380+
tbl, vtype, distance_query);
381+
char *err = NULL;
382+
int rc = sqlite3_exec(db, sql, scan_cb, &r, &err);
383+
snprintf(msg, sizeof(msg), "%s/%s distance query executes", vtype, tc->distance_name);
384+
ASSERT(rc == SQLITE_OK, msg);
385+
if (err) { printf(" err: %s\n", err); sqlite3_free(err); }
386+
if (rc != SQLITE_OK) return;
387+
388+
snprintf(msg, sizeof(msg), "%s/%s distance query returns all rows", vtype, tc->distance_name);
389+
ASSERT(r.count == distance_nvecs, msg);
390+
if (r.count != distance_nvecs) return;
391+
392+
for (int i = 0; i < distance_nvecs; i++) {
393+
int id_ok = (r.ids[i] == (i + 1));
394+
snprintf(msg, sizeof(msg), "%s/%s row id matches expected (row %d)",
395+
vtype, tc->distance_name, i + 1);
396+
ASSERT(id_ok, msg);
397+
398+
double diff = fabs(r.distances[i] - tc->expected[i]);
399+
int within_eps = diff <= eps;
400+
snprintf(msg, sizeof(msg), "%s/%s distance within epsilon (id=%d, diff=%.8g, eps=%.3g)",
401+
vtype, tc->distance_name, i + 1, diff, eps);
402+
ASSERT(within_eps, msg);
403+
}
404+
}
405+
406+
static void test_distance_functions_float(sqlite3 *db) {
407+
const expected_distance_case cases[] = {
408+
{
409+
.distance_name = "L2",
410+
.eps_f32 = 1e-6,
411+
.eps_f16 = 1e-2,
412+
.eps_bf16 = 5e-2,
413+
.expected = {
414+
2.598076211353316,
415+
2.291287847477920,
416+
3.041381265149110,
417+
4.387482193696061,
418+
3.278719262151000,
419+
2.958039891549808,
420+
4.555216789572150,
421+
4.031128874149275,
422+
4.092676385936225,
423+
2.692582403567252
424+
}
425+
},
426+
{
427+
.distance_name = "SQUARED_L2",
428+
.eps_f32 = 1e-6,
429+
.eps_f16 = 5e-2,
430+
.eps_bf16 = 2e-1,
431+
.expected = {6.75, 5.25, 9.25, 19.25, 10.75, 8.75, 20.75, 16.25, 16.75, 7.25}
432+
},
433+
{
434+
.distance_name = "COSINE",
435+
.eps_f32 = 1e-5,
436+
.eps_f16 = 1e-2,
437+
.eps_bf16 = 5e-2,
438+
.expected = {
439+
0.753817018041334,
440+
0.449518117436820,
441+
1.164487923739942,
442+
1.116774841624228,
443+
0.598909685625288,
444+
0.698488655422236,
445+
1.299521148936577,
446+
1.279145263119541,
447+
1.150755672288882,
448+
0.547732983133355
449+
}
450+
},
451+
{
452+
.distance_name = "DOT",
453+
.eps_f32 = 1e-6,
454+
.eps_f16 = 1e-2,
455+
.eps_bf16 = 5e-2,
456+
.expected = {-1.0, -2.5, 0.625, 0.75, -2.375, -1.5, 1.875, 1.5, 0.875, -2.25}
457+
},
458+
{
459+
.distance_name = "L1",
460+
.eps_f32 = 1e-6,
461+
.eps_f16 = 1e-2,
462+
.eps_bf16 = 5e-2,
463+
.expected = {4.0, 4.0, 4.5, 8.5, 5.5, 5.0, 8.0, 6.0, 7.0, 4.5}
464+
}
465+
};
466+
const int ncases = (int)(sizeof(cases) / sizeof(cases[0]));
467+
const char *types[] = {"f32", "f16", "bf16"};
468+
469+
for (int t = 0; t < 3; t++) {
470+
for (int i = 0; i < ncases; i++) {
471+
test_one_distance_case(db, types[t], &cases[i]);
472+
}
473+
}
474+
}
475+
476+
typedef struct {
477+
const char *distance_name;
478+
double eps_i8;
479+
double eps_u8;
480+
double expected[10];
481+
} expected_int_distance_case;
482+
483+
static double eps_for_int_type(const expected_int_distance_case *tc, const char *vtype) {
484+
if (strcasecmp(vtype, "i8") == 0) return tc->eps_i8;
485+
return tc->eps_u8;
486+
}
487+
488+
static void test_one_int_distance_case(sqlite3 *db, const char *vtype, const expected_int_distance_case *tc) {
489+
char tbl[64];
490+
char sql[1024];
491+
char msg[256];
492+
double eps = eps_for_int_type(tc, vtype);
493+
494+
snprintf(tbl, sizeof(tbl), "tdist_%s_%s", tc->distance_name, vtype);
495+
for (char *p = tbl; *p; p++) if (*p >= 'A' && *p <= 'Z') *p += 32;
496+
497+
if (setup_table(db, tbl, vtype, tc->distance_name, 4, distance_int_vecs, distance_int_nvecs) != 0) {
498+
snprintf(msg, sizeof(msg), "%s/%s int distance setup", vtype, tc->distance_name);
499+
ASSERT(0, msg);
500+
return;
501+
}
502+
503+
scan_result r = {0};
504+
snprintf(sql, sizeof(sql),
505+
"SELECT id, distance FROM vector_full_scan('%s', 'v', vector_as_%s('%s')) ORDER BY id;",
506+
tbl, vtype, distance_int_query);
507+
char *err = NULL;
508+
int rc = sqlite3_exec(db, sql, scan_cb, &r, &err);
509+
snprintf(msg, sizeof(msg), "%s/%s int distance query executes", vtype, tc->distance_name);
510+
ASSERT(rc == SQLITE_OK, msg);
511+
if (err) { printf(" err: %s\n", err); sqlite3_free(err); }
512+
if (rc != SQLITE_OK) return;
513+
514+
snprintf(msg, sizeof(msg), "%s/%s int distance query returns all rows", vtype, tc->distance_name);
515+
ASSERT(r.count == distance_int_nvecs, msg);
516+
if (r.count != distance_int_nvecs) return;
517+
518+
for (int i = 0; i < distance_int_nvecs; i++) {
519+
int id_ok = (r.ids[i] == (i + 1));
520+
snprintf(msg, sizeof(msg), "%s/%s int row id matches expected (row %d)",
521+
vtype, tc->distance_name, i + 1);
522+
ASSERT(id_ok, msg);
523+
524+
double diff = fabs(r.distances[i] - tc->expected[i]);
525+
int within_eps = diff <= eps;
526+
snprintf(msg, sizeof(msg), "%s/%s int distance within epsilon (id=%d, diff=%.8g, eps=%.3g)",
527+
vtype, tc->distance_name, i + 1, diff, eps);
528+
ASSERT(within_eps, msg);
529+
}
530+
}
531+
532+
static void test_distance_functions_int(sqlite3 *db) {
533+
const expected_int_distance_case cases[] = {
534+
{
535+
.distance_name = "L2",
536+
.eps_i8 = 1e-6,
537+
.eps_u8 = 1e-6,
538+
.expected = {
539+
9.949874371066199,
540+
12.529964086141668,
541+
13.674794331177344,
542+
4.472135954999580,
543+
15.556349186104045,
544+
12.806248474865697,
545+
11.704699910719626,
546+
13.601470508735444,
547+
12.124355652982141,
548+
4.242640687119285
549+
}
550+
},
551+
{
552+
.distance_name = "SQUARED_L2",
553+
.eps_i8 = 1e-6,
554+
.eps_u8 = 1e-6,
555+
.expected = {99.0, 157.0, 187.0, 20.0, 242.0, 164.0, 137.0, 185.0, 147.0, 18.0}
556+
},
557+
{
558+
.distance_name = "COSINE",
559+
.eps_i8 = 1e-6,
560+
.eps_u8 = 1e-6,
561+
.expected = {
562+
0.197058901598547,
563+
0.278725549597720,
564+
0.161317797973194,
565+
0.036913175313846,
566+
0.449627749704491,
567+
0.182558273343614,
568+
0.126858993881120,
569+
0.205091387999948,
570+
0.144927951966812,
571+
0.000283884548207
572+
}
573+
},
574+
{
575+
.distance_name = "DOT",
576+
.eps_i8 = 1e-6,
577+
.eps_u8 = 1e-6,
578+
.expected = {-165.0, -203.0, -337.0, -256.0, -148.0, -300.0, -333.0, -295.0, -323.0, -346.0}
579+
},
580+
{
581+
.distance_name = "L1",
582+
.eps_i8 = 1e-6,
583+
.eps_u8 = 1e-6,
584+
.expected = {19.0, 23.0, 19.0, 8.0, 30.0, 24.0, 21.0, 25.0, 19.0, 8.0}
585+
}
586+
};
587+
const int ncases = (int)(sizeof(cases) / sizeof(cases[0]));
588+
const char *types[] = {"i8", "u8"};
589+
590+
for (int t = 0; t < 2; t++) {
591+
for (int i = 0; i < ncases; i++) {
592+
test_one_int_distance_case(db, types[t], &cases[i]);
593+
}
594+
}
595+
}
596+
597+
static void test_distance_functions_hamming(sqlite3 *db) {
598+
const char *tbl = "tdist_hamming_bit";
599+
const double expected[] = {3.0, 5.0, 3.0, 5.0, 4.0, 4.0, 4.0, 3.0, 5.0, 4.0};
600+
char sql[1024];
601+
char msg[256];
602+
603+
if (setup_table(db, tbl, "bit", "HAMMING", 8, bit_vecs, bit_nvecs) != 0) {
604+
ASSERT(0, "bit/HAMMING distance setup");
605+
return;
606+
}
607+
608+
scan_result r = {0};
609+
snprintf(sql, sizeof(sql),
610+
"SELECT id, distance FROM vector_full_scan('%s', 'v', vector_as_bit('%s')) ORDER BY id;",
611+
tbl, bit_query);
612+
char *err = NULL;
613+
int rc = sqlite3_exec(db, sql, scan_cb, &r, &err);
614+
ASSERT(rc == SQLITE_OK, "bit/HAMMING distance query executes");
615+
if (err) { printf(" err: %s\n", err); sqlite3_free(err); }
616+
if (rc != SQLITE_OK) return;
617+
618+
ASSERT(r.count == bit_nvecs, "bit/HAMMING distance query returns all rows");
619+
if (r.count != bit_nvecs) return;
620+
621+
for (int i = 0; i < bit_nvecs; i++) {
622+
int id_ok = (r.ids[i] == (i + 1));
623+
snprintf(msg, sizeof(msg), "bit/HAMMING row id matches expected (row %d)", i + 1);
624+
ASSERT(id_ok, msg);
625+
626+
int exact_match = (r.distances[i] == expected[i]);
627+
snprintf(msg, sizeof(msg), "bit/HAMMING distance matches exactly (id=%d)", i + 1);
628+
ASSERT(exact_match, msg);
629+
}
630+
}
631+
314632
/* ---------- Main ---------- */
315633

316634
int main(void) {
@@ -426,6 +744,16 @@ int main(void) {
426744
}
427745
}
428746

747+
748+
/* 5. distance functions */
749+
printf("\n=== distance_functions ===\n");
750+
{
751+
test_distance_functions_float(db);
752+
test_distance_functions_int(db);
753+
test_distance_functions_hamming(db);
754+
}
755+
756+
429757
sqlite3_close(db);
430758

431759
/* Summary */

0 commit comments

Comments
 (0)