diff --git a/scripts/test-commands.sh b/scripts/test-commands.sh
index 97e4cffa..c93c8100 100755
--- a/scripts/test-commands.sh
+++ b/scripts/test-commands.sh
@@ -63,6 +63,7 @@ while [[ $# -gt 0 ]]; do
echo " pubsub - Pub/Sub commands (SUBSCRIBE, PUBLISH, etc.)"
echo " transaction - Transaction commands (MULTI, EXEC, DISCARD)"
echo " scripting - Lua scripting (EVAL, EVALSHA)"
+ echo " vector - Vector search commands (FT.CREATE, FT.SEARCH, FT.INFO, FT.DROPINDEX)"
echo " persistence - Persistence commands (BGSAVE, BGREWRITEAOF, etc.)"
echo " blocking - Blocking commands (BLPOP, BRPOP, BZPOPMIN, etc.)"
echo " benchmark - redis-benchmark throughput for all benchmarkable commands"
@@ -665,6 +666,42 @@ fi
# PERSISTENCE COMMANDS
# ===========================================================================
+# ===========================================================================
+# VECTOR SEARCH COMMANDS (moon-only — Redis uses different syntax)
+# ===========================================================================
+
+if should_run "vector"; then
+ echo ""
+ echo "=== VECTOR SEARCH COMMANDS ==="
+ mcli FLUSHALL >/dev/null 2>&1
+
+ # FT.CREATE — create a vector index
+ assert_moon "FT.CREATE basic" "OK" FT.CREATE myidx ON HASH PREFIX 1 doc: SCHEMA embedding VECTOR FLAT 6 DIM 4 DISTANCE_METRIC L2 TYPE FLOAT32
+
+ # FT.INFO — index metadata
+ TOTAL=$((TOTAL + 1)); FT_INFO=$(mcli FT.INFO myidx 2>&1)
+ if echo "$FT_INFO" | grep -q "myidx"; then PASS=$((PASS + 1)); echo " PASS: FT.INFO returns index name"; else FAIL=$((FAIL + 1)); echo " FAIL: FT.INFO returns index name"; fi
+
+ # Insert vectors via HSET (auto-indexed) — use python3 to avoid null byte stripping in bash
+ python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',1.0,0.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET doc:1 embedding >/dev/null 2>&1
+ python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',0.0,1.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET doc:2 embedding >/dev/null 2>&1
+
+ # FT.SEARCH — verify command doesn't error (redis-cli can't pass binary args directly)
+ TOTAL=$((TOTAL + 1)); FT_SEARCH=$(mcli FT.SEARCH myidx "*" 2>&1)
+ if ! echo "$FT_SEARCH" | grep -qi "err"; then PASS=$((PASS + 1)); echo " PASS: FT.SEARCH does not error"; else FAIL=$((FAIL + 1)); echo " FAIL: FT.SEARCH returned error"; fi
+
+ # FT.DROPINDEX — remove index
+ assert_moon "FT.DROPINDEX" "OK" FT.DROPINDEX myidx
+
+ # FT.INFO after drop should error
+ TOTAL=$((TOTAL + 1)); FT_INFO_AFTER=$(mcli FT.INFO myidx 2>&1)
+ if echo "$FT_INFO_AFTER" | grep -qi "err\|not found"; then PASS=$((PASS + 1)); echo " PASS: FT.INFO after drop errors"; else FAIL=$((FAIL + 1)); echo " FAIL: FT.INFO after drop errors"; fi
+fi
+
+# ===========================================================================
+# PERSISTENCE COMMANDS
+# ===========================================================================
+
if should_run "persistence"; then
echo ""
echo "=== PERSISTENCE COMMANDS ==="
diff --git a/scripts/test-consistency.sh b/scripts/test-consistency.sh
index 3ca8873a..f854cb50 100755
--- a/scripts/test-consistency.sh
+++ b/scripts/test-consistency.sh
@@ -481,6 +481,31 @@ assert_both "GET with 500-char key" GET "$LONGKEY"
# ===========================================================================
echo ""
+# ===========================================================================
+# Vector Search (moon-only — FT.* not available in Redis)
+# ===========================================================================
+log "=== Vector Search (moon-only) ==="
+
+# Create index on moon only
+FT_CREATE=$(redis-cli -p "$PORT_RUST" FT.CREATE vecidx ON HASH PREFIX 1 vec: SCHEMA embedding VECTOR FLAT 6 DIM 4 DISTANCE_METRIC L2 TYPE FLOAT32 2>&1)
+assert_eq "FT.CREATE" "OK" "$FT_CREATE"
+
+# Insert vectors — use python3 to avoid null byte stripping in bash
+python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',1.0,0.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET vec:1 embedding >/dev/null 2>&1
+python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',0.0,1.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET vec:2 embedding >/dev/null 2>&1
+
+# FT.INFO should show index
+FT_INFO=$(redis-cli -p "$PORT_RUST" FT.INFO vecidx 2>&1)
+if echo "$FT_INFO" | grep -q "vecidx"; then
+ PASS=$((PASS + 1))
+else
+ FAIL=$((FAIL + 1)); echo " FAIL: FT.INFO should show vecidx"
+fi
+
+# FT.DROPINDEX
+FT_DROP=$(redis-cli -p "$PORT_RUST" FT.DROPINDEX vecidx 2>&1)
+assert_eq "FT.DROPINDEX" "OK" "$FT_DROP"
+
echo "============================================"
echo " Data Consistency Test Results"
echo "============================================"
diff --git a/src/command/vector_search.rs b/src/command/vector_search/mod.rs
similarity index 53%
rename from src/command/vector_search.rs
rename to src/command/vector_search/mod.rs
index 23684046..8f5e334c 100644
--- a/src/command/vector_search.rs
+++ b/src/command/vector_search/mod.rs
@@ -212,7 +212,10 @@ pub fn ft_create(store: &mut VectorStore, args: &[Frame]) -> Frame {
}
let dim = match dimension {
- Some(d) if d > 0 => d,
+ Some(d) if d > 0 && d <= 65536 => d,
+ Some(d) if d > 65536 => {
+ return Frame::Error(Bytes::from_static(b"ERR DIM must be between 1 and 65536"));
+ }
_ => return Frame::Error(Bytes::from_static(b"ERR DIM is required and must be > 0")),
};
@@ -236,7 +239,12 @@ pub fn ft_create(store: &mut VectorStore, args: &[Frame]) -> Frame {
crate::vector::metrics::increment_indexes();
Frame::SimpleString(Bytes::from_static(b"OK"))
}
- Err(msg) => Frame::Error(Bytes::from(format!("ERR {msg}"))),
+ Err(msg) => {
+ let mut buf = Vec::with_capacity(4 + msg.len());
+ buf.extend_from_slice(b"ERR ");
+ buf.extend_from_slice(msg.as_bytes());
+ Frame::Error(Bytes::from(buf))
+ }
}
}
@@ -304,15 +312,26 @@ pub fn ft_info(store: &VectorStore, args: &[Frame]) -> Frame {
let snap = idx.segments.load();
let num_docs = snap.mutable.len();
- let ef_rt_str = if idx.meta.hnsw_ef_runtime > 0 {
- format!("{}", idx.meta.hnsw_ef_runtime)
+ // Use itoa for numeric formatting — no format!() on hot path.
+ let ef_rt_bytes: Bytes = if idx.meta.hnsw_ef_runtime > 0 {
+ let mut buf = itoa::Buffer::new();
+ Bytes::copy_from_slice(buf.format(idx.meta.hnsw_ef_runtime).as_bytes())
} else {
- "auto".to_string()
+ Bytes::from_static(b"auto")
};
- let ct_str = if idx.meta.compact_threshold > 0 {
- format!("{}", idx.meta.compact_threshold)
+ let ct_bytes: Bytes = if idx.meta.compact_threshold > 0 {
+ let mut buf = itoa::Buffer::new();
+ Bytes::copy_from_slice(buf.format(idx.meta.compact_threshold).as_bytes())
} else {
- "1000".to_string()
+ Bytes::from_static(b"1000")
+ };
+ let quant_bytes: Bytes = match idx.meta.quantization {
+ QuantizationConfig::Sq8 => Bytes::from_static(b"SQ8"),
+ QuantizationConfig::TurboQuant4 => Bytes::from_static(b"TurboQuant4"),
+ QuantizationConfig::TurboQuantProd4 => Bytes::from_static(b"TurboQuantProd4"),
+ QuantizationConfig::TurboQuant1 => Bytes::from_static(b"TurboQuant1"),
+ QuantizationConfig::TurboQuant2 => Bytes::from_static(b"TurboQuant2"),
+ QuantizationConfig::TurboQuant3 => Bytes::from_static(b"TurboQuant3"),
};
let items = vec![
@@ -337,11 +356,11 @@ pub fn ft_info(store: &VectorStore, args: &[Frame]) -> Frame {
Frame::BulkString(Bytes::from_static(b"EF_CONSTRUCTION")),
Frame::Integer(idx.meta.hnsw_ef_construction as i64),
Frame::BulkString(Bytes::from_static(b"EF_RUNTIME")),
- Frame::BulkString(Bytes::from(ef_rt_str)),
+ Frame::BulkString(ef_rt_bytes),
Frame::BulkString(Bytes::from_static(b"COMPACT_THRESHOLD")),
- Frame::BulkString(Bytes::from(ct_str)),
+ Frame::BulkString(ct_bytes),
Frame::BulkString(Bytes::from_static(b"QUANTIZATION")),
- Frame::BulkString(Bytes::from(format!("{:?}", idx.meta.quantization))),
+ Frame::BulkString(quant_bytes),
];
Frame::Array(items.into())
}
@@ -399,10 +418,8 @@ pub fn ft_search(store: &mut VectorStore, args: &[Frame]) -> Frame {
// Parse optional FILTER clause
let filter_expr = parse_filter_clause(args);
- let start = std::time::Instant::now();
let result = search_local_filtered(store, &index_name, &query_blob, k, filter_expr.as_ref());
crate::vector::metrics::increment_search();
- crate::vector::metrics::record_search_latency(start.elapsed().as_micros() as u64);
result
}
@@ -558,8 +575,11 @@ fn build_search_response(results: &SmallVec<[SearchResult; 32]>) -> Frame {
doc_id.extend_from_slice(id_str.as_bytes());
items.push(Frame::BulkString(Bytes::from(doc_id)));
- // Score as nested array (format! acceptable -- end of command path)
- let score_str = format!("{}", r.distance);
+ // Score as nested array — use write! to pre-allocated buffer
+ let mut score_buf = String::with_capacity(16);
+ use std::fmt::Write;
+ let _ = write!(score_buf, "{}", r.distance);
+ let score_str = score_buf;
let fields = vec![
Frame::BulkString(Bytes::from_static(b"__vec_score")),
Frame::BulkString(Bytes::from(score_str)),
@@ -835,692 +855,4 @@ fn metric_to_bytes(m: DistanceMetric) -> Bytes {
}
#[cfg(test)]
-mod tests {
- use super::*;
-
- fn bulk(s: &[u8]) -> Frame {
- Frame::BulkString(Bytes::from(s.to_vec()))
- }
-
- /// Build a valid FT.CREATE argument list.
- fn ft_create_args() -> Vec {
- vec![
- bulk(b"myidx"), // index name
- bulk(b"ON"),
- bulk(b"HASH"),
- bulk(b"PREFIX"),
- bulk(b"1"),
- bulk(b"doc:"),
- bulk(b"SCHEMA"),
- bulk(b"vec"),
- bulk(b"VECTOR"),
- bulk(b"HNSW"),
- bulk(b"6"), // 6 params = 3 key-value pairs
- bulk(b"TYPE"),
- bulk(b"FLOAT32"),
- bulk(b"DIM"),
- bulk(b"128"),
- bulk(b"DISTANCE_METRIC"),
- bulk(b"L2"),
- ]
- }
-
- #[test]
- fn test_ft_create_parse_full_syntax() {
- let mut store = VectorStore::new();
- let args = ft_create_args();
- let result = ft_create(&mut store, &args);
- match &result {
- Frame::SimpleString(s) => assert_eq!(&s[..], b"OK"),
- other => panic!("expected OK, got {other:?}"),
- }
- assert_eq!(store.len(), 1);
- let idx = store.get_index(b"myidx").unwrap();
- assert_eq!(idx.meta.dimension, 128);
- assert_eq!(idx.meta.metric, DistanceMetric::L2);
- assert_eq!(idx.meta.key_prefixes.len(), 1);
- assert_eq!(&idx.meta.key_prefixes[0][..], b"doc:");
- }
-
- #[test]
- fn test_ft_create_missing_dim() {
- let mut store = VectorStore::new();
- // Remove DIM param pair: keep TYPE FLOAT32 and DISTANCE_METRIC L2 (4 params = 2 pairs)
- let args = vec![
- bulk(b"myidx"),
- bulk(b"ON"),
- bulk(b"HASH"),
- bulk(b"PREFIX"),
- bulk(b"1"),
- bulk(b"doc:"),
- bulk(b"SCHEMA"),
- bulk(b"vec"),
- bulk(b"VECTOR"),
- bulk(b"HNSW"),
- bulk(b"4"), // 4 params = 2 key-value pairs
- bulk(b"TYPE"),
- bulk(b"FLOAT32"),
- bulk(b"DISTANCE_METRIC"),
- bulk(b"L2"),
- ];
- let result = ft_create(&mut store, &args);
- match &result {
- Frame::Error(_) => {} // expected
- other => panic!("expected error, got {other:?}"),
- }
- }
-
- #[test]
- fn test_ft_create_duplicate() {
- let mut store = VectorStore::new();
- let args = ft_create_args();
- let r1 = ft_create(&mut store, &args);
- assert!(matches!(r1, Frame::SimpleString(_)));
-
- let args2 = ft_create_args();
- let r2 = ft_create(&mut store, &args2);
- match &r2 {
- Frame::Error(e) => assert!(e.starts_with(b"ERR")),
- other => panic!("expected error, got {other:?}"),
- }
- }
-
- #[test]
- fn test_ft_dropindex() {
- let mut store = VectorStore::new();
- let args = ft_create_args();
- ft_create(&mut store, &args);
-
- // Drop existing
- let result = ft_dropindex(&mut store, &[bulk(b"myidx")]);
- assert!(matches!(result, Frame::SimpleString(_)));
- assert!(store.is_empty());
-
- // Drop non-existing
- let result = ft_dropindex(&mut store, &[bulk(b"myidx")]);
- assert!(matches!(result, Frame::Error(_)));
- }
-
- #[test]
- fn test_parse_knn_query() {
- let query = b"*=>[KNN 10 @vec $query]";
- let (k, param) = parse_knn_query(query).unwrap();
- assert_eq!(k, 10);
- assert_eq!(¶m[..], b"query");
- }
-
- #[test]
- fn test_parse_knn_query_different_k() {
- let query = b"*=>[KNN 5 @embedding $blob]";
- let (k, param) = parse_knn_query(query).unwrap();
- assert_eq!(k, 5);
- assert_eq!(¶m[..], b"blob");
- }
-
- #[test]
- fn test_parse_knn_query_invalid() {
- assert!(parse_knn_query(b"*").is_none());
- assert!(parse_knn_query(b"*=>[NOTAKNN]").is_none());
- }
-
- #[test]
- fn test_extract_param_blob() {
- let args = vec![
- bulk(b"idx"),
- bulk(b"*=>[KNN 10 @vec $query]"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"query"),
- bulk(b"blobdata"),
- ];
- let blob = extract_param_blob(&args, b"query").unwrap();
- assert_eq!(&blob[..], b"blobdata");
- }
-
- #[test]
- fn test_extract_param_blob_missing() {
- let args = vec![bulk(b"idx"), bulk(b"*=>[KNN 10 @vec $query]")];
- assert!(extract_param_blob(&args, b"query").is_none());
- }
-
- #[test]
- fn test_quantize_f32_to_sq() {
- let input = [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0];
- let mut output = [0i8; 7];
- quantize_f32_to_sq(&input, &mut output);
- assert_eq!(output[0], 0); // 0.0 -> 0
- assert_eq!(output[1], 127); // 1.0 -> 127
- assert_eq!(output[2], -127); // -1.0 -> -127
- assert_eq!(output[3], 63); // 0.5 -> 63 (truncated from 63.5)
- assert_eq!(output[4], -63); // -0.5 -> -63
- assert_eq!(output[5], 127); // 2.0 clamped to 1.0 -> 127
- assert_eq!(output[6], -127); // -2.0 clamped to -1.0 -> -127
- }
-
- #[test]
- fn test_merge_search_results_combines_shards() {
- // Shard 0 returns: [2, "vec:0", ["__vec_score", "0.1"], "vec:1", ["__vec_score", "0.5"]]
- // Shard 1 returns: [2, "vec:10", ["__vec_score", "0.3"], "vec:11", ["__vec_score", "0.9"]]
- // Global top-2 should be: vec:0 (0.1), vec:10 (0.3)
-
- let shard0 = Frame::Array(
- vec![
- Frame::Integer(2),
- bulk(b"vec:0"),
- Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.1")].into()),
- bulk(b"vec:1"),
- Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.5")].into()),
- ]
- .into(),
- );
-
- let shard1 = Frame::Array(
- vec![
- Frame::Integer(2),
- bulk(b"vec:10"),
- Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.3")].into()),
- bulk(b"vec:11"),
- Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.9")].into()),
- ]
- .into(),
- );
-
- let result = merge_search_results(&[shard0, shard1], 2);
- match result {
- Frame::Array(items) => {
- assert_eq!(items[0], Frame::Integer(2));
- assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:0")));
- assert_eq!(items[3], Frame::BulkString(Bytes::from("vec:10")));
- }
- other => panic!("expected Array, got {other:?}"),
- }
- }
-
- #[test]
- fn test_merge_search_results_handles_errors() {
- // One shard returns error, one returns valid results
- let shard0 = Frame::Error(Bytes::from_static(b"ERR shard unavailable"));
- let shard1 = Frame::Array(
- vec![
- Frame::Integer(1),
- bulk(b"vec:5"),
- Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.2")].into()),
- ]
- .into(),
- );
-
- let result = merge_search_results(&[shard0, shard1], 5);
- match result {
- Frame::Array(items) => {
- assert_eq!(items[0], Frame::Integer(1));
- assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:5")));
- }
- other => panic!("expected Array, got {other:?}"),
- }
- }
-
- #[test]
- fn test_merge_search_results_empty() {
- // No results from any shard
- let shard0 = Frame::Array(vec![Frame::Integer(0)].into());
- let shard1 = Frame::Array(vec![Frame::Integer(0)].into());
-
- let result = merge_search_results(&[shard0, shard1], 10);
- match result {
- Frame::Array(items) => {
- assert_eq!(items.len(), 1);
- assert_eq!(items[0], Frame::Integer(0));
- }
- other => panic!("expected Array, got {other:?}"),
- }
- }
-
- #[test]
- fn test_ft_search_dimension_mismatch() {
- let mut store = VectorStore::new();
- let args = ft_create_args();
- ft_create(&mut store, &args);
-
- // Build a query with wrong dimension (4 bytes instead of 128*4)
- let search_args = vec![
- bulk(b"myidx"),
- bulk(b"*=>[KNN 10 @vec $query]"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"query"),
- bulk(b"tooshort"),
- ];
- let result = ft_search(&mut store, &search_args);
- match &result {
- Frame::Error(e) => assert!(
- e.starts_with(b"ERR query vector dimension"),
- "expected dimension mismatch error, got {:?}",
- std::str::from_utf8(e)
- ),
- other => panic!("expected error, got {other:?}"),
- }
- }
-
- #[test]
- fn test_ft_search_empty_index() {
- let mut store = VectorStore::new();
- let args = ft_create_args();
- ft_create(&mut store, &args);
-
- // Build valid query for dim=128
- let query_vec: Vec = vec![0u8; 128 * 4]; // 128 floats, all zero
- let search_args = vec![
- bulk(b"myidx"),
- bulk(b"*=>[KNN 5 @vec $query]"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"query"),
- Frame::BulkString(Bytes::from(query_vec)),
- ];
- crate::vector::distance::init();
- let result = ft_search(&mut store, &search_args);
- match result {
- Frame::Array(items) => {
- assert_eq!(items[0], Frame::Integer(0)); // no results
- }
- other => panic!("expected Array, got {other:?}"),
- }
- }
-
- #[test]
- fn test_ft_info() {
- let mut store = VectorStore::new();
- let args = ft_create_args();
- ft_create(&mut store, &args);
-
- let result = ft_info(&store, &[bulk(b"myidx")]);
- match result {
- Frame::Array(items) => {
- // Should have 20 items (10 key-value pairs)
- assert!(
- items.len() >= 20,
- "FT.INFO should return at least 20 items, got {}",
- items.len()
- );
- assert_eq!(
- items[0],
- Frame::BulkString(Bytes::from_static(b"index_name"))
- );
- assert_eq!(items[1], Frame::BulkString(Bytes::from("myidx")));
- assert_eq!(items[5], Frame::Integer(0)); // num_docs = 0
- assert_eq!(items[7], Frame::Integer(128)); // dimension
- // New fields
- assert_eq!(items[10], Frame::BulkString(Bytes::from_static(b"M")));
- assert_eq!(items[11], Frame::Integer(16)); // default M
- assert_eq!(
- items[14],
- Frame::BulkString(Bytes::from_static(b"EF_RUNTIME"))
- );
- }
- other => panic!("expected Array, got {other:?}"),
- }
-
- // Non-existing index
- let result = ft_info(&store, &[bulk(b"nonexistent")]);
- assert!(matches!(result, Frame::Error(_)));
- }
-
- /// Helper to build FT.CREATE args with custom parameters.
- fn build_ft_create_args(
- name: &str,
- prefix: &str,
- field: &str,
- dim: u32,
- metric: &str,
- ) -> Vec {
- vec![
- Frame::BulkString(Bytes::from(name.to_owned())),
- Frame::BulkString(Bytes::from_static(b"ON")),
- Frame::BulkString(Bytes::from_static(b"HASH")),
- Frame::BulkString(Bytes::from_static(b"PREFIX")),
- Frame::BulkString(Bytes::from_static(b"1")),
- Frame::BulkString(Bytes::from(prefix.to_owned())),
- Frame::BulkString(Bytes::from_static(b"SCHEMA")),
- Frame::BulkString(Bytes::from(field.to_owned())),
- Frame::BulkString(Bytes::from_static(b"VECTOR")),
- Frame::BulkString(Bytes::from_static(b"HNSW")),
- Frame::BulkString(Bytes::from_static(b"6")),
- Frame::BulkString(Bytes::from_static(b"TYPE")),
- Frame::BulkString(Bytes::from_static(b"FLOAT32")),
- Frame::BulkString(Bytes::from_static(b"DIM")),
- Frame::BulkString(Bytes::from(dim.to_string())),
- Frame::BulkString(Bytes::from_static(b"DISTANCE_METRIC")),
- Frame::BulkString(Bytes::from(metric.to_owned())),
- ]
- }
-
- #[test]
- fn test_end_to_end_create_insert_search() {
- // Initialize distance functions (required before any search)
- crate::vector::distance::init();
-
- let mut store = VectorStore::new();
- let dim: usize = 4;
-
- // 1. FT.CREATE
- let create_args = build_ft_create_args("e2eidx", "doc:", "embedding", dim as u32, "L2");
- let result = ft_create(&mut store, &create_args);
- assert!(
- matches!(result, Frame::SimpleString(_)),
- "FT.CREATE should return OK, got {result:?}"
- );
-
- // 2. Insert vectors directly into the mutable segment
- let idx = store.get_index_mut(b"e2eidx").unwrap();
- let vectors: Vec<[f32; 4]> = vec![
- [1.0, 0.0, 0.0, 0.0], // vec:0 -- exact match for query (L2=0)
- [-1.0, 0.0, 0.0, 0.0], // vec:1 -- opposite direction (L2=4.0)
- [0.5, 0.0, 0.0, 0.0], // vec:2 -- same direction, half magnitude (L2=0.25)
- ];
-
- let snap = idx.segments.load();
- for (i, v) in vectors.iter().enumerate() {
- let mut sq = vec![0i8; dim];
- quantize_f32_to_sq(v, &mut sq);
- let norm = v.iter().map(|x| x * x).sum::().sqrt();
- snap.mutable.append(i as u64, v, &sq, norm, i as u64);
- }
- drop(snap);
-
- // 3. FT.SEARCH for vector close to [1.0, 0.0, 0.0, 0.0]
- let query_vec: [f32; 4] = [1.0, 0.0, 0.0, 0.0];
- let query_blob: Vec = query_vec.iter().flat_map(|f| f.to_le_bytes()).collect();
-
- let search_args = vec![
- Frame::BulkString(Bytes::from_static(b"e2eidx")),
- Frame::BulkString(Bytes::from_static(b"*=>[KNN 2 @embedding $query]")),
- Frame::BulkString(Bytes::from_static(b"PARAMS")),
- Frame::BulkString(Bytes::from_static(b"2")),
- Frame::BulkString(Bytes::from_static(b"query")),
- Frame::BulkString(Bytes::from(query_blob)),
- ];
-
- let result = ft_search(&mut store, &search_args);
- match &result {
- Frame::Array(items) => {
- // First element is count
- assert!(
- matches!(&items[0], Frame::Integer(n) if *n >= 1),
- "Should find at least 1 result, got {result:?}"
- );
- // vec:0 should be in top-2 results (at dim=4, TQ-4bit quantization
- // noise can swap rankings of very close vectors in Light mode)
- let mut found_vec0 = false;
- for idx in [1, 3].iter() {
- if let Some(Frame::BulkString(doc_id)) = items.get(*idx) {
- if doc_id.as_ref() == b"vec:0" {
- found_vec0 = true;
- }
- }
- }
- assert!(
- found_vec0,
- "vec:0 should be in top-2 results, got {result:?}"
- );
- // vec:2 should be in top-2 (at dim=4, TQ noise may reorder)
- let mut found_vec2 = false;
- for idx in [1, 3].iter() {
- if let Some(Frame::BulkString(doc_id)) = items.get(*idx) {
- if doc_id.as_ref() == b"vec:2" {
- found_vec2 = true;
- }
- }
- }
- assert!(
- found_vec2,
- "vec:2 should be in top-2 results, got {result:?}"
- );
- }
- Frame::Error(e) => panic!("FT.SEARCH returned error: {:?}", std::str::from_utf8(e)),
- _ => panic!("FT.SEARCH should return Array, got {result:?}"),
- }
- }
-
- #[test]
- fn test_ft_info_returns_correct_data() {
- let mut store = VectorStore::new();
- let args = build_ft_create_args("testidx", "test:", "vec", 128, "COSINE");
- ft_create(&mut store, &args);
-
- let info_args = [Frame::BulkString(Bytes::from_static(b"testidx"))];
- let result = ft_info(&store, &info_args);
- match result {
- Frame::Array(items) => {
- assert!(items.len() >= 6, "FT.INFO should return at least 6 items");
- // Check dimension
- let mut found_dim = false;
- for pair in items.chunks(2) {
- if let Frame::BulkString(key) = &pair[0] {
- if key.as_ref() == b"dimension" {
- if let Frame::Integer(d) = &pair[1] {
- assert_eq!(*d, 128);
- found_dim = true;
- }
- }
- }
- }
- assert!(found_dim, "FT.INFO should return dimension");
- }
- other => panic!("FT.INFO should return Array, got {other:?}"),
- }
- }
-
- #[test]
- fn test_ft_search_unknown_index() {
- let mut store = VectorStore::new();
- let args = [
- Frame::BulkString(Bytes::from_static(b"nonexistent")),
- Frame::BulkString(Bytes::from_static(b"*=>[KNN 5 @vec $query]")),
- Frame::BulkString(Bytes::from_static(b"PARAMS")),
- Frame::BulkString(Bytes::from_static(b"2")),
- Frame::BulkString(Bytes::from_static(b"query")),
- Frame::BulkString(Bytes::from(vec![0u8; 16])),
- ];
- let result = ft_search(&mut store, &args);
- assert!(
- matches!(result, Frame::Error(_)),
- "Should error on unknown index, got {result:?}"
- );
- }
-
- #[test]
- fn test_parse_filter_clause_tag() {
- let args = vec![
- bulk(b"idx"),
- bulk(b"*=>[KNN 10 @vec $q]"),
- bulk(b"FILTER"),
- bulk(b"@category:{electronics}"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"q"),
- bulk(b"blob"),
- ];
- let filter = parse_filter_clause(&args);
- assert!(filter.is_some(), "should parse @category:{{electronics}}");
- match filter.unwrap() {
- crate::vector::filter::FilterExpr::TagEq { field, value } => {
- assert_eq!(&field[..], b"category");
- assert_eq!(&value[..], b"electronics");
- }
- other => panic!("expected TagEq, got {other:?}"),
- }
- }
-
- #[test]
- fn test_parse_filter_clause_numeric_range() {
- let args = vec![
- bulk(b"idx"),
- bulk(b"*=>[KNN 5 @vec $q]"),
- bulk(b"FILTER"),
- bulk(b"@price:[10 100]"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"q"),
- bulk(b"blob"),
- ];
- let filter = parse_filter_clause(&args);
- assert!(filter.is_some());
- match filter.unwrap() {
- crate::vector::filter::FilterExpr::NumRange { field, min, max } => {
- assert_eq!(&field[..], b"price");
- assert_eq!(*min, 10.0);
- assert_eq!(*max, 100.0);
- }
- other => panic!("expected NumRange, got {other:?}"),
- }
- }
-
- #[test]
- fn test_parse_filter_clause_numeric_eq() {
- let args = vec![
- bulk(b"idx"),
- bulk(b"*=>[KNN 5 @vec $q]"),
- bulk(b"FILTER"),
- bulk(b"@price:[50 50]"),
- ];
- let filter = parse_filter_clause(&args);
- assert!(filter.is_some());
- match filter.unwrap() {
- crate::vector::filter::FilterExpr::NumEq { field, value } => {
- assert_eq!(&field[..], b"price");
- assert_eq!(*value, 50.0);
- }
- other => panic!("expected NumEq, got {other:?}"),
- }
- }
-
- #[test]
- fn test_parse_filter_clause_compound() {
- let args = vec![
- bulk(b"idx"),
- bulk(b"*=>[KNN 5 @vec $q]"),
- bulk(b"FILTER"),
- bulk(b"@a:{x} @b:[1 10]"),
- ];
- let filter = parse_filter_clause(&args);
- assert!(filter.is_some());
- match filter.unwrap() {
- crate::vector::filter::FilterExpr::And(left, right) => {
- assert!(matches!(
- *left,
- crate::vector::filter::FilterExpr::TagEq { .. }
- ));
- assert!(matches!(
- *right,
- crate::vector::filter::FilterExpr::NumRange { .. }
- ));
- }
- other => panic!("expected And, got {other:?}"),
- }
- }
-
- #[test]
- fn test_parse_filter_clause_none() {
- // No FILTER keyword
- let args = vec![
- bulk(b"idx"),
- bulk(b"*=>[KNN 10 @vec $q]"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"q"),
- bulk(b"blob"),
- ];
- let filter = parse_filter_clause(&args);
- assert!(filter.is_none());
- }
-
- #[test]
- fn test_ft_search_with_filter_no_regression() {
- // Unfiltered FT.SEARCH still works identically
- crate::vector::distance::init();
- let mut store = VectorStore::new();
- let args = ft_create_args();
- ft_create(&mut store, &args);
-
- let query_vec: Vec = vec![0u8; 128 * 4];
- let search_args = vec![
- bulk(b"myidx"),
- bulk(b"*=>[KNN 5 @vec $query]"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"query"),
- Frame::BulkString(Bytes::from(query_vec)),
- ];
- let result = ft_search(&mut store, &search_args);
- match result {
- Frame::Array(items) => {
- assert_eq!(items[0], Frame::Integer(0));
- }
- other => panic!("expected Array, got {other:?}"),
- }
- }
-
- #[test]
- fn test_vector_index_has_payload_index() {
- let mut store = VectorStore::new();
- let args = ft_create_args();
- ft_create(&mut store, &args);
- let idx = store.get_index(b"myidx").unwrap();
- // payload_index should exist -- insert and evaluate should work
- let _ = &idx.payload_index;
- }
-
- #[test]
- fn test_vector_metrics_increment_decrement() {
- use std::sync::atomic::Ordering;
-
- // Capture before-snapshot immediately before each operation to handle
- // parallel test interference on global atomics.
- let mut store = VectorStore::new();
- let args = ft_create_args();
-
- // FT.CREATE should increment VECTOR_INDEXES
- let before_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
- ft_create(&mut store, &args);
- let after_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
- assert!(
- after_create > before_create,
- "FT.CREATE should increment VECTOR_INDEXES"
- );
-
- // FT.SEARCH should increment VECTOR_SEARCH_TOTAL
- crate::vector::distance::init();
- let before_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed);
- let query_vec: Vec = vec![0u8; 128 * 4];
- let search_args = vec![
- bulk(b"myidx"),
- bulk(b"*=>[KNN 5 @vec $query]"),
- bulk(b"PARAMS"),
- bulk(b"2"),
- bulk(b"query"),
- Frame::BulkString(Bytes::from(query_vec)),
- ];
- ft_search(&mut store, &search_args);
- let after_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed);
- assert!(
- after_search > before_search,
- "FT.SEARCH should increment VECTOR_SEARCH_TOTAL"
- );
-
- // Latency should be non-zero after a search
- let latency = crate::vector::metrics::VECTOR_SEARCH_LATENCY_US.load(Ordering::Relaxed);
- // latency may be 0 on very fast machines, so just check it was written (could be 0 if sub-microsecond)
-
- // FT.DROPINDEX should decrement VECTOR_INDEXES
- let before_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
- ft_dropindex(&mut store, &[bulk(b"myidx")]);
- let after_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
- assert!(
- after_drop < before_drop,
- "FT.DROPINDEX should decrement VECTOR_INDEXES"
- );
-
- // Suppress unused variable warning
- let _ = latency;
- }
-}
+mod tests;
diff --git a/src/command/vector_search/tests.rs b/src/command/vector_search/tests.rs
new file mode 100644
index 00000000..c9290a39
--- /dev/null
+++ b/src/command/vector_search/tests.rs
@@ -0,0 +1,684 @@
+use super::*;
+use std::sync::Mutex;
+
+/// Serialize tests that touch global atomic metrics to avoid flaky interference.
+static METRICS_LOCK: Mutex<()> = Mutex::new(());
+
+fn bulk(s: &[u8]) -> Frame {
+ Frame::BulkString(Bytes::from(s.to_vec()))
+}
+
+/// Build a valid FT.CREATE argument list.
+fn ft_create_args() -> Vec {
+ vec![
+ bulk(b"myidx"), // index name
+ bulk(b"ON"),
+ bulk(b"HASH"),
+ bulk(b"PREFIX"),
+ bulk(b"1"),
+ bulk(b"doc:"),
+ bulk(b"SCHEMA"),
+ bulk(b"vec"),
+ bulk(b"VECTOR"),
+ bulk(b"HNSW"),
+ bulk(b"6"), // 6 params = 3 key-value pairs
+ bulk(b"TYPE"),
+ bulk(b"FLOAT32"),
+ bulk(b"DIM"),
+ bulk(b"128"),
+ bulk(b"DISTANCE_METRIC"),
+ bulk(b"L2"),
+ ]
+}
+
+#[test]
+fn test_ft_create_parse_full_syntax() {
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ let result = ft_create(&mut store, &args);
+ match &result {
+ Frame::SimpleString(s) => assert_eq!(&s[..], b"OK"),
+ other => panic!("expected OK, got {other:?}"),
+ }
+ assert_eq!(store.len(), 1);
+ let idx = store.get_index(b"myidx").unwrap();
+ assert_eq!(idx.meta.dimension, 128);
+ assert_eq!(idx.meta.metric, DistanceMetric::L2);
+ assert_eq!(idx.meta.key_prefixes.len(), 1);
+ assert_eq!(&idx.meta.key_prefixes[0][..], b"doc:");
+}
+
+#[test]
+fn test_ft_create_missing_dim() {
+ let mut store = VectorStore::new();
+ // Remove DIM param pair: keep TYPE FLOAT32 and DISTANCE_METRIC L2 (4 params = 2 pairs)
+ let args = vec![
+ bulk(b"myidx"),
+ bulk(b"ON"),
+ bulk(b"HASH"),
+ bulk(b"PREFIX"),
+ bulk(b"1"),
+ bulk(b"doc:"),
+ bulk(b"SCHEMA"),
+ bulk(b"vec"),
+ bulk(b"VECTOR"),
+ bulk(b"HNSW"),
+ bulk(b"4"), // 4 params = 2 key-value pairs
+ bulk(b"TYPE"),
+ bulk(b"FLOAT32"),
+ bulk(b"DISTANCE_METRIC"),
+ bulk(b"L2"),
+ ];
+ let result = ft_create(&mut store, &args);
+ match &result {
+ Frame::Error(_) => {} // expected
+ other => panic!("expected error, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_ft_create_duplicate() {
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ let r1 = ft_create(&mut store, &args);
+ assert!(matches!(r1, Frame::SimpleString(_)));
+
+ let args2 = ft_create_args();
+ let r2 = ft_create(&mut store, &args2);
+ match &r2 {
+ Frame::Error(e) => assert!(e.starts_with(b"ERR")),
+ other => panic!("expected error, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_ft_dropindex() {
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ ft_create(&mut store, &args);
+
+ // Drop existing
+ let result = ft_dropindex(&mut store, &[bulk(b"myidx")]);
+ assert!(matches!(result, Frame::SimpleString(_)));
+ assert!(store.is_empty());
+
+ // Drop non-existing
+ let result = ft_dropindex(&mut store, &[bulk(b"myidx")]);
+ assert!(matches!(result, Frame::Error(_)));
+}
+
+#[test]
+fn test_parse_knn_query() {
+ let query = b"*=>[KNN 10 @vec $query]";
+ let (k, param) = parse_knn_query(query).unwrap();
+ assert_eq!(k, 10);
+ assert_eq!(¶m[..], b"query");
+}
+
+#[test]
+fn test_parse_knn_query_different_k() {
+ let query = b"*=>[KNN 5 @embedding $blob]";
+ let (k, param) = parse_knn_query(query).unwrap();
+ assert_eq!(k, 5);
+ assert_eq!(¶m[..], b"blob");
+}
+
+#[test]
+fn test_parse_knn_query_invalid() {
+ assert!(parse_knn_query(b"*").is_none());
+ assert!(parse_knn_query(b"*=>[NOTAKNN]").is_none());
+}
+
+#[test]
+fn test_extract_param_blob() {
+ let args = vec![
+ bulk(b"idx"),
+ bulk(b"*=>[KNN 10 @vec $query]"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"query"),
+ bulk(b"blobdata"),
+ ];
+ let blob = extract_param_blob(&args, b"query").unwrap();
+ assert_eq!(&blob[..], b"blobdata");
+}
+
+#[test]
+fn test_extract_param_blob_missing() {
+ let args = vec![bulk(b"idx"), bulk(b"*=>[KNN 10 @vec $query]")];
+ assert!(extract_param_blob(&args, b"query").is_none());
+}
+
+#[test]
+fn test_quantize_f32_to_sq() {
+ let input = [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0];
+ let mut output = [0i8; 7];
+ quantize_f32_to_sq(&input, &mut output);
+ assert_eq!(output[0], 0); // 0.0 -> 0
+ assert_eq!(output[1], 127); // 1.0 -> 127
+ assert_eq!(output[2], -127); // -1.0 -> -127
+ assert_eq!(output[3], 63); // 0.5 -> 63 (truncated from 63.5)
+ assert_eq!(output[4], -63); // -0.5 -> -63
+ assert_eq!(output[5], 127); // 2.0 clamped to 1.0 -> 127
+ assert_eq!(output[6], -127); // -2.0 clamped to -1.0 -> -127
+}
+
+#[test]
+fn test_merge_search_results_combines_shards() {
+ // Shard 0 returns: [2, "vec:0", ["__vec_score", "0.1"], "vec:1", ["__vec_score", "0.5"]]
+ // Shard 1 returns: [2, "vec:10", ["__vec_score", "0.3"], "vec:11", ["__vec_score", "0.9"]]
+ // Global top-2 should be: vec:0 (0.1), vec:10 (0.3)
+
+ let shard0 = Frame::Array(
+ vec![
+ Frame::Integer(2),
+ bulk(b"vec:0"),
+ Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.1")].into()),
+ bulk(b"vec:1"),
+ Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.5")].into()),
+ ]
+ .into(),
+ );
+
+ let shard1 = Frame::Array(
+ vec![
+ Frame::Integer(2),
+ bulk(b"vec:10"),
+ Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.3")].into()),
+ bulk(b"vec:11"),
+ Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.9")].into()),
+ ]
+ .into(),
+ );
+
+ let result = merge_search_results(&[shard0, shard1], 2);
+ match result {
+ Frame::Array(items) => {
+ assert_eq!(items[0], Frame::Integer(2));
+ assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:0")));
+ assert_eq!(items[3], Frame::BulkString(Bytes::from("vec:10")));
+ }
+ other => panic!("expected Array, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_merge_search_results_handles_errors() {
+ // One shard returns error, one returns valid results
+ let shard0 = Frame::Error(Bytes::from_static(b"ERR shard unavailable"));
+ let shard1 = Frame::Array(
+ vec![
+ Frame::Integer(1),
+ bulk(b"vec:5"),
+ Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.2")].into()),
+ ]
+ .into(),
+ );
+
+ let result = merge_search_results(&[shard0, shard1], 5);
+ match result {
+ Frame::Array(items) => {
+ assert_eq!(items[0], Frame::Integer(1));
+ assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:5")));
+ }
+ other => panic!("expected Array, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_merge_search_results_empty() {
+ // No results from any shard
+ let shard0 = Frame::Array(vec![Frame::Integer(0)].into());
+ let shard1 = Frame::Array(vec![Frame::Integer(0)].into());
+
+ let result = merge_search_results(&[shard0, shard1], 10);
+ match result {
+ Frame::Array(items) => {
+ assert_eq!(items.len(), 1);
+ assert_eq!(items[0], Frame::Integer(0));
+ }
+ other => panic!("expected Array, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_ft_search_dimension_mismatch() {
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ ft_create(&mut store, &args);
+
+ // Build a query with wrong dimension (4 bytes instead of 128*4)
+ let search_args = vec![
+ bulk(b"myidx"),
+ bulk(b"*=>[KNN 10 @vec $query]"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"query"),
+ bulk(b"tooshort"),
+ ];
+ let result = ft_search(&mut store, &search_args);
+ match &result {
+ Frame::Error(e) => assert!(
+ e.starts_with(b"ERR query vector dimension"),
+ "expected dimension mismatch error, got {:?}",
+ std::str::from_utf8(e)
+ ),
+ other => panic!("expected error, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_ft_search_empty_index() {
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ ft_create(&mut store, &args);
+
+ // Build valid query for dim=128
+ let query_vec: Vec = vec![0u8; 128 * 4]; // 128 floats, all zero
+ let search_args = vec![
+ bulk(b"myidx"),
+ bulk(b"*=>[KNN 5 @vec $query]"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"query"),
+ Frame::BulkString(Bytes::from(query_vec)),
+ ];
+ crate::vector::distance::init();
+ let result = ft_search(&mut store, &search_args);
+ match result {
+ Frame::Array(items) => {
+ assert_eq!(items[0], Frame::Integer(0)); // no results
+ }
+ other => panic!("expected Array, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_ft_info() {
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ ft_create(&mut store, &args);
+
+ let result = ft_info(&store, &[bulk(b"myidx")]);
+ match result {
+ Frame::Array(items) => {
+ // Should have 20 items (10 key-value pairs)
+ assert!(
+ items.len() >= 20,
+ "FT.INFO should return at least 20 items, got {}",
+ items.len()
+ );
+ assert_eq!(
+ items[0],
+ Frame::BulkString(Bytes::from_static(b"index_name"))
+ );
+ assert_eq!(items[1], Frame::BulkString(Bytes::from("myidx")));
+ assert_eq!(items[5], Frame::Integer(0)); // num_docs = 0
+ assert_eq!(items[7], Frame::Integer(128)); // dimension
+ // New fields
+ assert_eq!(items[10], Frame::BulkString(Bytes::from_static(b"M")));
+ assert_eq!(items[11], Frame::Integer(16)); // default M
+ assert_eq!(
+ items[14],
+ Frame::BulkString(Bytes::from_static(b"EF_RUNTIME"))
+ );
+ }
+ other => panic!("expected Array, got {other:?}"),
+ }
+
+ // Non-existing index
+ let result = ft_info(&store, &[bulk(b"nonexistent")]);
+ assert!(matches!(result, Frame::Error(_)));
+}
+
+/// Helper to build FT.CREATE args with custom parameters.
+fn build_ft_create_args(
+ name: &str,
+ prefix: &str,
+ field: &str,
+ dim: u32,
+ metric: &str,
+) -> Vec {
+ vec![
+ Frame::BulkString(Bytes::from(name.to_owned())),
+ Frame::BulkString(Bytes::from_static(b"ON")),
+ Frame::BulkString(Bytes::from_static(b"HASH")),
+ Frame::BulkString(Bytes::from_static(b"PREFIX")),
+ Frame::BulkString(Bytes::from_static(b"1")),
+ Frame::BulkString(Bytes::from(prefix.to_owned())),
+ Frame::BulkString(Bytes::from_static(b"SCHEMA")),
+ Frame::BulkString(Bytes::from(field.to_owned())),
+ Frame::BulkString(Bytes::from_static(b"VECTOR")),
+ Frame::BulkString(Bytes::from_static(b"HNSW")),
+ Frame::BulkString(Bytes::from_static(b"6")),
+ Frame::BulkString(Bytes::from_static(b"TYPE")),
+ Frame::BulkString(Bytes::from_static(b"FLOAT32")),
+ Frame::BulkString(Bytes::from_static(b"DIM")),
+ Frame::BulkString(Bytes::from(dim.to_string())),
+ Frame::BulkString(Bytes::from_static(b"DISTANCE_METRIC")),
+ Frame::BulkString(Bytes::from(metric.to_owned())),
+ ]
+}
+
+#[test]
+fn test_end_to_end_create_insert_search() {
+ // Initialize distance functions (required before any search)
+ crate::vector::distance::init();
+
+ let mut store = VectorStore::new();
+ let dim: usize = 4;
+
+ // 1. FT.CREATE
+ let create_args = build_ft_create_args("e2eidx", "doc:", "embedding", dim as u32, "L2");
+ let result = ft_create(&mut store, &create_args);
+ assert!(
+ matches!(result, Frame::SimpleString(_)),
+ "FT.CREATE should return OK, got {result:?}"
+ );
+
+ // 2. Insert vectors directly into the mutable segment
+ let idx = store.get_index_mut(b"e2eidx").unwrap();
+ let vectors: Vec<[f32; 4]> = vec![
+ [1.0, 0.0, 0.0, 0.0], // vec:0 -- exact match for query (L2=0)
+ [-1.0, 0.0, 0.0, 0.0], // vec:1 -- opposite direction (L2=4.0)
+ [0.5, 0.0, 0.0, 0.0], // vec:2 -- same direction, half magnitude (L2=0.25)
+ ];
+
+ let snap = idx.segments.load();
+ for (i, v) in vectors.iter().enumerate() {
+ let mut sq = vec![0i8; dim];
+ quantize_f32_to_sq(v, &mut sq);
+ let norm = v.iter().map(|x| x * x).sum::().sqrt();
+ snap.mutable.append(i as u64, v, &sq, norm, i as u64);
+ }
+ drop(snap);
+
+ // 3. FT.SEARCH for vector close to [1.0, 0.0, 0.0, 0.0]
+ let query_vec: [f32; 4] = [1.0, 0.0, 0.0, 0.0];
+ let query_blob: Vec = query_vec.iter().flat_map(|f| f.to_le_bytes()).collect();
+
+ let search_args = vec![
+ Frame::BulkString(Bytes::from_static(b"e2eidx")),
+ Frame::BulkString(Bytes::from_static(b"*=>[KNN 2 @embedding $query]")),
+ Frame::BulkString(Bytes::from_static(b"PARAMS")),
+ Frame::BulkString(Bytes::from_static(b"2")),
+ Frame::BulkString(Bytes::from_static(b"query")),
+ Frame::BulkString(Bytes::from(query_blob)),
+ ];
+
+ let result = ft_search(&mut store, &search_args);
+ match &result {
+ Frame::Array(items) => {
+ // First element is count
+ assert!(
+ matches!(&items[0], Frame::Integer(n) if *n >= 1),
+ "Should find at least 1 result, got {result:?}"
+ );
+ // vec:0 should be in top-2 results (at dim=4, TQ-4bit quantization
+ // noise can swap rankings of very close vectors in Light mode)
+ let mut found_vec0 = false;
+ for idx in [1, 3].iter() {
+ if let Some(Frame::BulkString(doc_id)) = items.get(*idx) {
+ if doc_id.as_ref() == b"vec:0" {
+ found_vec0 = true;
+ }
+ }
+ }
+ assert!(
+ found_vec0,
+ "vec:0 should be in top-2 results, got {result:?}"
+ );
+ // vec:2 should be in top-2 (at dim=4, TQ noise may reorder)
+ let mut found_vec2 = false;
+ for idx in [1, 3].iter() {
+ if let Some(Frame::BulkString(doc_id)) = items.get(*idx) {
+ if doc_id.as_ref() == b"vec:2" {
+ found_vec2 = true;
+ }
+ }
+ }
+ assert!(
+ found_vec2,
+ "vec:2 should be in top-2 results, got {result:?}"
+ );
+ }
+ Frame::Error(e) => panic!("FT.SEARCH returned error: {:?}", std::str::from_utf8(e)),
+ _ => panic!("FT.SEARCH should return Array, got {result:?}"),
+ }
+}
+
+#[test]
+fn test_ft_info_returns_correct_data() {
+ let mut store = VectorStore::new();
+ let args = build_ft_create_args("testidx", "test:", "vec", 128, "COSINE");
+ ft_create(&mut store, &args);
+
+ let info_args = [Frame::BulkString(Bytes::from_static(b"testidx"))];
+ let result = ft_info(&store, &info_args);
+ match result {
+ Frame::Array(items) => {
+ assert!(items.len() >= 6, "FT.INFO should return at least 6 items");
+ // Check dimension
+ let mut found_dim = false;
+ for pair in items.chunks(2) {
+ if let Frame::BulkString(key) = &pair[0] {
+ if key.as_ref() == b"dimension" {
+ if let Frame::Integer(d) = &pair[1] {
+ assert_eq!(*d, 128);
+ found_dim = true;
+ }
+ }
+ }
+ }
+ assert!(found_dim, "FT.INFO should return dimension");
+ }
+ other => panic!("FT.INFO should return Array, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_ft_search_unknown_index() {
+ let mut store = VectorStore::new();
+ let args = [
+ Frame::BulkString(Bytes::from_static(b"nonexistent")),
+ Frame::BulkString(Bytes::from_static(b"*=>[KNN 5 @vec $query]")),
+ Frame::BulkString(Bytes::from_static(b"PARAMS")),
+ Frame::BulkString(Bytes::from_static(b"2")),
+ Frame::BulkString(Bytes::from_static(b"query")),
+ Frame::BulkString(Bytes::from(vec![0u8; 16])),
+ ];
+ let result = ft_search(&mut store, &args);
+ assert!(
+ matches!(result, Frame::Error(_)),
+ "Should error on unknown index, got {result:?}"
+ );
+}
+
+#[test]
+fn test_parse_filter_clause_tag() {
+ let args = vec![
+ bulk(b"idx"),
+ bulk(b"*=>[KNN 10 @vec $q]"),
+ bulk(b"FILTER"),
+ bulk(b"@category:{electronics}"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"q"),
+ bulk(b"blob"),
+ ];
+ let filter = parse_filter_clause(&args);
+ assert!(filter.is_some(), "should parse @category:{{electronics}}");
+ match filter.unwrap() {
+ crate::vector::filter::FilterExpr::TagEq { field, value } => {
+ assert_eq!(&field[..], b"category");
+ assert_eq!(&value[..], b"electronics");
+ }
+ other => panic!("expected TagEq, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_parse_filter_clause_numeric_range() {
+ let args = vec![
+ bulk(b"idx"),
+ bulk(b"*=>[KNN 5 @vec $q]"),
+ bulk(b"FILTER"),
+ bulk(b"@price:[10 100]"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"q"),
+ bulk(b"blob"),
+ ];
+ let filter = parse_filter_clause(&args);
+ assert!(filter.is_some());
+ match filter.unwrap() {
+ crate::vector::filter::FilterExpr::NumRange { field, min, max } => {
+ assert_eq!(&field[..], b"price");
+ assert_eq!(*min, 10.0);
+ assert_eq!(*max, 100.0);
+ }
+ other => panic!("expected NumRange, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_parse_filter_clause_numeric_eq() {
+ let args = vec![
+ bulk(b"idx"),
+ bulk(b"*=>[KNN 5 @vec $q]"),
+ bulk(b"FILTER"),
+ bulk(b"@price:[50 50]"),
+ ];
+ let filter = parse_filter_clause(&args);
+ assert!(filter.is_some());
+ match filter.unwrap() {
+ crate::vector::filter::FilterExpr::NumEq { field, value } => {
+ assert_eq!(&field[..], b"price");
+ assert_eq!(*value, 50.0);
+ }
+ other => panic!("expected NumEq, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_parse_filter_clause_compound() {
+ let args = vec![
+ bulk(b"idx"),
+ bulk(b"*=>[KNN 5 @vec $q]"),
+ bulk(b"FILTER"),
+ bulk(b"@a:{x} @b:[1 10]"),
+ ];
+ let filter = parse_filter_clause(&args);
+ assert!(filter.is_some());
+ match filter.unwrap() {
+ crate::vector::filter::FilterExpr::And(left, right) => {
+ assert!(matches!(
+ *left,
+ crate::vector::filter::FilterExpr::TagEq { .. }
+ ));
+ assert!(matches!(
+ *right,
+ crate::vector::filter::FilterExpr::NumRange { .. }
+ ));
+ }
+ other => panic!("expected And, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_parse_filter_clause_none() {
+ // No FILTER keyword
+ let args = vec![
+ bulk(b"idx"),
+ bulk(b"*=>[KNN 10 @vec $q]"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"q"),
+ bulk(b"blob"),
+ ];
+ let filter = parse_filter_clause(&args);
+ assert!(filter.is_none());
+}
+
+#[test]
+fn test_ft_search_with_filter_no_regression() {
+ // Unfiltered FT.SEARCH still works identically
+ crate::vector::distance::init();
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ ft_create(&mut store, &args);
+
+ let query_vec: Vec = vec![0u8; 128 * 4];
+ let search_args = vec![
+ bulk(b"myidx"),
+ bulk(b"*=>[KNN 5 @vec $query]"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"query"),
+ Frame::BulkString(Bytes::from(query_vec)),
+ ];
+ let result = ft_search(&mut store, &search_args);
+ match result {
+ Frame::Array(items) => {
+ assert_eq!(items[0], Frame::Integer(0));
+ }
+ other => panic!("expected Array, got {other:?}"),
+ }
+}
+
+#[test]
+fn test_vector_index_has_payload_index() {
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+ ft_create(&mut store, &args);
+ let idx = store.get_index(b"myidx").unwrap();
+ // payload_index should exist -- insert and evaluate should work
+ let _ = &idx.payload_index;
+}
+
+#[test]
+fn test_vector_metrics_increment_decrement() {
+ use std::sync::atomic::Ordering;
+
+ let _guard = METRICS_LOCK.lock().unwrap();
+
+ let mut store = VectorStore::new();
+ let args = ft_create_args();
+
+ // FT.CREATE should increment VECTOR_INDEXES
+ let before_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
+ ft_create(&mut store, &args);
+ let after_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
+ assert!(
+ after_create > before_create,
+ "FT.CREATE should increment VECTOR_INDEXES"
+ );
+
+ // FT.SEARCH should increment VECTOR_SEARCH_TOTAL
+ crate::vector::distance::init();
+ let before_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed);
+ let query_vec: Vec = vec![0u8; 128 * 4];
+ let search_args = vec![
+ bulk(b"myidx"),
+ bulk(b"*=>[KNN 5 @vec $query]"),
+ bulk(b"PARAMS"),
+ bulk(b"2"),
+ bulk(b"query"),
+ Frame::BulkString(Bytes::from(query_vec)),
+ ];
+ ft_search(&mut store, &search_args);
+ let after_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed);
+ assert!(
+ after_search > before_search,
+ "FT.SEARCH should increment VECTOR_SEARCH_TOTAL"
+ );
+
+ // FT.DROPINDEX should decrement VECTOR_INDEXES
+ let before_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
+ ft_dropindex(&mut store, &[bulk(b"myidx")]);
+ let after_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed);
+ assert!(
+ after_drop < before_drop,
+ "FT.DROPINDEX should decrement VECTOR_INDEXES"
+ );
+}
diff --git a/src/server/conn/handler_monoio.rs b/src/server/conn/handler_monoio.rs
index b13e945d..220eb9a9 100644
--- a/src/server/conn/handler_monoio.rs
+++ b/src/server/conn/handler_monoio.rs
@@ -1414,62 +1414,92 @@ pub async fn handle_connection_sharded_monoio<
// Local shard: direct VectorStore access via shard_databases.
// Remote shards: SPSC dispatch. Works with any shard count (including 1).
if cmd.len() > 3 && cmd[..3].eq_ignore_ascii_case(b"FT.") {
- if cmd.eq_ignore_ascii_case(b"FT.SEARCH") {
- let response =
- match crate::command::vector_search::parse_ft_search_args(cmd_args) {
- Ok((index_name, query_blob, k, _filter)) => {
- crate::shard::coordinator::scatter_vector_search_remote(
- index_name,
- query_blob,
- k,
- shard_id,
- num_shards,
- &shard_databases,
- &dispatch_tx,
- &spsc_notifiers,
- )
- .await
- }
- Err(err_frame) => err_frame,
+ if num_shards > 1 {
+ // Multi-shard: dispatch via SPSC
+ if cmd.eq_ignore_ascii_case(b"FT.SEARCH") {
+ let response =
+ match crate::command::vector_search::parse_ft_search_args(cmd_args) {
+ Ok((index_name, query_blob, k, filter)) => {
+ if filter.is_some() {
+ Frame::Error(Bytes::from_static(
+ b"ERR FILTER not supported in multi-shard mode yet",
+ ))
+ } else {
+ crate::shard::coordinator::scatter_vector_search_remote(
+ index_name,
+ query_blob,
+ k,
+ shard_id,
+ num_shards,
+ &shard_databases,
+ &dispatch_tx,
+ &spsc_notifiers,
+ )
+ .await
+ }
+ }
+ Err(err_frame) => err_frame,
+ };
+ responses.push(response);
+ continue;
+ }
+ if cmd.eq_ignore_ascii_case(b"FT.CREATE")
+ || cmd.eq_ignore_ascii_case(b"FT.DROPINDEX")
+ {
+ // Broadcast to ALL shards so every shard has the index
+ let response = crate::shard::coordinator::broadcast_vector_command(
+ std::sync::Arc::new(frame),
+ shard_id,
+ num_shards,
+ &shard_databases,
+ &dispatch_tx,
+ &spsc_notifiers,
+ )
+ .await;
+ responses.push(response);
+ continue;
+ }
+ if cmd.eq_ignore_ascii_case(b"FT.INFO") {
+ let response = {
+ let vs = shard_databases.vector_store(shard_id);
+ crate::command::vector_search::ft_info(&vs, cmd_args)
};
- responses.push(response);
- continue;
- }
- if cmd.eq_ignore_ascii_case(b"FT.CREATE")
- || cmd.eq_ignore_ascii_case(b"FT.DROPINDEX")
- {
- // Broadcast to ALL shards so every shard has the index
- let response = crate::shard::coordinator::broadcast_vector_command(
- std::sync::Arc::new(frame),
- shard_id,
- num_shards,
- &shard_databases,
- &dispatch_tx,
- &spsc_notifiers,
- )
- .await;
- responses.push(response);
- continue;
- }
- if cmd.eq_ignore_ascii_case(b"FT.INFO") {
- // Read-only: local shard is sufficient
- let response = {
- let vs = shard_databases.vector_store(shard_id);
- crate::command::vector_search::ft_info(&vs, cmd_args)
- };
- responses.push(response);
+ responses.push(response);
+ continue;
+ }
+ if cmd.eq_ignore_ascii_case(b"FT.COMPACT") {
+ let response = {
+ let mut vs = shard_databases.vector_store(shard_id);
+ crate::command::vector_search::ft_compact(&mut vs, cmd_args)
+ };
+ responses.push(response);
+ continue;
+ }
+ responses.push(Frame::Error(Bytes::from_static(b"ERR unknown FT command")));
continue;
- }
- if cmd.eq_ignore_ascii_case(b"FT.COMPACT") {
+ } else {
+ // Single-shard: no SPSC channels needed.
+ // Dispatch directly to shard's VectorStore via shared access.
let response = {
- let mut vs = shard_databases.vector_store(shard_id);
- crate::command::vector_search::ft_compact(&mut vs, cmd_args)
+ let shard_databases_ref = &shard_databases;
+ let mut vs = shard_databases_ref.vector_store(shard_id);
+ if cmd.eq_ignore_ascii_case(b"FT.CREATE") {
+ crate::command::vector_search::ft_create(&mut vs, cmd_args)
+ } else if cmd.eq_ignore_ascii_case(b"FT.SEARCH") {
+ crate::command::vector_search::ft_search(&mut vs, cmd_args)
+ } else if cmd.eq_ignore_ascii_case(b"FT.DROPINDEX") {
+ crate::command::vector_search::ft_dropindex(&mut vs, cmd_args)
+ } else if cmd.eq_ignore_ascii_case(b"FT.INFO") {
+ crate::command::vector_search::ft_info(&vs, cmd_args)
+ } else if cmd.eq_ignore_ascii_case(b"FT.COMPACT") {
+ crate::command::vector_search::ft_compact(&mut vs, cmd_args)
+ } else {
+ Frame::Error(Bytes::from_static(b"ERR unknown FT.* command"))
+ }
};
responses.push(response);
continue;
}
- responses.push(Frame::Error(Bytes::from_static(b"ERR unknown FT command")));
- continue;
}
// --- Routing: keyless, local, or remote ---
diff --git a/src/server/conn/handler_sharded.rs b/src/server/conn/handler_sharded.rs
index 68e7e1c8..404643f8 100644
--- a/src/server/conn/handler_sharded.rs
+++ b/src/server/conn/handler_sharded.rs
@@ -1258,13 +1258,19 @@ pub async fn handle_connection_sharded_inner<
// Multi-shard: dispatch via SPSC
if cmd.eq_ignore_ascii_case(b"FT.SEARCH") {
let response = match crate::command::vector_search::parse_ft_search_args(cmd_args) {
- Ok((index_name, query_blob, k, _filter)) => {
- crate::shard::coordinator::scatter_vector_search_remote(
- index_name, query_blob, k,
- shard_id, num_shards,
- &shard_databases,
- &dispatch_tx, &spsc_notifiers,
- ).await
+ Ok((index_name, query_blob, k, filter)) => {
+ if filter.is_some() {
+ Frame::Error(Bytes::from_static(
+ b"ERR FILTER not supported in multi-shard mode yet",
+ ))
+ } else {
+ crate::shard::coordinator::scatter_vector_search_remote(
+ index_name, query_blob, k,
+ shard_id, num_shards,
+ &shard_databases,
+ &dispatch_tx, &spsc_notifiers,
+ ).await
+ }
}
Err(err_frame) => err_frame,
};
@@ -1359,24 +1365,6 @@ pub async fn handle_connection_sharded_inner<
DispatchResult::Response(f) => f,
DispatchResult::Quit(f) => { should_quit = true; f }
};
- // Auto-index vectors on successful HSET (local write path)
- if !matches!(response, Frame::Error(_))
- && (cmd.eq_ignore_ascii_case(b"HSET") || cmd.eq_ignore_ascii_case(b"HMSET"))
- {
- if let Some(key) = cmd_args.first().and_then(|f| extract_bytes(f)) {
- let mut vs = shard_databases.vector_store(shard_id);
- crate::shard::spsc_handler::auto_index_hset_public(&mut vs, &key, cmd_args);
- }
- }
- // Auto-delete vectors on DEL/HDEL/UNLINK (local write path)
- if !matches!(response, Frame::Error(_))
- && (cmd.eq_ignore_ascii_case(b"DEL") || cmd.eq_ignore_ascii_case(b"UNLINK") || cmd.eq_ignore_ascii_case(b"HDEL"))
- {
- if let Some(key) = cmd_args.first().and_then(|f| extract_bytes(f)) {
- let mut vs = shard_databases.vector_store(shard_id);
- vs.mark_deleted_for_key(&key);
- }
- }
if !matches!(response, Frame::Error(_)) {
let needs_wake = cmd.eq_ignore_ascii_case(b"LPUSH") || cmd.eq_ignore_ascii_case(b"RPUSH")
|| cmd.eq_ignore_ascii_case(b"LMOVE") || cmd.eq_ignore_ascii_case(b"ZADD");
@@ -1392,6 +1380,30 @@ pub async fn handle_connection_sharded_inner<
}
}
drop(guard);
+ // Auto-index vectors on successful HSET (local write path)
+ // Placed AFTER drop(guard) to avoid DB→vector_store lock order
+ // inversion with the shard event loop (vector_store→DB).
+ if !matches!(response, Frame::Error(_))
+ && (cmd.eq_ignore_ascii_case(b"HSET") || cmd.eq_ignore_ascii_case(b"HMSET"))
+ {
+ if let Some(key) = cmd_args.first().and_then(|f| extract_bytes(f)) {
+ let mut vs = shard_databases.vector_store(shard_id);
+ crate::shard::spsc_handler::auto_index_hset_public(&mut vs, &key, cmd_args);
+ }
+ }
+ // Auto-delete vectors on DEL/UNLINK (local write path)
+ // Note: HDEL removes fields, not keys — it should NOT trigger
+ // vector deletion unless the entire key is removed.
+ if !matches!(response, Frame::Error(_))
+ && (cmd.eq_ignore_ascii_case(b"DEL") || cmd.eq_ignore_ascii_case(b"UNLINK"))
+ {
+ let mut vs = shard_databases.vector_store(shard_id);
+ for arg in cmd_args.iter() {
+ if let Some(key) = extract_bytes(arg) {
+ vs.mark_deleted_for_key(key.as_ref());
+ }
+ }
+ }
if let Some(bytes) = aof_bytes {
if !matches!(response, Frame::Error(_)) {
if let Some(ref tx) = aof_tx { let _ = tx.try_send(AofMessage::Append(bytes)); }
diff --git a/src/shard/spsc_handler.rs b/src/shard/spsc_handler.rs
index a9877212..4e277ded 100644
--- a/src/shard/spsc_handler.rs
+++ b/src/shard/spsc_handler.rs
@@ -272,26 +272,17 @@ pub(crate) fn handle_shard_message_shared(
}
}
- // Auto-delete: if DEL/HDEL/UNLINK succeeded and key matches a vector
+ // Auto-delete: if DEL/UNLINK succeeded and key matches a vector
// index prefix, mark stale vectors as deleted in matching indexes.
- if (cmd.eq_ignore_ascii_case(b"DEL")
- || cmd.eq_ignore_ascii_case(b"HDEL")
- || cmd.eq_ignore_ascii_case(b"UNLINK"))
+ // Note: HDEL removes fields, not keys — it should NOT trigger vector
+ // deletion unless the entire key is removed.
+ if (cmd.eq_ignore_ascii_case(b"DEL") || cmd.eq_ignore_ascii_case(b"UNLINK"))
&& !matches!(frame, crate::protocol::Frame::Error(_))
{
- // DEL/UNLINK: args are keys (args[0], args[1], ...).
- // HDEL: args[0] is the hash key, remaining are fields.
- // For HDEL we only mark the hash key itself (the vector source).
- if cmd.eq_ignore_ascii_case(b"HDEL") {
- if let Some(crate::protocol::Frame::BulkString(key_bytes)) = args.first() {
+ for arg in args {
+ if let crate::protocol::Frame::BulkString(key_bytes) = arg {
vector_store.mark_deleted_for_key(key_bytes);
}
- } else {
- for arg in args {
- if let crate::protocol::Frame::BulkString(key_bytes) = arg {
- vector_store.mark_deleted_for_key(key_bytes);
- }
- }
}
}
diff --git a/src/vector/distance/fastscan.rs b/src/vector/distance/fastscan.rs
index 6b30d9dc..bffd7ca6 100644
--- a/src/vector/distance/fastscan.rs
+++ b/src/vector/distance/fastscan.rs
@@ -48,12 +48,16 @@ pub fn init_fastscan() {
/// Get the static FastScan dispatch table.
///
-/// # Safety contract
-/// Caller must ensure [`init_fastscan()`] has been called before first use.
+/// Auto-initializes on first use if [`init_fastscan()`] was not called explicitly.
+/// After the first call the hot path is two atomic loads (both always succeed).
#[inline(always)]
pub fn fastscan_dispatch() -> &'static FastScanDispatch {
- // SAFETY: init_fastscan() is called from distance::init() at startup.
- unsafe { FASTSCAN_DISPATCH.get().unwrap_unchecked() }
+ if FASTSCAN_DISPATCH.get().is_none() {
+ init_fastscan();
+ }
+ FASTSCAN_DISPATCH
+ .get()
+ .expect("fastscan dispatch initialized by init_fastscan()")
}
/// Scalar FastScan: compute distances for 32 vectors in one interleaved block.
diff --git a/src/vector/distance/mod.rs b/src/vector/distance/mod.rs
index 2e7e9c4f..6a5fe569 100644
--- a/src/vector/distance/mod.rs
+++ b/src/vector/distance/mod.rs
@@ -141,20 +141,17 @@ pub fn init() {
/// Returns the table initialized by [`init()`]. This is a single pointer load
/// followed by a direct function call — at most 1 cache miss per call site.
///
-/// # Safety contract
-/// Caller must ensure [`init()`] has been called before the first call to `table()`.
-/// In practice, `init()` is called from `main()` at startup.
+/// Auto-initializes on first use if [`init()`] was not called explicitly.
+/// After the first call the hot path is two atomic loads (both always succeed).
#[inline(always)]
pub fn table() -> &'static DistanceTable {
- // SAFETY: init() is called from main() at startup before any search operation.
- // The OnceLock is guaranteed to be initialized by the time any search
- // path reaches this function. Using unwrap_unchecked avoids a branch
- // on the hot path.
- debug_assert!(
- DISTANCE_TABLE.get().is_some(),
- "distance::init() was not called before table()"
- );
- unsafe { DISTANCE_TABLE.get().unwrap_unchecked() }
+ if DISTANCE_TABLE.get().is_none() {
+ init();
+ }
+ // After init(), DISTANCE_TABLE is guaranteed to be set.
+ DISTANCE_TABLE
+ .get()
+ .expect("distance table initialized by init()")
}
#[cfg(test)]
diff --git a/src/vector/hnsw/search.rs b/src/vector/hnsw/search.rs
index 796eae53..5f69eca2 100644
--- a/src/vector/hnsw/search.rs
+++ b/src/vector/hnsw/search.rs
@@ -281,13 +281,13 @@ pub fn hnsw_search_filtered(
let original_dim = query.len();
let padded_dim = q_rotated.len();
let _active_code_bytes = original_dim / 2; // nibble-packed bytes for original dim
- let entries_per_coord: usize = if use_subcent { 32 } else { 16 };
-
let sub_table = collection.sub_centroid_table.as_ref();
+ // Guard use_subcent on sub_table availability to avoid panic
+ let use_subcent = use_subcent && sub_table.is_some();
+ let entries_per_coord: usize = if use_subcent { 32 } else { 16 };
let mut adc_lut = Vec::with_capacity(padded_dim * entries_per_coord);
- if use_subcent {
- let st = sub_table.unwrap();
+ if let Some(st) = sub_table.filter(|_| use_subcent) {
for j in 0..padded_dim {
let q = q_rotated[j];
for e in 0..32 {
diff --git a/src/vector/turbo_quant/codebook.rs b/src/vector/turbo_quant/codebook.rs
index a1b4173c..c4690c89 100644
--- a/src/vector/turbo_quant/codebook.rs
+++ b/src/vector/turbo_quant/codebook.rs
@@ -128,32 +128,36 @@ pub const RAW_BOUNDARIES_3BIT: [f32; 7] = [-1.7480, -1.0500, -0.5006, 0.0, 0.500
///
/// Returns a Vec because the size varies by bit width.
/// sigma = 1/sqrt(padded_dim), matching FWHT normalization.
-pub fn scaled_centroids_n(padded_dim: u32, bits: u8) -> Vec {
+///
+/// Returns `Err` for unsupported bit widths (anything outside 1-4).
+pub fn scaled_centroids_n(padded_dim: u32, bits: u8) -> Result, &'static str> {
let sigma = 1.0 / (padded_dim as f32).sqrt();
match bits {
- 1 => RAW_CENTROIDS_1BIT.iter().map(|&c| c * sigma).collect(),
- 2 => RAW_CENTROIDS_2BIT.iter().map(|&c| c * sigma).collect(),
- 3 => RAW_CENTROIDS_3BIT.iter().map(|&c| c * sigma).collect(),
+ 1 => Ok(RAW_CENTROIDS_1BIT.iter().map(|&c| c * sigma).collect()),
+ 2 => Ok(RAW_CENTROIDS_2BIT.iter().map(|&c| c * sigma).collect()),
+ 3 => Ok(RAW_CENTROIDS_3BIT.iter().map(|&c| c * sigma).collect()),
4 => {
let sc = scaled_centroids(padded_dim);
- sc.to_vec()
+ Ok(sc.to_vec())
}
- _ => panic!("unsupported bit width: {bits}"),
+ _ => Err("unsupported bit width"),
}
}
/// Compute dimension-scaled boundaries for any bit width (1-4).
-pub fn scaled_boundaries_n(padded_dim: u32, bits: u8) -> Vec {
+///
+/// Returns `Err` for unsupported bit widths (anything outside 1-4).
+pub fn scaled_boundaries_n(padded_dim: u32, bits: u8) -> Result, &'static str> {
let sigma = 1.0 / (padded_dim as f32).sqrt();
match bits {
- 1 => RAW_BOUNDARIES_1BIT.iter().map(|&b| b * sigma).collect(),
- 2 => RAW_BOUNDARIES_2BIT.iter().map(|&b| b * sigma).collect(),
- 3 => RAW_BOUNDARIES_3BIT.iter().map(|&b| b * sigma).collect(),
+ 1 => Ok(RAW_BOUNDARIES_1BIT.iter().map(|&b| b * sigma).collect()),
+ 2 => Ok(RAW_BOUNDARIES_2BIT.iter().map(|&b| b * sigma).collect()),
+ 3 => Ok(RAW_BOUNDARIES_3BIT.iter().map(|&b| b * sigma).collect()),
4 => {
let sb = scaled_boundaries(padded_dim);
- sb.to_vec()
+ Ok(sb.to_vec())
}
- _ => panic!("unsupported bit width: {bits}"),
+ _ => Err("unsupported bit width"),
}
}
@@ -186,7 +190,10 @@ pub fn code_bytes_per_vector(padded_dim: u32, bits: u8) -> usize {
2 => pd / 4,
3 => (pd * 3 + 7) / 8,
4 => pd / 2,
- _ => panic!("unsupported bit width: {bits}"),
+ _ => {
+ tracing::error!("unsupported bit width {bits} for code_bytes_per_vector");
+ 0
+ }
}
}
@@ -377,19 +384,20 @@ mod tests {
#[test]
fn test_scaled_centroids_n_sizes() {
let pdim = 1024u32;
- assert_eq!(scaled_centroids_n(pdim, 1).len(), 2);
- assert_eq!(scaled_centroids_n(pdim, 2).len(), 4);
- assert_eq!(scaled_centroids_n(pdim, 3).len(), 8);
- assert_eq!(scaled_centroids_n(pdim, 4).len(), 16);
+ assert_eq!(scaled_centroids_n(pdim, 1).unwrap().len(), 2);
+ assert_eq!(scaled_centroids_n(pdim, 2).unwrap().len(), 4);
+ assert_eq!(scaled_centroids_n(pdim, 3).unwrap().len(), 8);
+ assert_eq!(scaled_centroids_n(pdim, 4).unwrap().len(), 16);
+ assert!(scaled_centroids_n(pdim, 5).is_err());
}
#[test]
fn test_scaled_centroids_n_values() {
let pdim = 1024u32;
let sigma = 1.0 / (pdim as f32).sqrt();
- let c1 = scaled_centroids_n(pdim, 1);
+ let c1 = scaled_centroids_n(pdim, 1).unwrap();
assert!((c1[1] - 0.7979 * sigma).abs() < 1e-6);
- let c2 = scaled_centroids_n(pdim, 2);
+ let c2 = scaled_centroids_n(pdim, 2).unwrap();
assert!((c2[3] - 1.5104 * sigma).abs() < 1e-5);
}
diff --git a/src/vector/turbo_quant/collection.rs b/src/vector/turbo_quant/collection.rs
index 39450c84..f1c91e58 100644
--- a/src/vector/turbo_quant/collection.rs
+++ b/src/vector/turbo_quant/collection.rs
@@ -218,13 +218,17 @@ impl CollectionMetadata {
fwht_sign_flips: sign_flips,
codebook_version: CODEBOOK_VERSION,
codebook: if quantization.is_turbo_quant() {
+ // Fail fast on invalid bit width — this is a programming invariant,
+ // not user input. Valid bit widths (1-4) are guaranteed by QuantizationConfig.
scaled_centroids_n(padded, quantization.bits())
+ .expect("codebook centroids: invalid bit width is a programming bug")
} else {
// SQ8 doesn't use codebooks -- store empty Vec
Vec::new()
},
codebook_boundaries: if quantization.is_turbo_quant() {
scaled_boundaries_n(padded, quantization.bits())
+ .expect("codebook boundaries: invalid bit width is a programming bug")
} else {
Vec::new()
},
@@ -276,32 +280,56 @@ impl CollectionMetadata {
code_bytes_per_vector(self.padded_dimension, self.quantization.bits())
}
- /// Convenience accessor: returns the codebook boundaries as a `&[f32; 15]` reference.
+ /// Returns the codebook boundaries as a `&[f32; 15]` reference.
///
- /// Panics if quantization is not 4-bit (only valid for TurboQuant4 / TurboQuantProd4).
- /// Used by legacy `encode_tq_mse_scaled` which requires fixed-size array.
+ /// Only valid for 4-bit quantization (TurboQuant4 / TurboQuantProd4).
+ /// The codebook is guaranteed to have exactly 15 boundaries at construction
+ /// for 4-bit configs. If the invariant is violated (programming bug), logs
+ /// an error and returns a zeroed fallback to avoid panicking in production.
pub fn codebook_boundaries_15(&self) -> &[f32; 15] {
- assert_eq!(
+ debug_assert_eq!(
self.codebook_boundaries.len(),
15,
- "codebook_boundaries_15 requires 4-bit quantization (15 boundaries), got {}",
+ "codebook_boundaries_15 called on non-4-bit quantization (len={})",
self.codebook_boundaries.len()
);
- self.codebook_boundaries[..15].try_into().unwrap()
+ match self.codebook_boundaries.as_slice().try_into() {
+ Ok(arr) => arr,
+ Err(_) => {
+ tracing::error!(
+ "codebook_boundaries has {} entries, expected 15 — construction invariant violated",
+ self.codebook_boundaries.len()
+ );
+ static ZERO: [f32; 15] = [0.0; 15];
+ &ZERO
+ }
+ }
}
- /// Convenience accessor: returns the codebook as a `&[f32; 16]` reference.
+ /// Returns the codebook as a `&[f32; 16]` reference.
///
- /// Panics if quantization is not 4-bit (only valid for TurboQuant4 / TurboQuantProd4).
- /// Used by legacy `tq_l2_adc_scaled` which requires fixed-size array.
+ /// Only valid for 4-bit quantization (TurboQuant4 / TurboQuantProd4).
+ /// The codebook is guaranteed to have exactly 16 centroids at construction
+ /// for 4-bit configs. If the invariant is violated (programming bug), logs
+ /// an error and returns a zeroed fallback to avoid panicking in production.
pub fn codebook_16(&self) -> &[f32; 16] {
- assert_eq!(
+ debug_assert_eq!(
self.codebook.len(),
16,
- "codebook_16 requires 4-bit quantization (16 centroids), got {}",
+ "codebook_16 called on non-4-bit quantization (len={})",
self.codebook.len()
);
- self.codebook[..16].try_into().unwrap()
+ match self.codebook.as_slice().try_into() {
+ Ok(arr) => arr,
+ Err(_) => {
+ tracing::error!(
+ "codebook has {} entries, expected 16 — construction invariant violated",
+ self.codebook.len()
+ );
+ static ZERO: [f32; 16] = [0.0; 16];
+ &ZERO
+ }
+ }
}
/// Verify metadata integrity. Returns Err if checksum mismatch.
diff --git a/src/vector/turbo_quant/encoder.rs b/src/vector/turbo_quant/encoder.rs
index 1d982fa4..69bdbe4d 100644
--- a/src/vector/turbo_quant/encoder.rs
+++ b/src/vector/turbo_quant/encoder.rs
@@ -782,7 +782,7 @@ mod tests {
normalize_to_unit(&mut v);
for bits in [1u8, 2, 3, 4] {
- let boundaries = scaled_boundaries_n(padded, bits);
+ let boundaries = scaled_boundaries_n(padded, bits).unwrap();
let code = encode_tq_mse_multibit(&v, &signs, &boundaries, bits, &mut work);
let expected = code_bytes_per_vector(padded, bits);
assert_eq!(
@@ -794,15 +794,15 @@ mod tests {
}
// Specific sizes for 768d (padded=1024)
- let b1 = scaled_boundaries_n(padded, 1);
+ let b1 = scaled_boundaries_n(padded, 1).unwrap();
let c1 = encode_tq_mse_multibit(&v, &signs, &b1, 1, &mut work);
assert_eq!(c1.codes.len(), 128); // 1024/8
- let b2 = scaled_boundaries_n(padded, 2);
+ let b2 = scaled_boundaries_n(padded, 2).unwrap();
let c2 = encode_tq_mse_multibit(&v, &signs, &b2, 2, &mut work);
assert_eq!(c2.codes.len(), 256); // 1024/4
- let b3 = scaled_boundaries_n(padded, 3);
+ let b3 = scaled_boundaries_n(padded, 3).unwrap();
let c3 = encode_tq_mse_multibit(&v, &signs, &b3, 3, &mut work);
assert_eq!(c3.codes.len(), 384); // 1024*3/8
}
@@ -813,8 +813,8 @@ mod tests {
let dim = 768;
let padded = padded_dimension(dim as u32);
let signs = test_sign_flips(padded as usize, 12345);
- let boundaries = scaled_boundaries_n(padded, 1);
- let centroids = scaled_centroids_n(padded, 1);
+ let boundaries = scaled_boundaries_n(padded, 1).unwrap();
+ let centroids = scaled_centroids_n(padded, 1).unwrap();
let mut work_enc = vec![0.0f32; padded as usize];
let mut work_dec = vec![0.0f32; padded as usize];
@@ -839,8 +839,8 @@ mod tests {
let dim = 768;
let padded = padded_dimension(dim as u32);
let signs = test_sign_flips(padded as usize, 12345);
- let boundaries = scaled_boundaries_n(padded, 2);
- let centroids = scaled_centroids_n(padded, 2);
+ let boundaries = scaled_boundaries_n(padded, 2).unwrap();
+ let centroids = scaled_centroids_n(padded, 2).unwrap();
let mut work_enc = vec![0.0f32; padded as usize];
let mut work_dec = vec![0.0f32; padded as usize];
@@ -864,8 +864,8 @@ mod tests {
let dim = 768;
let padded = padded_dimension(dim as u32);
let signs = test_sign_flips(padded as usize, 12345);
- let boundaries = scaled_boundaries_n(padded, 3);
- let centroids = scaled_centroids_n(padded, 3);
+ let boundaries = scaled_boundaries_n(padded, 3).unwrap();
+ let centroids = scaled_centroids_n(padded, 3).unwrap();
let mut work_enc = vec![0.0f32; padded as usize];
let mut work_dec = vec![0.0f32; padded as usize];
diff --git a/src/vector/turbo_quant/inner_product.rs b/src/vector/turbo_quant/inner_product.rs
index 5504de0e..49b99821 100644
--- a/src/vector/turbo_quant/inner_product.rs
+++ b/src/vector/turbo_quant/inner_product.rs
@@ -608,9 +608,9 @@ mod tests {
// v2: 3-bit MSE + QJL signs (paper-correct)
let boundaries_3 =
- crate::vector::turbo_quant::codebook::scaled_boundaries_n(padded as u32, 3);
+ crate::vector::turbo_quant::codebook::scaled_boundaries_n(padded as u32, 3).unwrap();
let centroids_3 =
- crate::vector::turbo_quant::codebook::scaled_centroids_n(padded as u32, 3);
+ crate::vector::turbo_quant::codebook::scaled_centroids_n(padded as u32, 3).unwrap();
let code_v2 = encode_tq_prod_v2(
&vec,
&sign_flips,
diff --git a/src/vector/turbo_quant/tq_adc.rs b/src/vector/turbo_quant/tq_adc.rs
index acb124c9..83c4de38 100644
--- a/src/vector/turbo_quant/tq_adc.rs
+++ b/src/vector/turbo_quant/tq_adc.rs
@@ -317,15 +317,21 @@ pub fn tq_l2_adc_multibit(
4 => {
// Delegate to existing optimized 4-bit path
debug_assert_eq!(centroids.len(), 16);
- let c: &[f32; 16] = centroids.try_into().unwrap_or_else(|_| {
- panic!(
+ if let Ok(c) = centroids.try_into() {
+ tq_l2_adc_scaled(q_rotated, code, norm, c)
+ } else {
+ // Invariant violated — return max distance rather than panic
+ tracing::error!(
"4-bit ADC requires exactly 16 centroids, got {}",
centroids.len()
- )
- });
- tq_l2_adc_scaled(q_rotated, code, norm, c)
+ );
+ f32::MAX
+ }
+ }
+ _ => {
+ tracing::error!("unsupported bit width: {bits}");
+ f32::MAX
}
- _ => panic!("unsupported bit width: {bits}"),
}
}
@@ -593,7 +599,7 @@ pub fn brute_force_tq_adc_multibit(
}
let mut results: Vec<(f32, u32)> = heap.into_iter().map(|(d, id)| (d.0, id)).collect();
- results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ results.sort_by(|a, b| a.0.total_cmp(&b.0));
results
.into_iter()
@@ -673,7 +679,7 @@ pub fn brute_force_tq_adc(
// Extract sorted results
let mut results: Vec<(f32, u32)> = heap.into_iter().map(|(d, id)| (d.0, id)).collect();
- results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
+ results.sort_by(|a, b| a.0.total_cmp(&b.0));
results
.into_iter()
@@ -1033,8 +1039,8 @@ mod tests {
let dim = 768;
let padded = padded_dimension(dim as u32) as usize;
let signs = test_sign_flips(padded, 42);
- let boundaries = scaled_boundaries_n(padded as u32, 1);
- let centroids = scaled_centroids_n(padded as u32, 1);
+ let boundaries = scaled_boundaries_n(padded as u32, 1).unwrap();
+ let centroids = scaled_centroids_n(padded as u32, 1).unwrap();
let mut work = vec![0.0f32; padded];
let mut v = lcg_f32(dim, 99);
@@ -1062,8 +1068,8 @@ mod tests {
let dim = 768;
let padded = padded_dimension(dim as u32) as usize;
let signs = test_sign_flips(padded, 42);
- let boundaries = scaled_boundaries_n(padded as u32, 2);
- let centroids = scaled_centroids_n(padded as u32, 2);
+ let boundaries = scaled_boundaries_n(padded as u32, 2).unwrap();
+ let centroids = scaled_centroids_n(padded as u32, 2).unwrap();
let mut work = vec![0.0f32; padded];
let mut v = lcg_f32(dim, 99);
@@ -1090,8 +1096,8 @@ mod tests {
let dim = 768;
let padded = padded_dimension(dim as u32) as usize;
let signs = test_sign_flips(padded, 42);
- let boundaries = scaled_boundaries_n(padded as u32, 3);
- let centroids = scaled_centroids_n(padded as u32, 3);
+ let boundaries = scaled_boundaries_n(padded as u32, 3).unwrap();
+ let centroids = scaled_centroids_n(padded as u32, 3).unwrap();
let mut work = vec![0.0f32; padded];
let mut v = lcg_f32(dim, 99);
@@ -1120,8 +1126,8 @@ mod tests {
let signs = test_sign_flips(padded, 42);
for bits in [1u8, 2, 3] {
- let boundaries = scaled_boundaries_n(padded as u32, bits);
- let centroids = scaled_centroids_n(padded as u32, bits);
+ let boundaries = scaled_boundaries_n(padded as u32, bits).unwrap();
+ let centroids = scaled_centroids_n(padded as u32, bits).unwrap();
let mut work_enc = vec![0.0f32; padded];
let mut work_dec = vec![0.0f32; padded];
@@ -1197,8 +1203,8 @@ mod tests {
let signs = test_sign_flips(padded, 42);
for bits in [1u8, 2, 3] {
- let boundaries = scaled_boundaries_n(padded as u32, bits);
- let centroids = scaled_centroids_n(padded as u32, bits);
+ let boundaries = scaled_boundaries_n(padded as u32, bits).unwrap();
+ let centroids = scaled_centroids_n(padded as u32, bits).unwrap();
let mut work = vec![0.0f32; padded];
let mut v = lcg_f32(dim, 99);