|
10 | 10 | #include <stdlib.h> |
11 | 11 | #include <string.h> |
12 | 12 | #include <strings.h> |
| 13 | +#include <math.h> |
13 | 14 | #include "sqlite3.h" |
14 | 15 | #include "sqlite-vector.h" |
15 | 16 |
|
@@ -73,13 +74,15 @@ static int setup_table(sqlite3 *db, const char *tbl, const char *type, |
73 | 74 |
|
74 | 75 | typedef struct { |
75 | 76 | int count; |
| 77 | + int ids[64]; |
76 | 78 | double distances[64]; |
77 | 79 | } scan_result; |
78 | 80 |
|
79 | 81 | static int scan_cb(void *ctx, int ncols, char **vals, char **names) { |
80 | 82 | (void)names; |
81 | 83 | scan_result *r = (scan_result *)ctx; |
82 | 84 | if (r->count < 64 && ncols >= 2 && vals[1]) { |
| 85 | + r->ids[r->count] = vals[0] ? atoi(vals[0]) : 0; |
83 | 86 | r->distances[r->count] = atof(vals[1]); |
84 | 87 | } |
85 | 88 | r->count++; |
@@ -311,6 +314,321 @@ static const char *bit_vecs[] = { |
311 | 314 | static const int bit_nvecs = 10; |
312 | 315 | static const char *bit_query = "[1, 0, 1, 0, 1, 0, 1, 0]"; |
313 | 316 |
|
| 317 | +/* ---------- Test: distance function values ---------- */ |
| 318 | + |
| 319 | +typedef struct { |
| 320 | + const char *distance_name; |
| 321 | + double eps_f32; |
| 322 | + double eps_f16; |
| 323 | + double eps_bf16; |
| 324 | + double expected[10]; |
| 325 | +} expected_distance_case; |
| 326 | + |
| 327 | +static const char *distance_vecs[] = { |
| 328 | + "[1.0, 2.0, 0.0, -1.0]", |
| 329 | + "[0.5, -1.5, 2.0, 1.0]", |
| 330 | + "[-2.0, 0.0, 1.0, 0.5]", |
| 331 | + "[3.0, 1.0, -1.0, 2.0]", |
| 332 | + "[-0.5, 2.5, 1.5, -2.0]", |
| 333 | + "[1.5, 1.5, 1.5, 1.5]", |
| 334 | + "[-1.0, -2.0, 0.5, 3.0]", |
| 335 | + "[2.0, -0.5, -2.5, 0.0]", |
| 336 | + "[0.0, 3.0, -1.0, -1.5]", |
| 337 | + "[-1.5, 0.5, 2.5, -0.5]" |
| 338 | +}; |
| 339 | +static const int distance_nvecs = 10; |
| 340 | +static const char *distance_query = "[0.75, -0.25, 1.25, -0.75]"; |
| 341 | +static const char *distance_int_vecs[] = { |
| 342 | + "[10, 2, 0, 7]", |
| 343 | + "[3, 14, 9, 1]", |
| 344 | + "[20, 5, 4, 12]", |
| 345 | + "[8, 8, 8, 8]", |
| 346 | + "[1, 0, 15, 6]", |
| 347 | + "[12, 18, 2, 4]", |
| 348 | + "[6, 3, 11, 19]", |
| 349 | + "[16, 7, 13, 5]", |
| 350 | + "[4, 20, 1, 10]", |
| 351 | + "[9, 11, 6, 14]" |
| 352 | +}; |
| 353 | +static const int distance_int_nvecs = 10; |
| 354 | +static const char *distance_int_query = "[7, 9, 5, 11]"; |
| 355 | + |
| 356 | +static double eps_for_type(const expected_distance_case *tc, const char *vtype) { |
| 357 | + if (strcasecmp(vtype, "f16") == 0) return tc->eps_f16; |
| 358 | + if (strcasecmp(vtype, "bf16") == 0) return tc->eps_bf16; |
| 359 | + return tc->eps_f32; |
| 360 | +} |
| 361 | + |
| 362 | +static void test_one_distance_case(sqlite3 *db, const char *vtype, const expected_distance_case *tc) { |
| 363 | + char tbl[64]; |
| 364 | + char sql[1024]; |
| 365 | + char msg[256]; |
| 366 | + double eps = eps_for_type(tc, vtype); |
| 367 | + |
| 368 | + snprintf(tbl, sizeof(tbl), "tdist_%s_%s", tc->distance_name, vtype); |
| 369 | + for (char *p = tbl; *p; p++) if (*p >= 'A' && *p <= 'Z') *p += 32; |
| 370 | + |
| 371 | + if (setup_table(db, tbl, vtype, tc->distance_name, 4, distance_vecs, distance_nvecs) != 0) { |
| 372 | + snprintf(msg, sizeof(msg), "%s/%s distance setup", vtype, tc->distance_name); |
| 373 | + ASSERT(0, msg); |
| 374 | + return; |
| 375 | + } |
| 376 | + |
| 377 | + scan_result r = {0}; |
| 378 | + snprintf(sql, sizeof(sql), |
| 379 | + "SELECT id, distance FROM vector_full_scan('%s', 'v', vector_as_%s('%s')) ORDER BY id;", |
| 380 | + tbl, vtype, distance_query); |
| 381 | + char *err = NULL; |
| 382 | + int rc = sqlite3_exec(db, sql, scan_cb, &r, &err); |
| 383 | + snprintf(msg, sizeof(msg), "%s/%s distance query executes", vtype, tc->distance_name); |
| 384 | + ASSERT(rc == SQLITE_OK, msg); |
| 385 | + if (err) { printf(" err: %s\n", err); sqlite3_free(err); } |
| 386 | + if (rc != SQLITE_OK) return; |
| 387 | + |
| 388 | + snprintf(msg, sizeof(msg), "%s/%s distance query returns all rows", vtype, tc->distance_name); |
| 389 | + ASSERT(r.count == distance_nvecs, msg); |
| 390 | + if (r.count != distance_nvecs) return; |
| 391 | + |
| 392 | + for (int i = 0; i < distance_nvecs; i++) { |
| 393 | + int id_ok = (r.ids[i] == (i + 1)); |
| 394 | + snprintf(msg, sizeof(msg), "%s/%s row id matches expected (row %d)", |
| 395 | + vtype, tc->distance_name, i + 1); |
| 396 | + ASSERT(id_ok, msg); |
| 397 | + |
| 398 | + double diff = fabs(r.distances[i] - tc->expected[i]); |
| 399 | + int within_eps = diff <= eps; |
| 400 | + snprintf(msg, sizeof(msg), "%s/%s distance within epsilon (id=%d, diff=%.8g, eps=%.3g)", |
| 401 | + vtype, tc->distance_name, i + 1, diff, eps); |
| 402 | + ASSERT(within_eps, msg); |
| 403 | + } |
| 404 | +} |
| 405 | + |
| 406 | +static void test_distance_functions_float(sqlite3 *db) { |
| 407 | + const expected_distance_case cases[] = { |
| 408 | + { |
| 409 | + .distance_name = "L2", |
| 410 | + .eps_f32 = 1e-6, |
| 411 | + .eps_f16 = 1e-2, |
| 412 | + .eps_bf16 = 5e-2, |
| 413 | + .expected = { |
| 414 | + 2.598076211353316, |
| 415 | + 2.291287847477920, |
| 416 | + 3.041381265149110, |
| 417 | + 4.387482193696061, |
| 418 | + 3.278719262151000, |
| 419 | + 2.958039891549808, |
| 420 | + 4.555216789572150, |
| 421 | + 4.031128874149275, |
| 422 | + 4.092676385936225, |
| 423 | + 2.692582403567252 |
| 424 | + } |
| 425 | + }, |
| 426 | + { |
| 427 | + .distance_name = "SQUARED_L2", |
| 428 | + .eps_f32 = 1e-6, |
| 429 | + .eps_f16 = 5e-2, |
| 430 | + .eps_bf16 = 2e-1, |
| 431 | + .expected = {6.75, 5.25, 9.25, 19.25, 10.75, 8.75, 20.75, 16.25, 16.75, 7.25} |
| 432 | + }, |
| 433 | + { |
| 434 | + .distance_name = "COSINE", |
| 435 | + .eps_f32 = 1e-5, |
| 436 | + .eps_f16 = 1e-2, |
| 437 | + .eps_bf16 = 5e-2, |
| 438 | + .expected = { |
| 439 | + 0.753817018041334, |
| 440 | + 0.449518117436820, |
| 441 | + 1.164487923739942, |
| 442 | + 1.116774841624228, |
| 443 | + 0.598909685625288, |
| 444 | + 0.698488655422236, |
| 445 | + 1.299521148936577, |
| 446 | + 1.279145263119541, |
| 447 | + 1.150755672288882, |
| 448 | + 0.547732983133355 |
| 449 | + } |
| 450 | + }, |
| 451 | + { |
| 452 | + .distance_name = "DOT", |
| 453 | + .eps_f32 = 1e-6, |
| 454 | + .eps_f16 = 1e-2, |
| 455 | + .eps_bf16 = 5e-2, |
| 456 | + .expected = {-1.0, -2.5, 0.625, 0.75, -2.375, -1.5, 1.875, 1.5, 0.875, -2.25} |
| 457 | + }, |
| 458 | + { |
| 459 | + .distance_name = "L1", |
| 460 | + .eps_f32 = 1e-6, |
| 461 | + .eps_f16 = 1e-2, |
| 462 | + .eps_bf16 = 5e-2, |
| 463 | + .expected = {4.0, 4.0, 4.5, 8.5, 5.5, 5.0, 8.0, 6.0, 7.0, 4.5} |
| 464 | + } |
| 465 | + }; |
| 466 | + const int ncases = (int)(sizeof(cases) / sizeof(cases[0])); |
| 467 | + const char *types[] = {"f32", "f16", "bf16"}; |
| 468 | + |
| 469 | + for (int t = 0; t < 3; t++) { |
| 470 | + for (int i = 0; i < ncases; i++) { |
| 471 | + test_one_distance_case(db, types[t], &cases[i]); |
| 472 | + } |
| 473 | + } |
| 474 | +} |
| 475 | + |
| 476 | +typedef struct { |
| 477 | + const char *distance_name; |
| 478 | + double eps_i8; |
| 479 | + double eps_u8; |
| 480 | + double expected[10]; |
| 481 | +} expected_int_distance_case; |
| 482 | + |
| 483 | +static double eps_for_int_type(const expected_int_distance_case *tc, const char *vtype) { |
| 484 | + if (strcasecmp(vtype, "i8") == 0) return tc->eps_i8; |
| 485 | + return tc->eps_u8; |
| 486 | +} |
| 487 | + |
| 488 | +static void test_one_int_distance_case(sqlite3 *db, const char *vtype, const expected_int_distance_case *tc) { |
| 489 | + char tbl[64]; |
| 490 | + char sql[1024]; |
| 491 | + char msg[256]; |
| 492 | + double eps = eps_for_int_type(tc, vtype); |
| 493 | + |
| 494 | + snprintf(tbl, sizeof(tbl), "tdist_%s_%s", tc->distance_name, vtype); |
| 495 | + for (char *p = tbl; *p; p++) if (*p >= 'A' && *p <= 'Z') *p += 32; |
| 496 | + |
| 497 | + if (setup_table(db, tbl, vtype, tc->distance_name, 4, distance_int_vecs, distance_int_nvecs) != 0) { |
| 498 | + snprintf(msg, sizeof(msg), "%s/%s int distance setup", vtype, tc->distance_name); |
| 499 | + ASSERT(0, msg); |
| 500 | + return; |
| 501 | + } |
| 502 | + |
| 503 | + scan_result r = {0}; |
| 504 | + snprintf(sql, sizeof(sql), |
| 505 | + "SELECT id, distance FROM vector_full_scan('%s', 'v', vector_as_%s('%s')) ORDER BY id;", |
| 506 | + tbl, vtype, distance_int_query); |
| 507 | + char *err = NULL; |
| 508 | + int rc = sqlite3_exec(db, sql, scan_cb, &r, &err); |
| 509 | + snprintf(msg, sizeof(msg), "%s/%s int distance query executes", vtype, tc->distance_name); |
| 510 | + ASSERT(rc == SQLITE_OK, msg); |
| 511 | + if (err) { printf(" err: %s\n", err); sqlite3_free(err); } |
| 512 | + if (rc != SQLITE_OK) return; |
| 513 | + |
| 514 | + snprintf(msg, sizeof(msg), "%s/%s int distance query returns all rows", vtype, tc->distance_name); |
| 515 | + ASSERT(r.count == distance_int_nvecs, msg); |
| 516 | + if (r.count != distance_int_nvecs) return; |
| 517 | + |
| 518 | + for (int i = 0; i < distance_int_nvecs; i++) { |
| 519 | + int id_ok = (r.ids[i] == (i + 1)); |
| 520 | + snprintf(msg, sizeof(msg), "%s/%s int row id matches expected (row %d)", |
| 521 | + vtype, tc->distance_name, i + 1); |
| 522 | + ASSERT(id_ok, msg); |
| 523 | + |
| 524 | + double diff = fabs(r.distances[i] - tc->expected[i]); |
| 525 | + int within_eps = diff <= eps; |
| 526 | + snprintf(msg, sizeof(msg), "%s/%s int distance within epsilon (id=%d, diff=%.8g, eps=%.3g)", |
| 527 | + vtype, tc->distance_name, i + 1, diff, eps); |
| 528 | + ASSERT(within_eps, msg); |
| 529 | + } |
| 530 | +} |
| 531 | + |
| 532 | +static void test_distance_functions_int(sqlite3 *db) { |
| 533 | + const expected_int_distance_case cases[] = { |
| 534 | + { |
| 535 | + .distance_name = "L2", |
| 536 | + .eps_i8 = 1e-6, |
| 537 | + .eps_u8 = 1e-6, |
| 538 | + .expected = { |
| 539 | + 9.949874371066199, |
| 540 | + 12.529964086141668, |
| 541 | + 13.674794331177344, |
| 542 | + 4.472135954999580, |
| 543 | + 15.556349186104045, |
| 544 | + 12.806248474865697, |
| 545 | + 11.704699910719626, |
| 546 | + 13.601470508735444, |
| 547 | + 12.124355652982141, |
| 548 | + 4.242640687119285 |
| 549 | + } |
| 550 | + }, |
| 551 | + { |
| 552 | + .distance_name = "SQUARED_L2", |
| 553 | + .eps_i8 = 1e-6, |
| 554 | + .eps_u8 = 1e-6, |
| 555 | + .expected = {99.0, 157.0, 187.0, 20.0, 242.0, 164.0, 137.0, 185.0, 147.0, 18.0} |
| 556 | + }, |
| 557 | + { |
| 558 | + .distance_name = "COSINE", |
| 559 | + .eps_i8 = 1e-6, |
| 560 | + .eps_u8 = 1e-6, |
| 561 | + .expected = { |
| 562 | + 0.197058901598547, |
| 563 | + 0.278725549597720, |
| 564 | + 0.161317797973194, |
| 565 | + 0.036913175313846, |
| 566 | + 0.449627749704491, |
| 567 | + 0.182558273343614, |
| 568 | + 0.126858993881120, |
| 569 | + 0.205091387999948, |
| 570 | + 0.144927951966812, |
| 571 | + 0.000283884548207 |
| 572 | + } |
| 573 | + }, |
| 574 | + { |
| 575 | + .distance_name = "DOT", |
| 576 | + .eps_i8 = 1e-6, |
| 577 | + .eps_u8 = 1e-6, |
| 578 | + .expected = {-165.0, -203.0, -337.0, -256.0, -148.0, -300.0, -333.0, -295.0, -323.0, -346.0} |
| 579 | + }, |
| 580 | + { |
| 581 | + .distance_name = "L1", |
| 582 | + .eps_i8 = 1e-6, |
| 583 | + .eps_u8 = 1e-6, |
| 584 | + .expected = {19.0, 23.0, 19.0, 8.0, 30.0, 24.0, 21.0, 25.0, 19.0, 8.0} |
| 585 | + } |
| 586 | + }; |
| 587 | + const int ncases = (int)(sizeof(cases) / sizeof(cases[0])); |
| 588 | + const char *types[] = {"i8", "u8"}; |
| 589 | + |
| 590 | + for (int t = 0; t < 2; t++) { |
| 591 | + for (int i = 0; i < ncases; i++) { |
| 592 | + test_one_int_distance_case(db, types[t], &cases[i]); |
| 593 | + } |
| 594 | + } |
| 595 | +} |
| 596 | + |
| 597 | +static void test_distance_functions_hamming(sqlite3 *db) { |
| 598 | + const char *tbl = "tdist_hamming_bit"; |
| 599 | + const double expected[] = {3.0, 5.0, 3.0, 5.0, 4.0, 4.0, 4.0, 3.0, 5.0, 4.0}; |
| 600 | + char sql[1024]; |
| 601 | + char msg[256]; |
| 602 | + |
| 603 | + if (setup_table(db, tbl, "bit", "HAMMING", 8, bit_vecs, bit_nvecs) != 0) { |
| 604 | + ASSERT(0, "bit/HAMMING distance setup"); |
| 605 | + return; |
| 606 | + } |
| 607 | + |
| 608 | + scan_result r = {0}; |
| 609 | + snprintf(sql, sizeof(sql), |
| 610 | + "SELECT id, distance FROM vector_full_scan('%s', 'v', vector_as_bit('%s')) ORDER BY id;", |
| 611 | + tbl, bit_query); |
| 612 | + char *err = NULL; |
| 613 | + int rc = sqlite3_exec(db, sql, scan_cb, &r, &err); |
| 614 | + ASSERT(rc == SQLITE_OK, "bit/HAMMING distance query executes"); |
| 615 | + if (err) { printf(" err: %s\n", err); sqlite3_free(err); } |
| 616 | + if (rc != SQLITE_OK) return; |
| 617 | + |
| 618 | + ASSERT(r.count == bit_nvecs, "bit/HAMMING distance query returns all rows"); |
| 619 | + if (r.count != bit_nvecs) return; |
| 620 | + |
| 621 | + for (int i = 0; i < bit_nvecs; i++) { |
| 622 | + int id_ok = (r.ids[i] == (i + 1)); |
| 623 | + snprintf(msg, sizeof(msg), "bit/HAMMING row id matches expected (row %d)", i + 1); |
| 624 | + ASSERT(id_ok, msg); |
| 625 | + |
| 626 | + int exact_match = (r.distances[i] == expected[i]); |
| 627 | + snprintf(msg, sizeof(msg), "bit/HAMMING distance matches exactly (id=%d)", i + 1); |
| 628 | + ASSERT(exact_match, msg); |
| 629 | + } |
| 630 | +} |
| 631 | + |
314 | 632 | /* ---------- Main ---------- */ |
315 | 633 |
|
316 | 634 | int main(void) { |
@@ -426,6 +744,16 @@ int main(void) { |
426 | 744 | } |
427 | 745 | } |
428 | 746 |
|
| 747 | + |
| 748 | + /* 5. distance functions */ |
| 749 | + printf("\n=== distance_functions ===\n"); |
| 750 | + { |
| 751 | + test_distance_functions_float(db); |
| 752 | + test_distance_functions_int(db); |
| 753 | + test_distance_functions_hamming(db); |
| 754 | + } |
| 755 | + |
| 756 | + |
429 | 757 | sqlite3_close(db); |
430 | 758 |
|
431 | 759 | /* Summary */ |
|
0 commit comments