diff --git a/scripts/test-commands.sh b/scripts/test-commands.sh index 97e4cffa..c93c8100 100755 --- a/scripts/test-commands.sh +++ b/scripts/test-commands.sh @@ -63,6 +63,7 @@ while [[ $# -gt 0 ]]; do echo " pubsub - Pub/Sub commands (SUBSCRIBE, PUBLISH, etc.)" echo " transaction - Transaction commands (MULTI, EXEC, DISCARD)" echo " scripting - Lua scripting (EVAL, EVALSHA)" + echo " vector - Vector search commands (FT.CREATE, FT.SEARCH, FT.INFO, FT.DROPINDEX)" echo " persistence - Persistence commands (BGSAVE, BGREWRITEAOF, etc.)" echo " blocking - Blocking commands (BLPOP, BRPOP, BZPOPMIN, etc.)" echo " benchmark - redis-benchmark throughput for all benchmarkable commands" @@ -665,6 +666,42 @@ fi # PERSISTENCE COMMANDS # =========================================================================== +# =========================================================================== +# VECTOR SEARCH COMMANDS (moon-only — Redis uses different syntax) +# =========================================================================== + +if should_run "vector"; then + echo "" + echo "=== VECTOR SEARCH COMMANDS ===" + mcli FLUSHALL >/dev/null 2>&1 + + # FT.CREATE — create a vector index + assert_moon "FT.CREATE basic" "OK" FT.CREATE myidx ON HASH PREFIX 1 doc: SCHEMA embedding VECTOR FLAT 6 DIM 4 DISTANCE_METRIC L2 TYPE FLOAT32 + + # FT.INFO — index metadata + TOTAL=$((TOTAL + 1)); FT_INFO=$(mcli FT.INFO myidx 2>&1) + if echo "$FT_INFO" | grep -q "myidx"; then PASS=$((PASS + 1)); echo " PASS: FT.INFO returns index name"; else FAIL=$((FAIL + 1)); echo " FAIL: FT.INFO returns index name"; fi + + # Insert vectors via HSET (auto-indexed) — use python3 to avoid null byte stripping in bash + python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',1.0,0.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET doc:1 embedding >/dev/null 2>&1 + python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',0.0,1.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET doc:2 embedding >/dev/null 2>&1 + + # FT.SEARCH — verify command doesn't error (redis-cli can't pass binary args directly) + TOTAL=$((TOTAL + 1)); FT_SEARCH=$(mcli FT.SEARCH myidx "*" 2>&1) + if ! echo "$FT_SEARCH" | grep -qi "err"; then PASS=$((PASS + 1)); echo " PASS: FT.SEARCH does not error"; else FAIL=$((FAIL + 1)); echo " FAIL: FT.SEARCH returned error"; fi + + # FT.DROPINDEX — remove index + assert_moon "FT.DROPINDEX" "OK" FT.DROPINDEX myidx + + # FT.INFO after drop should error + TOTAL=$((TOTAL + 1)); FT_INFO_AFTER=$(mcli FT.INFO myidx 2>&1) + if echo "$FT_INFO_AFTER" | grep -qi "err\|not found"; then PASS=$((PASS + 1)); echo " PASS: FT.INFO after drop errors"; else FAIL=$((FAIL + 1)); echo " FAIL: FT.INFO after drop errors"; fi +fi + +# =========================================================================== +# PERSISTENCE COMMANDS +# =========================================================================== + if should_run "persistence"; then echo "" echo "=== PERSISTENCE COMMANDS ===" diff --git a/scripts/test-consistency.sh b/scripts/test-consistency.sh index 3ca8873a..f854cb50 100755 --- a/scripts/test-consistency.sh +++ b/scripts/test-consistency.sh @@ -481,6 +481,31 @@ assert_both "GET with 500-char key" GET "$LONGKEY" # =========================================================================== echo "" +# =========================================================================== +# Vector Search (moon-only — FT.* not available in Redis) +# =========================================================================== +log "=== Vector Search (moon-only) ===" + +# Create index on moon only +FT_CREATE=$(redis-cli -p "$PORT_RUST" FT.CREATE vecidx ON HASH PREFIX 1 vec: SCHEMA embedding VECTOR FLAT 6 DIM 4 DISTANCE_METRIC L2 TYPE FLOAT32 2>&1) +assert_eq "FT.CREATE" "OK" "$FT_CREATE" + +# Insert vectors — use python3 to avoid null byte stripping in bash +python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',1.0,0.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET vec:1 embedding >/dev/null 2>&1 +python3 -c "import struct,sys; sys.stdout.buffer.write(struct.pack('<4f',0.0,1.0,0.0,0.0))" | redis-cli -x -p "$PORT_RUST" HSET vec:2 embedding >/dev/null 2>&1 + +# FT.INFO should show index +FT_INFO=$(redis-cli -p "$PORT_RUST" FT.INFO vecidx 2>&1) +if echo "$FT_INFO" | grep -q "vecidx"; then + PASS=$((PASS + 1)) +else + FAIL=$((FAIL + 1)); echo " FAIL: FT.INFO should show vecidx" +fi + +# FT.DROPINDEX +FT_DROP=$(redis-cli -p "$PORT_RUST" FT.DROPINDEX vecidx 2>&1) +assert_eq "FT.DROPINDEX" "OK" "$FT_DROP" + echo "============================================" echo " Data Consistency Test Results" echo "============================================" diff --git a/src/command/vector_search.rs b/src/command/vector_search/mod.rs similarity index 53% rename from src/command/vector_search.rs rename to src/command/vector_search/mod.rs index 23684046..8f5e334c 100644 --- a/src/command/vector_search.rs +++ b/src/command/vector_search/mod.rs @@ -212,7 +212,10 @@ pub fn ft_create(store: &mut VectorStore, args: &[Frame]) -> Frame { } let dim = match dimension { - Some(d) if d > 0 => d, + Some(d) if d > 0 && d <= 65536 => d, + Some(d) if d > 65536 => { + return Frame::Error(Bytes::from_static(b"ERR DIM must be between 1 and 65536")); + } _ => return Frame::Error(Bytes::from_static(b"ERR DIM is required and must be > 0")), }; @@ -236,7 +239,12 @@ pub fn ft_create(store: &mut VectorStore, args: &[Frame]) -> Frame { crate::vector::metrics::increment_indexes(); Frame::SimpleString(Bytes::from_static(b"OK")) } - Err(msg) => Frame::Error(Bytes::from(format!("ERR {msg}"))), + Err(msg) => { + let mut buf = Vec::with_capacity(4 + msg.len()); + buf.extend_from_slice(b"ERR "); + buf.extend_from_slice(msg.as_bytes()); + Frame::Error(Bytes::from(buf)) + } } } @@ -304,15 +312,26 @@ pub fn ft_info(store: &VectorStore, args: &[Frame]) -> Frame { let snap = idx.segments.load(); let num_docs = snap.mutable.len(); - let ef_rt_str = if idx.meta.hnsw_ef_runtime > 0 { - format!("{}", idx.meta.hnsw_ef_runtime) + // Use itoa for numeric formatting — no format!() on hot path. + let ef_rt_bytes: Bytes = if idx.meta.hnsw_ef_runtime > 0 { + let mut buf = itoa::Buffer::new(); + Bytes::copy_from_slice(buf.format(idx.meta.hnsw_ef_runtime).as_bytes()) } else { - "auto".to_string() + Bytes::from_static(b"auto") }; - let ct_str = if idx.meta.compact_threshold > 0 { - format!("{}", idx.meta.compact_threshold) + let ct_bytes: Bytes = if idx.meta.compact_threshold > 0 { + let mut buf = itoa::Buffer::new(); + Bytes::copy_from_slice(buf.format(idx.meta.compact_threshold).as_bytes()) } else { - "1000".to_string() + Bytes::from_static(b"1000") + }; + let quant_bytes: Bytes = match idx.meta.quantization { + QuantizationConfig::Sq8 => Bytes::from_static(b"SQ8"), + QuantizationConfig::TurboQuant4 => Bytes::from_static(b"TurboQuant4"), + QuantizationConfig::TurboQuantProd4 => Bytes::from_static(b"TurboQuantProd4"), + QuantizationConfig::TurboQuant1 => Bytes::from_static(b"TurboQuant1"), + QuantizationConfig::TurboQuant2 => Bytes::from_static(b"TurboQuant2"), + QuantizationConfig::TurboQuant3 => Bytes::from_static(b"TurboQuant3"), }; let items = vec![ @@ -337,11 +356,11 @@ pub fn ft_info(store: &VectorStore, args: &[Frame]) -> Frame { Frame::BulkString(Bytes::from_static(b"EF_CONSTRUCTION")), Frame::Integer(idx.meta.hnsw_ef_construction as i64), Frame::BulkString(Bytes::from_static(b"EF_RUNTIME")), - Frame::BulkString(Bytes::from(ef_rt_str)), + Frame::BulkString(ef_rt_bytes), Frame::BulkString(Bytes::from_static(b"COMPACT_THRESHOLD")), - Frame::BulkString(Bytes::from(ct_str)), + Frame::BulkString(ct_bytes), Frame::BulkString(Bytes::from_static(b"QUANTIZATION")), - Frame::BulkString(Bytes::from(format!("{:?}", idx.meta.quantization))), + Frame::BulkString(quant_bytes), ]; Frame::Array(items.into()) } @@ -399,10 +418,8 @@ pub fn ft_search(store: &mut VectorStore, args: &[Frame]) -> Frame { // Parse optional FILTER clause let filter_expr = parse_filter_clause(args); - let start = std::time::Instant::now(); let result = search_local_filtered(store, &index_name, &query_blob, k, filter_expr.as_ref()); crate::vector::metrics::increment_search(); - crate::vector::metrics::record_search_latency(start.elapsed().as_micros() as u64); result } @@ -558,8 +575,11 @@ fn build_search_response(results: &SmallVec<[SearchResult; 32]>) -> Frame { doc_id.extend_from_slice(id_str.as_bytes()); items.push(Frame::BulkString(Bytes::from(doc_id))); - // Score as nested array (format! acceptable -- end of command path) - let score_str = format!("{}", r.distance); + // Score as nested array — use write! to pre-allocated buffer + let mut score_buf = String::with_capacity(16); + use std::fmt::Write; + let _ = write!(score_buf, "{}", r.distance); + let score_str = score_buf; let fields = vec![ Frame::BulkString(Bytes::from_static(b"__vec_score")), Frame::BulkString(Bytes::from(score_str)), @@ -835,692 +855,4 @@ fn metric_to_bytes(m: DistanceMetric) -> Bytes { } #[cfg(test)] -mod tests { - use super::*; - - fn bulk(s: &[u8]) -> Frame { - Frame::BulkString(Bytes::from(s.to_vec())) - } - - /// Build a valid FT.CREATE argument list. - fn ft_create_args() -> Vec { - vec![ - bulk(b"myidx"), // index name - bulk(b"ON"), - bulk(b"HASH"), - bulk(b"PREFIX"), - bulk(b"1"), - bulk(b"doc:"), - bulk(b"SCHEMA"), - bulk(b"vec"), - bulk(b"VECTOR"), - bulk(b"HNSW"), - bulk(b"6"), // 6 params = 3 key-value pairs - bulk(b"TYPE"), - bulk(b"FLOAT32"), - bulk(b"DIM"), - bulk(b"128"), - bulk(b"DISTANCE_METRIC"), - bulk(b"L2"), - ] - } - - #[test] - fn test_ft_create_parse_full_syntax() { - let mut store = VectorStore::new(); - let args = ft_create_args(); - let result = ft_create(&mut store, &args); - match &result { - Frame::SimpleString(s) => assert_eq!(&s[..], b"OK"), - other => panic!("expected OK, got {other:?}"), - } - assert_eq!(store.len(), 1); - let idx = store.get_index(b"myidx").unwrap(); - assert_eq!(idx.meta.dimension, 128); - assert_eq!(idx.meta.metric, DistanceMetric::L2); - assert_eq!(idx.meta.key_prefixes.len(), 1); - assert_eq!(&idx.meta.key_prefixes[0][..], b"doc:"); - } - - #[test] - fn test_ft_create_missing_dim() { - let mut store = VectorStore::new(); - // Remove DIM param pair: keep TYPE FLOAT32 and DISTANCE_METRIC L2 (4 params = 2 pairs) - let args = vec![ - bulk(b"myidx"), - bulk(b"ON"), - bulk(b"HASH"), - bulk(b"PREFIX"), - bulk(b"1"), - bulk(b"doc:"), - bulk(b"SCHEMA"), - bulk(b"vec"), - bulk(b"VECTOR"), - bulk(b"HNSW"), - bulk(b"4"), // 4 params = 2 key-value pairs - bulk(b"TYPE"), - bulk(b"FLOAT32"), - bulk(b"DISTANCE_METRIC"), - bulk(b"L2"), - ]; - let result = ft_create(&mut store, &args); - match &result { - Frame::Error(_) => {} // expected - other => panic!("expected error, got {other:?}"), - } - } - - #[test] - fn test_ft_create_duplicate() { - let mut store = VectorStore::new(); - let args = ft_create_args(); - let r1 = ft_create(&mut store, &args); - assert!(matches!(r1, Frame::SimpleString(_))); - - let args2 = ft_create_args(); - let r2 = ft_create(&mut store, &args2); - match &r2 { - Frame::Error(e) => assert!(e.starts_with(b"ERR")), - other => panic!("expected error, got {other:?}"), - } - } - - #[test] - fn test_ft_dropindex() { - let mut store = VectorStore::new(); - let args = ft_create_args(); - ft_create(&mut store, &args); - - // Drop existing - let result = ft_dropindex(&mut store, &[bulk(b"myidx")]); - assert!(matches!(result, Frame::SimpleString(_))); - assert!(store.is_empty()); - - // Drop non-existing - let result = ft_dropindex(&mut store, &[bulk(b"myidx")]); - assert!(matches!(result, Frame::Error(_))); - } - - #[test] - fn test_parse_knn_query() { - let query = b"*=>[KNN 10 @vec $query]"; - let (k, param) = parse_knn_query(query).unwrap(); - assert_eq!(k, 10); - assert_eq!(¶m[..], b"query"); - } - - #[test] - fn test_parse_knn_query_different_k() { - let query = b"*=>[KNN 5 @embedding $blob]"; - let (k, param) = parse_knn_query(query).unwrap(); - assert_eq!(k, 5); - assert_eq!(¶m[..], b"blob"); - } - - #[test] - fn test_parse_knn_query_invalid() { - assert!(parse_knn_query(b"*").is_none()); - assert!(parse_knn_query(b"*=>[NOTAKNN]").is_none()); - } - - #[test] - fn test_extract_param_blob() { - let args = vec![ - bulk(b"idx"), - bulk(b"*=>[KNN 10 @vec $query]"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"query"), - bulk(b"blobdata"), - ]; - let blob = extract_param_blob(&args, b"query").unwrap(); - assert_eq!(&blob[..], b"blobdata"); - } - - #[test] - fn test_extract_param_blob_missing() { - let args = vec![bulk(b"idx"), bulk(b"*=>[KNN 10 @vec $query]")]; - assert!(extract_param_blob(&args, b"query").is_none()); - } - - #[test] - fn test_quantize_f32_to_sq() { - let input = [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0]; - let mut output = [0i8; 7]; - quantize_f32_to_sq(&input, &mut output); - assert_eq!(output[0], 0); // 0.0 -> 0 - assert_eq!(output[1], 127); // 1.0 -> 127 - assert_eq!(output[2], -127); // -1.0 -> -127 - assert_eq!(output[3], 63); // 0.5 -> 63 (truncated from 63.5) - assert_eq!(output[4], -63); // -0.5 -> -63 - assert_eq!(output[5], 127); // 2.0 clamped to 1.0 -> 127 - assert_eq!(output[6], -127); // -2.0 clamped to -1.0 -> -127 - } - - #[test] - fn test_merge_search_results_combines_shards() { - // Shard 0 returns: [2, "vec:0", ["__vec_score", "0.1"], "vec:1", ["__vec_score", "0.5"]] - // Shard 1 returns: [2, "vec:10", ["__vec_score", "0.3"], "vec:11", ["__vec_score", "0.9"]] - // Global top-2 should be: vec:0 (0.1), vec:10 (0.3) - - let shard0 = Frame::Array( - vec![ - Frame::Integer(2), - bulk(b"vec:0"), - Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.1")].into()), - bulk(b"vec:1"), - Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.5")].into()), - ] - .into(), - ); - - let shard1 = Frame::Array( - vec![ - Frame::Integer(2), - bulk(b"vec:10"), - Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.3")].into()), - bulk(b"vec:11"), - Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.9")].into()), - ] - .into(), - ); - - let result = merge_search_results(&[shard0, shard1], 2); - match result { - Frame::Array(items) => { - assert_eq!(items[0], Frame::Integer(2)); - assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:0"))); - assert_eq!(items[3], Frame::BulkString(Bytes::from("vec:10"))); - } - other => panic!("expected Array, got {other:?}"), - } - } - - #[test] - fn test_merge_search_results_handles_errors() { - // One shard returns error, one returns valid results - let shard0 = Frame::Error(Bytes::from_static(b"ERR shard unavailable")); - let shard1 = Frame::Array( - vec![ - Frame::Integer(1), - bulk(b"vec:5"), - Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.2")].into()), - ] - .into(), - ); - - let result = merge_search_results(&[shard0, shard1], 5); - match result { - Frame::Array(items) => { - assert_eq!(items[0], Frame::Integer(1)); - assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:5"))); - } - other => panic!("expected Array, got {other:?}"), - } - } - - #[test] - fn test_merge_search_results_empty() { - // No results from any shard - let shard0 = Frame::Array(vec![Frame::Integer(0)].into()); - let shard1 = Frame::Array(vec![Frame::Integer(0)].into()); - - let result = merge_search_results(&[shard0, shard1], 10); - match result { - Frame::Array(items) => { - assert_eq!(items.len(), 1); - assert_eq!(items[0], Frame::Integer(0)); - } - other => panic!("expected Array, got {other:?}"), - } - } - - #[test] - fn test_ft_search_dimension_mismatch() { - let mut store = VectorStore::new(); - let args = ft_create_args(); - ft_create(&mut store, &args); - - // Build a query with wrong dimension (4 bytes instead of 128*4) - let search_args = vec![ - bulk(b"myidx"), - bulk(b"*=>[KNN 10 @vec $query]"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"query"), - bulk(b"tooshort"), - ]; - let result = ft_search(&mut store, &search_args); - match &result { - Frame::Error(e) => assert!( - e.starts_with(b"ERR query vector dimension"), - "expected dimension mismatch error, got {:?}", - std::str::from_utf8(e) - ), - other => panic!("expected error, got {other:?}"), - } - } - - #[test] - fn test_ft_search_empty_index() { - let mut store = VectorStore::new(); - let args = ft_create_args(); - ft_create(&mut store, &args); - - // Build valid query for dim=128 - let query_vec: Vec = vec![0u8; 128 * 4]; // 128 floats, all zero - let search_args = vec![ - bulk(b"myidx"), - bulk(b"*=>[KNN 5 @vec $query]"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"query"), - Frame::BulkString(Bytes::from(query_vec)), - ]; - crate::vector::distance::init(); - let result = ft_search(&mut store, &search_args); - match result { - Frame::Array(items) => { - assert_eq!(items[0], Frame::Integer(0)); // no results - } - other => panic!("expected Array, got {other:?}"), - } - } - - #[test] - fn test_ft_info() { - let mut store = VectorStore::new(); - let args = ft_create_args(); - ft_create(&mut store, &args); - - let result = ft_info(&store, &[bulk(b"myidx")]); - match result { - Frame::Array(items) => { - // Should have 20 items (10 key-value pairs) - assert!( - items.len() >= 20, - "FT.INFO should return at least 20 items, got {}", - items.len() - ); - assert_eq!( - items[0], - Frame::BulkString(Bytes::from_static(b"index_name")) - ); - assert_eq!(items[1], Frame::BulkString(Bytes::from("myidx"))); - assert_eq!(items[5], Frame::Integer(0)); // num_docs = 0 - assert_eq!(items[7], Frame::Integer(128)); // dimension - // New fields - assert_eq!(items[10], Frame::BulkString(Bytes::from_static(b"M"))); - assert_eq!(items[11], Frame::Integer(16)); // default M - assert_eq!( - items[14], - Frame::BulkString(Bytes::from_static(b"EF_RUNTIME")) - ); - } - other => panic!("expected Array, got {other:?}"), - } - - // Non-existing index - let result = ft_info(&store, &[bulk(b"nonexistent")]); - assert!(matches!(result, Frame::Error(_))); - } - - /// Helper to build FT.CREATE args with custom parameters. - fn build_ft_create_args( - name: &str, - prefix: &str, - field: &str, - dim: u32, - metric: &str, - ) -> Vec { - vec![ - Frame::BulkString(Bytes::from(name.to_owned())), - Frame::BulkString(Bytes::from_static(b"ON")), - Frame::BulkString(Bytes::from_static(b"HASH")), - Frame::BulkString(Bytes::from_static(b"PREFIX")), - Frame::BulkString(Bytes::from_static(b"1")), - Frame::BulkString(Bytes::from(prefix.to_owned())), - Frame::BulkString(Bytes::from_static(b"SCHEMA")), - Frame::BulkString(Bytes::from(field.to_owned())), - Frame::BulkString(Bytes::from_static(b"VECTOR")), - Frame::BulkString(Bytes::from_static(b"HNSW")), - Frame::BulkString(Bytes::from_static(b"6")), - Frame::BulkString(Bytes::from_static(b"TYPE")), - Frame::BulkString(Bytes::from_static(b"FLOAT32")), - Frame::BulkString(Bytes::from_static(b"DIM")), - Frame::BulkString(Bytes::from(dim.to_string())), - Frame::BulkString(Bytes::from_static(b"DISTANCE_METRIC")), - Frame::BulkString(Bytes::from(metric.to_owned())), - ] - } - - #[test] - fn test_end_to_end_create_insert_search() { - // Initialize distance functions (required before any search) - crate::vector::distance::init(); - - let mut store = VectorStore::new(); - let dim: usize = 4; - - // 1. FT.CREATE - let create_args = build_ft_create_args("e2eidx", "doc:", "embedding", dim as u32, "L2"); - let result = ft_create(&mut store, &create_args); - assert!( - matches!(result, Frame::SimpleString(_)), - "FT.CREATE should return OK, got {result:?}" - ); - - // 2. Insert vectors directly into the mutable segment - let idx = store.get_index_mut(b"e2eidx").unwrap(); - let vectors: Vec<[f32; 4]> = vec![ - [1.0, 0.0, 0.0, 0.0], // vec:0 -- exact match for query (L2=0) - [-1.0, 0.0, 0.0, 0.0], // vec:1 -- opposite direction (L2=4.0) - [0.5, 0.0, 0.0, 0.0], // vec:2 -- same direction, half magnitude (L2=0.25) - ]; - - let snap = idx.segments.load(); - for (i, v) in vectors.iter().enumerate() { - let mut sq = vec![0i8; dim]; - quantize_f32_to_sq(v, &mut sq); - let norm = v.iter().map(|x| x * x).sum::().sqrt(); - snap.mutable.append(i as u64, v, &sq, norm, i as u64); - } - drop(snap); - - // 3. FT.SEARCH for vector close to [1.0, 0.0, 0.0, 0.0] - let query_vec: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; - let query_blob: Vec = query_vec.iter().flat_map(|f| f.to_le_bytes()).collect(); - - let search_args = vec![ - Frame::BulkString(Bytes::from_static(b"e2eidx")), - Frame::BulkString(Bytes::from_static(b"*=>[KNN 2 @embedding $query]")), - Frame::BulkString(Bytes::from_static(b"PARAMS")), - Frame::BulkString(Bytes::from_static(b"2")), - Frame::BulkString(Bytes::from_static(b"query")), - Frame::BulkString(Bytes::from(query_blob)), - ]; - - let result = ft_search(&mut store, &search_args); - match &result { - Frame::Array(items) => { - // First element is count - assert!( - matches!(&items[0], Frame::Integer(n) if *n >= 1), - "Should find at least 1 result, got {result:?}" - ); - // vec:0 should be in top-2 results (at dim=4, TQ-4bit quantization - // noise can swap rankings of very close vectors in Light mode) - let mut found_vec0 = false; - for idx in [1, 3].iter() { - if let Some(Frame::BulkString(doc_id)) = items.get(*idx) { - if doc_id.as_ref() == b"vec:0" { - found_vec0 = true; - } - } - } - assert!( - found_vec0, - "vec:0 should be in top-2 results, got {result:?}" - ); - // vec:2 should be in top-2 (at dim=4, TQ noise may reorder) - let mut found_vec2 = false; - for idx in [1, 3].iter() { - if let Some(Frame::BulkString(doc_id)) = items.get(*idx) { - if doc_id.as_ref() == b"vec:2" { - found_vec2 = true; - } - } - } - assert!( - found_vec2, - "vec:2 should be in top-2 results, got {result:?}" - ); - } - Frame::Error(e) => panic!("FT.SEARCH returned error: {:?}", std::str::from_utf8(e)), - _ => panic!("FT.SEARCH should return Array, got {result:?}"), - } - } - - #[test] - fn test_ft_info_returns_correct_data() { - let mut store = VectorStore::new(); - let args = build_ft_create_args("testidx", "test:", "vec", 128, "COSINE"); - ft_create(&mut store, &args); - - let info_args = [Frame::BulkString(Bytes::from_static(b"testidx"))]; - let result = ft_info(&store, &info_args); - match result { - Frame::Array(items) => { - assert!(items.len() >= 6, "FT.INFO should return at least 6 items"); - // Check dimension - let mut found_dim = false; - for pair in items.chunks(2) { - if let Frame::BulkString(key) = &pair[0] { - if key.as_ref() == b"dimension" { - if let Frame::Integer(d) = &pair[1] { - assert_eq!(*d, 128); - found_dim = true; - } - } - } - } - assert!(found_dim, "FT.INFO should return dimension"); - } - other => panic!("FT.INFO should return Array, got {other:?}"), - } - } - - #[test] - fn test_ft_search_unknown_index() { - let mut store = VectorStore::new(); - let args = [ - Frame::BulkString(Bytes::from_static(b"nonexistent")), - Frame::BulkString(Bytes::from_static(b"*=>[KNN 5 @vec $query]")), - Frame::BulkString(Bytes::from_static(b"PARAMS")), - Frame::BulkString(Bytes::from_static(b"2")), - Frame::BulkString(Bytes::from_static(b"query")), - Frame::BulkString(Bytes::from(vec![0u8; 16])), - ]; - let result = ft_search(&mut store, &args); - assert!( - matches!(result, Frame::Error(_)), - "Should error on unknown index, got {result:?}" - ); - } - - #[test] - fn test_parse_filter_clause_tag() { - let args = vec![ - bulk(b"idx"), - bulk(b"*=>[KNN 10 @vec $q]"), - bulk(b"FILTER"), - bulk(b"@category:{electronics}"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"q"), - bulk(b"blob"), - ]; - let filter = parse_filter_clause(&args); - assert!(filter.is_some(), "should parse @category:{{electronics}}"); - match filter.unwrap() { - crate::vector::filter::FilterExpr::TagEq { field, value } => { - assert_eq!(&field[..], b"category"); - assert_eq!(&value[..], b"electronics"); - } - other => panic!("expected TagEq, got {other:?}"), - } - } - - #[test] - fn test_parse_filter_clause_numeric_range() { - let args = vec![ - bulk(b"idx"), - bulk(b"*=>[KNN 5 @vec $q]"), - bulk(b"FILTER"), - bulk(b"@price:[10 100]"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"q"), - bulk(b"blob"), - ]; - let filter = parse_filter_clause(&args); - assert!(filter.is_some()); - match filter.unwrap() { - crate::vector::filter::FilterExpr::NumRange { field, min, max } => { - assert_eq!(&field[..], b"price"); - assert_eq!(*min, 10.0); - assert_eq!(*max, 100.0); - } - other => panic!("expected NumRange, got {other:?}"), - } - } - - #[test] - fn test_parse_filter_clause_numeric_eq() { - let args = vec![ - bulk(b"idx"), - bulk(b"*=>[KNN 5 @vec $q]"), - bulk(b"FILTER"), - bulk(b"@price:[50 50]"), - ]; - let filter = parse_filter_clause(&args); - assert!(filter.is_some()); - match filter.unwrap() { - crate::vector::filter::FilterExpr::NumEq { field, value } => { - assert_eq!(&field[..], b"price"); - assert_eq!(*value, 50.0); - } - other => panic!("expected NumEq, got {other:?}"), - } - } - - #[test] - fn test_parse_filter_clause_compound() { - let args = vec![ - bulk(b"idx"), - bulk(b"*=>[KNN 5 @vec $q]"), - bulk(b"FILTER"), - bulk(b"@a:{x} @b:[1 10]"), - ]; - let filter = parse_filter_clause(&args); - assert!(filter.is_some()); - match filter.unwrap() { - crate::vector::filter::FilterExpr::And(left, right) => { - assert!(matches!( - *left, - crate::vector::filter::FilterExpr::TagEq { .. } - )); - assert!(matches!( - *right, - crate::vector::filter::FilterExpr::NumRange { .. } - )); - } - other => panic!("expected And, got {other:?}"), - } - } - - #[test] - fn test_parse_filter_clause_none() { - // No FILTER keyword - let args = vec![ - bulk(b"idx"), - bulk(b"*=>[KNN 10 @vec $q]"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"q"), - bulk(b"blob"), - ]; - let filter = parse_filter_clause(&args); - assert!(filter.is_none()); - } - - #[test] - fn test_ft_search_with_filter_no_regression() { - // Unfiltered FT.SEARCH still works identically - crate::vector::distance::init(); - let mut store = VectorStore::new(); - let args = ft_create_args(); - ft_create(&mut store, &args); - - let query_vec: Vec = vec![0u8; 128 * 4]; - let search_args = vec![ - bulk(b"myidx"), - bulk(b"*=>[KNN 5 @vec $query]"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"query"), - Frame::BulkString(Bytes::from(query_vec)), - ]; - let result = ft_search(&mut store, &search_args); - match result { - Frame::Array(items) => { - assert_eq!(items[0], Frame::Integer(0)); - } - other => panic!("expected Array, got {other:?}"), - } - } - - #[test] - fn test_vector_index_has_payload_index() { - let mut store = VectorStore::new(); - let args = ft_create_args(); - ft_create(&mut store, &args); - let idx = store.get_index(b"myidx").unwrap(); - // payload_index should exist -- insert and evaluate should work - let _ = &idx.payload_index; - } - - #[test] - fn test_vector_metrics_increment_decrement() { - use std::sync::atomic::Ordering; - - // Capture before-snapshot immediately before each operation to handle - // parallel test interference on global atomics. - let mut store = VectorStore::new(); - let args = ft_create_args(); - - // FT.CREATE should increment VECTOR_INDEXES - let before_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); - ft_create(&mut store, &args); - let after_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); - assert!( - after_create > before_create, - "FT.CREATE should increment VECTOR_INDEXES" - ); - - // FT.SEARCH should increment VECTOR_SEARCH_TOTAL - crate::vector::distance::init(); - let before_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed); - let query_vec: Vec = vec![0u8; 128 * 4]; - let search_args = vec![ - bulk(b"myidx"), - bulk(b"*=>[KNN 5 @vec $query]"), - bulk(b"PARAMS"), - bulk(b"2"), - bulk(b"query"), - Frame::BulkString(Bytes::from(query_vec)), - ]; - ft_search(&mut store, &search_args); - let after_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed); - assert!( - after_search > before_search, - "FT.SEARCH should increment VECTOR_SEARCH_TOTAL" - ); - - // Latency should be non-zero after a search - let latency = crate::vector::metrics::VECTOR_SEARCH_LATENCY_US.load(Ordering::Relaxed); - // latency may be 0 on very fast machines, so just check it was written (could be 0 if sub-microsecond) - - // FT.DROPINDEX should decrement VECTOR_INDEXES - let before_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); - ft_dropindex(&mut store, &[bulk(b"myidx")]); - let after_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); - assert!( - after_drop < before_drop, - "FT.DROPINDEX should decrement VECTOR_INDEXES" - ); - - // Suppress unused variable warning - let _ = latency; - } -} +mod tests; diff --git a/src/command/vector_search/tests.rs b/src/command/vector_search/tests.rs new file mode 100644 index 00000000..c9290a39 --- /dev/null +++ b/src/command/vector_search/tests.rs @@ -0,0 +1,684 @@ +use super::*; +use std::sync::Mutex; + +/// Serialize tests that touch global atomic metrics to avoid flaky interference. +static METRICS_LOCK: Mutex<()> = Mutex::new(()); + +fn bulk(s: &[u8]) -> Frame { + Frame::BulkString(Bytes::from(s.to_vec())) +} + +/// Build a valid FT.CREATE argument list. +fn ft_create_args() -> Vec { + vec![ + bulk(b"myidx"), // index name + bulk(b"ON"), + bulk(b"HASH"), + bulk(b"PREFIX"), + bulk(b"1"), + bulk(b"doc:"), + bulk(b"SCHEMA"), + bulk(b"vec"), + bulk(b"VECTOR"), + bulk(b"HNSW"), + bulk(b"6"), // 6 params = 3 key-value pairs + bulk(b"TYPE"), + bulk(b"FLOAT32"), + bulk(b"DIM"), + bulk(b"128"), + bulk(b"DISTANCE_METRIC"), + bulk(b"L2"), + ] +} + +#[test] +fn test_ft_create_parse_full_syntax() { + let mut store = VectorStore::new(); + let args = ft_create_args(); + let result = ft_create(&mut store, &args); + match &result { + Frame::SimpleString(s) => assert_eq!(&s[..], b"OK"), + other => panic!("expected OK, got {other:?}"), + } + assert_eq!(store.len(), 1); + let idx = store.get_index(b"myidx").unwrap(); + assert_eq!(idx.meta.dimension, 128); + assert_eq!(idx.meta.metric, DistanceMetric::L2); + assert_eq!(idx.meta.key_prefixes.len(), 1); + assert_eq!(&idx.meta.key_prefixes[0][..], b"doc:"); +} + +#[test] +fn test_ft_create_missing_dim() { + let mut store = VectorStore::new(); + // Remove DIM param pair: keep TYPE FLOAT32 and DISTANCE_METRIC L2 (4 params = 2 pairs) + let args = vec![ + bulk(b"myidx"), + bulk(b"ON"), + bulk(b"HASH"), + bulk(b"PREFIX"), + bulk(b"1"), + bulk(b"doc:"), + bulk(b"SCHEMA"), + bulk(b"vec"), + bulk(b"VECTOR"), + bulk(b"HNSW"), + bulk(b"4"), // 4 params = 2 key-value pairs + bulk(b"TYPE"), + bulk(b"FLOAT32"), + bulk(b"DISTANCE_METRIC"), + bulk(b"L2"), + ]; + let result = ft_create(&mut store, &args); + match &result { + Frame::Error(_) => {} // expected + other => panic!("expected error, got {other:?}"), + } +} + +#[test] +fn test_ft_create_duplicate() { + let mut store = VectorStore::new(); + let args = ft_create_args(); + let r1 = ft_create(&mut store, &args); + assert!(matches!(r1, Frame::SimpleString(_))); + + let args2 = ft_create_args(); + let r2 = ft_create(&mut store, &args2); + match &r2 { + Frame::Error(e) => assert!(e.starts_with(b"ERR")), + other => panic!("expected error, got {other:?}"), + } +} + +#[test] +fn test_ft_dropindex() { + let mut store = VectorStore::new(); + let args = ft_create_args(); + ft_create(&mut store, &args); + + // Drop existing + let result = ft_dropindex(&mut store, &[bulk(b"myidx")]); + assert!(matches!(result, Frame::SimpleString(_))); + assert!(store.is_empty()); + + // Drop non-existing + let result = ft_dropindex(&mut store, &[bulk(b"myidx")]); + assert!(matches!(result, Frame::Error(_))); +} + +#[test] +fn test_parse_knn_query() { + let query = b"*=>[KNN 10 @vec $query]"; + let (k, param) = parse_knn_query(query).unwrap(); + assert_eq!(k, 10); + assert_eq!(¶m[..], b"query"); +} + +#[test] +fn test_parse_knn_query_different_k() { + let query = b"*=>[KNN 5 @embedding $blob]"; + let (k, param) = parse_knn_query(query).unwrap(); + assert_eq!(k, 5); + assert_eq!(¶m[..], b"blob"); +} + +#[test] +fn test_parse_knn_query_invalid() { + assert!(parse_knn_query(b"*").is_none()); + assert!(parse_knn_query(b"*=>[NOTAKNN]").is_none()); +} + +#[test] +fn test_extract_param_blob() { + let args = vec![ + bulk(b"idx"), + bulk(b"*=>[KNN 10 @vec $query]"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"query"), + bulk(b"blobdata"), + ]; + let blob = extract_param_blob(&args, b"query").unwrap(); + assert_eq!(&blob[..], b"blobdata"); +} + +#[test] +fn test_extract_param_blob_missing() { + let args = vec![bulk(b"idx"), bulk(b"*=>[KNN 10 @vec $query]")]; + assert!(extract_param_blob(&args, b"query").is_none()); +} + +#[test] +fn test_quantize_f32_to_sq() { + let input = [0.0, 1.0, -1.0, 0.5, -0.5, 2.0, -2.0]; + let mut output = [0i8; 7]; + quantize_f32_to_sq(&input, &mut output); + assert_eq!(output[0], 0); // 0.0 -> 0 + assert_eq!(output[1], 127); // 1.0 -> 127 + assert_eq!(output[2], -127); // -1.0 -> -127 + assert_eq!(output[3], 63); // 0.5 -> 63 (truncated from 63.5) + assert_eq!(output[4], -63); // -0.5 -> -63 + assert_eq!(output[5], 127); // 2.0 clamped to 1.0 -> 127 + assert_eq!(output[6], -127); // -2.0 clamped to -1.0 -> -127 +} + +#[test] +fn test_merge_search_results_combines_shards() { + // Shard 0 returns: [2, "vec:0", ["__vec_score", "0.1"], "vec:1", ["__vec_score", "0.5"]] + // Shard 1 returns: [2, "vec:10", ["__vec_score", "0.3"], "vec:11", ["__vec_score", "0.9"]] + // Global top-2 should be: vec:0 (0.1), vec:10 (0.3) + + let shard0 = Frame::Array( + vec![ + Frame::Integer(2), + bulk(b"vec:0"), + Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.1")].into()), + bulk(b"vec:1"), + Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.5")].into()), + ] + .into(), + ); + + let shard1 = Frame::Array( + vec![ + Frame::Integer(2), + bulk(b"vec:10"), + Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.3")].into()), + bulk(b"vec:11"), + Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.9")].into()), + ] + .into(), + ); + + let result = merge_search_results(&[shard0, shard1], 2); + match result { + Frame::Array(items) => { + assert_eq!(items[0], Frame::Integer(2)); + assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:0"))); + assert_eq!(items[3], Frame::BulkString(Bytes::from("vec:10"))); + } + other => panic!("expected Array, got {other:?}"), + } +} + +#[test] +fn test_merge_search_results_handles_errors() { + // One shard returns error, one returns valid results + let shard0 = Frame::Error(Bytes::from_static(b"ERR shard unavailable")); + let shard1 = Frame::Array( + vec![ + Frame::Integer(1), + bulk(b"vec:5"), + Frame::Array(vec![bulk(b"__vec_score"), bulk(b"0.2")].into()), + ] + .into(), + ); + + let result = merge_search_results(&[shard0, shard1], 5); + match result { + Frame::Array(items) => { + assert_eq!(items[0], Frame::Integer(1)); + assert_eq!(items[1], Frame::BulkString(Bytes::from("vec:5"))); + } + other => panic!("expected Array, got {other:?}"), + } +} + +#[test] +fn test_merge_search_results_empty() { + // No results from any shard + let shard0 = Frame::Array(vec![Frame::Integer(0)].into()); + let shard1 = Frame::Array(vec![Frame::Integer(0)].into()); + + let result = merge_search_results(&[shard0, shard1], 10); + match result { + Frame::Array(items) => { + assert_eq!(items.len(), 1); + assert_eq!(items[0], Frame::Integer(0)); + } + other => panic!("expected Array, got {other:?}"), + } +} + +#[test] +fn test_ft_search_dimension_mismatch() { + let mut store = VectorStore::new(); + let args = ft_create_args(); + ft_create(&mut store, &args); + + // Build a query with wrong dimension (4 bytes instead of 128*4) + let search_args = vec![ + bulk(b"myidx"), + bulk(b"*=>[KNN 10 @vec $query]"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"query"), + bulk(b"tooshort"), + ]; + let result = ft_search(&mut store, &search_args); + match &result { + Frame::Error(e) => assert!( + e.starts_with(b"ERR query vector dimension"), + "expected dimension mismatch error, got {:?}", + std::str::from_utf8(e) + ), + other => panic!("expected error, got {other:?}"), + } +} + +#[test] +fn test_ft_search_empty_index() { + let mut store = VectorStore::new(); + let args = ft_create_args(); + ft_create(&mut store, &args); + + // Build valid query for dim=128 + let query_vec: Vec = vec![0u8; 128 * 4]; // 128 floats, all zero + let search_args = vec![ + bulk(b"myidx"), + bulk(b"*=>[KNN 5 @vec $query]"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"query"), + Frame::BulkString(Bytes::from(query_vec)), + ]; + crate::vector::distance::init(); + let result = ft_search(&mut store, &search_args); + match result { + Frame::Array(items) => { + assert_eq!(items[0], Frame::Integer(0)); // no results + } + other => panic!("expected Array, got {other:?}"), + } +} + +#[test] +fn test_ft_info() { + let mut store = VectorStore::new(); + let args = ft_create_args(); + ft_create(&mut store, &args); + + let result = ft_info(&store, &[bulk(b"myidx")]); + match result { + Frame::Array(items) => { + // Should have 20 items (10 key-value pairs) + assert!( + items.len() >= 20, + "FT.INFO should return at least 20 items, got {}", + items.len() + ); + assert_eq!( + items[0], + Frame::BulkString(Bytes::from_static(b"index_name")) + ); + assert_eq!(items[1], Frame::BulkString(Bytes::from("myidx"))); + assert_eq!(items[5], Frame::Integer(0)); // num_docs = 0 + assert_eq!(items[7], Frame::Integer(128)); // dimension + // New fields + assert_eq!(items[10], Frame::BulkString(Bytes::from_static(b"M"))); + assert_eq!(items[11], Frame::Integer(16)); // default M + assert_eq!( + items[14], + Frame::BulkString(Bytes::from_static(b"EF_RUNTIME")) + ); + } + other => panic!("expected Array, got {other:?}"), + } + + // Non-existing index + let result = ft_info(&store, &[bulk(b"nonexistent")]); + assert!(matches!(result, Frame::Error(_))); +} + +/// Helper to build FT.CREATE args with custom parameters. +fn build_ft_create_args( + name: &str, + prefix: &str, + field: &str, + dim: u32, + metric: &str, +) -> Vec { + vec![ + Frame::BulkString(Bytes::from(name.to_owned())), + Frame::BulkString(Bytes::from_static(b"ON")), + Frame::BulkString(Bytes::from_static(b"HASH")), + Frame::BulkString(Bytes::from_static(b"PREFIX")), + Frame::BulkString(Bytes::from_static(b"1")), + Frame::BulkString(Bytes::from(prefix.to_owned())), + Frame::BulkString(Bytes::from_static(b"SCHEMA")), + Frame::BulkString(Bytes::from(field.to_owned())), + Frame::BulkString(Bytes::from_static(b"VECTOR")), + Frame::BulkString(Bytes::from_static(b"HNSW")), + Frame::BulkString(Bytes::from_static(b"6")), + Frame::BulkString(Bytes::from_static(b"TYPE")), + Frame::BulkString(Bytes::from_static(b"FLOAT32")), + Frame::BulkString(Bytes::from_static(b"DIM")), + Frame::BulkString(Bytes::from(dim.to_string())), + Frame::BulkString(Bytes::from_static(b"DISTANCE_METRIC")), + Frame::BulkString(Bytes::from(metric.to_owned())), + ] +} + +#[test] +fn test_end_to_end_create_insert_search() { + // Initialize distance functions (required before any search) + crate::vector::distance::init(); + + let mut store = VectorStore::new(); + let dim: usize = 4; + + // 1. FT.CREATE + let create_args = build_ft_create_args("e2eidx", "doc:", "embedding", dim as u32, "L2"); + let result = ft_create(&mut store, &create_args); + assert!( + matches!(result, Frame::SimpleString(_)), + "FT.CREATE should return OK, got {result:?}" + ); + + // 2. Insert vectors directly into the mutable segment + let idx = store.get_index_mut(b"e2eidx").unwrap(); + let vectors: Vec<[f32; 4]> = vec![ + [1.0, 0.0, 0.0, 0.0], // vec:0 -- exact match for query (L2=0) + [-1.0, 0.0, 0.0, 0.0], // vec:1 -- opposite direction (L2=4.0) + [0.5, 0.0, 0.0, 0.0], // vec:2 -- same direction, half magnitude (L2=0.25) + ]; + + let snap = idx.segments.load(); + for (i, v) in vectors.iter().enumerate() { + let mut sq = vec![0i8; dim]; + quantize_f32_to_sq(v, &mut sq); + let norm = v.iter().map(|x| x * x).sum::().sqrt(); + snap.mutable.append(i as u64, v, &sq, norm, i as u64); + } + drop(snap); + + // 3. FT.SEARCH for vector close to [1.0, 0.0, 0.0, 0.0] + let query_vec: [f32; 4] = [1.0, 0.0, 0.0, 0.0]; + let query_blob: Vec = query_vec.iter().flat_map(|f| f.to_le_bytes()).collect(); + + let search_args = vec![ + Frame::BulkString(Bytes::from_static(b"e2eidx")), + Frame::BulkString(Bytes::from_static(b"*=>[KNN 2 @embedding $query]")), + Frame::BulkString(Bytes::from_static(b"PARAMS")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"query")), + Frame::BulkString(Bytes::from(query_blob)), + ]; + + let result = ft_search(&mut store, &search_args); + match &result { + Frame::Array(items) => { + // First element is count + assert!( + matches!(&items[0], Frame::Integer(n) if *n >= 1), + "Should find at least 1 result, got {result:?}" + ); + // vec:0 should be in top-2 results (at dim=4, TQ-4bit quantization + // noise can swap rankings of very close vectors in Light mode) + let mut found_vec0 = false; + for idx in [1, 3].iter() { + if let Some(Frame::BulkString(doc_id)) = items.get(*idx) { + if doc_id.as_ref() == b"vec:0" { + found_vec0 = true; + } + } + } + assert!( + found_vec0, + "vec:0 should be in top-2 results, got {result:?}" + ); + // vec:2 should be in top-2 (at dim=4, TQ noise may reorder) + let mut found_vec2 = false; + for idx in [1, 3].iter() { + if let Some(Frame::BulkString(doc_id)) = items.get(*idx) { + if doc_id.as_ref() == b"vec:2" { + found_vec2 = true; + } + } + } + assert!( + found_vec2, + "vec:2 should be in top-2 results, got {result:?}" + ); + } + Frame::Error(e) => panic!("FT.SEARCH returned error: {:?}", std::str::from_utf8(e)), + _ => panic!("FT.SEARCH should return Array, got {result:?}"), + } +} + +#[test] +fn test_ft_info_returns_correct_data() { + let mut store = VectorStore::new(); + let args = build_ft_create_args("testidx", "test:", "vec", 128, "COSINE"); + ft_create(&mut store, &args); + + let info_args = [Frame::BulkString(Bytes::from_static(b"testidx"))]; + let result = ft_info(&store, &info_args); + match result { + Frame::Array(items) => { + assert!(items.len() >= 6, "FT.INFO should return at least 6 items"); + // Check dimension + let mut found_dim = false; + for pair in items.chunks(2) { + if let Frame::BulkString(key) = &pair[0] { + if key.as_ref() == b"dimension" { + if let Frame::Integer(d) = &pair[1] { + assert_eq!(*d, 128); + found_dim = true; + } + } + } + } + assert!(found_dim, "FT.INFO should return dimension"); + } + other => panic!("FT.INFO should return Array, got {other:?}"), + } +} + +#[test] +fn test_ft_search_unknown_index() { + let mut store = VectorStore::new(); + let args = [ + Frame::BulkString(Bytes::from_static(b"nonexistent")), + Frame::BulkString(Bytes::from_static(b"*=>[KNN 5 @vec $query]")), + Frame::BulkString(Bytes::from_static(b"PARAMS")), + Frame::BulkString(Bytes::from_static(b"2")), + Frame::BulkString(Bytes::from_static(b"query")), + Frame::BulkString(Bytes::from(vec![0u8; 16])), + ]; + let result = ft_search(&mut store, &args); + assert!( + matches!(result, Frame::Error(_)), + "Should error on unknown index, got {result:?}" + ); +} + +#[test] +fn test_parse_filter_clause_tag() { + let args = vec![ + bulk(b"idx"), + bulk(b"*=>[KNN 10 @vec $q]"), + bulk(b"FILTER"), + bulk(b"@category:{electronics}"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"q"), + bulk(b"blob"), + ]; + let filter = parse_filter_clause(&args); + assert!(filter.is_some(), "should parse @category:{{electronics}}"); + match filter.unwrap() { + crate::vector::filter::FilterExpr::TagEq { field, value } => { + assert_eq!(&field[..], b"category"); + assert_eq!(&value[..], b"electronics"); + } + other => panic!("expected TagEq, got {other:?}"), + } +} + +#[test] +fn test_parse_filter_clause_numeric_range() { + let args = vec![ + bulk(b"idx"), + bulk(b"*=>[KNN 5 @vec $q]"), + bulk(b"FILTER"), + bulk(b"@price:[10 100]"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"q"), + bulk(b"blob"), + ]; + let filter = parse_filter_clause(&args); + assert!(filter.is_some()); + match filter.unwrap() { + crate::vector::filter::FilterExpr::NumRange { field, min, max } => { + assert_eq!(&field[..], b"price"); + assert_eq!(*min, 10.0); + assert_eq!(*max, 100.0); + } + other => panic!("expected NumRange, got {other:?}"), + } +} + +#[test] +fn test_parse_filter_clause_numeric_eq() { + let args = vec![ + bulk(b"idx"), + bulk(b"*=>[KNN 5 @vec $q]"), + bulk(b"FILTER"), + bulk(b"@price:[50 50]"), + ]; + let filter = parse_filter_clause(&args); + assert!(filter.is_some()); + match filter.unwrap() { + crate::vector::filter::FilterExpr::NumEq { field, value } => { + assert_eq!(&field[..], b"price"); + assert_eq!(*value, 50.0); + } + other => panic!("expected NumEq, got {other:?}"), + } +} + +#[test] +fn test_parse_filter_clause_compound() { + let args = vec![ + bulk(b"idx"), + bulk(b"*=>[KNN 5 @vec $q]"), + bulk(b"FILTER"), + bulk(b"@a:{x} @b:[1 10]"), + ]; + let filter = parse_filter_clause(&args); + assert!(filter.is_some()); + match filter.unwrap() { + crate::vector::filter::FilterExpr::And(left, right) => { + assert!(matches!( + *left, + crate::vector::filter::FilterExpr::TagEq { .. } + )); + assert!(matches!( + *right, + crate::vector::filter::FilterExpr::NumRange { .. } + )); + } + other => panic!("expected And, got {other:?}"), + } +} + +#[test] +fn test_parse_filter_clause_none() { + // No FILTER keyword + let args = vec![ + bulk(b"idx"), + bulk(b"*=>[KNN 10 @vec $q]"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"q"), + bulk(b"blob"), + ]; + let filter = parse_filter_clause(&args); + assert!(filter.is_none()); +} + +#[test] +fn test_ft_search_with_filter_no_regression() { + // Unfiltered FT.SEARCH still works identically + crate::vector::distance::init(); + let mut store = VectorStore::new(); + let args = ft_create_args(); + ft_create(&mut store, &args); + + let query_vec: Vec = vec![0u8; 128 * 4]; + let search_args = vec![ + bulk(b"myidx"), + bulk(b"*=>[KNN 5 @vec $query]"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"query"), + Frame::BulkString(Bytes::from(query_vec)), + ]; + let result = ft_search(&mut store, &search_args); + match result { + Frame::Array(items) => { + assert_eq!(items[0], Frame::Integer(0)); + } + other => panic!("expected Array, got {other:?}"), + } +} + +#[test] +fn test_vector_index_has_payload_index() { + let mut store = VectorStore::new(); + let args = ft_create_args(); + ft_create(&mut store, &args); + let idx = store.get_index(b"myidx").unwrap(); + // payload_index should exist -- insert and evaluate should work + let _ = &idx.payload_index; +} + +#[test] +fn test_vector_metrics_increment_decrement() { + use std::sync::atomic::Ordering; + + let _guard = METRICS_LOCK.lock().unwrap(); + + let mut store = VectorStore::new(); + let args = ft_create_args(); + + // FT.CREATE should increment VECTOR_INDEXES + let before_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); + ft_create(&mut store, &args); + let after_create = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); + assert!( + after_create > before_create, + "FT.CREATE should increment VECTOR_INDEXES" + ); + + // FT.SEARCH should increment VECTOR_SEARCH_TOTAL + crate::vector::distance::init(); + let before_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed); + let query_vec: Vec = vec![0u8; 128 * 4]; + let search_args = vec![ + bulk(b"myidx"), + bulk(b"*=>[KNN 5 @vec $query]"), + bulk(b"PARAMS"), + bulk(b"2"), + bulk(b"query"), + Frame::BulkString(Bytes::from(query_vec)), + ]; + ft_search(&mut store, &search_args); + let after_search = crate::vector::metrics::VECTOR_SEARCH_TOTAL.load(Ordering::Relaxed); + assert!( + after_search > before_search, + "FT.SEARCH should increment VECTOR_SEARCH_TOTAL" + ); + + // FT.DROPINDEX should decrement VECTOR_INDEXES + let before_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); + ft_dropindex(&mut store, &[bulk(b"myidx")]); + let after_drop = crate::vector::metrics::VECTOR_INDEXES.load(Ordering::Relaxed); + assert!( + after_drop < before_drop, + "FT.DROPINDEX should decrement VECTOR_INDEXES" + ); +} diff --git a/src/server/conn/handler_monoio.rs b/src/server/conn/handler_monoio.rs index b13e945d..220eb9a9 100644 --- a/src/server/conn/handler_monoio.rs +++ b/src/server/conn/handler_monoio.rs @@ -1414,62 +1414,92 @@ pub async fn handle_connection_sharded_monoio< // Local shard: direct VectorStore access via shard_databases. // Remote shards: SPSC dispatch. Works with any shard count (including 1). if cmd.len() > 3 && cmd[..3].eq_ignore_ascii_case(b"FT.") { - if cmd.eq_ignore_ascii_case(b"FT.SEARCH") { - let response = - match crate::command::vector_search::parse_ft_search_args(cmd_args) { - Ok((index_name, query_blob, k, _filter)) => { - crate::shard::coordinator::scatter_vector_search_remote( - index_name, - query_blob, - k, - shard_id, - num_shards, - &shard_databases, - &dispatch_tx, - &spsc_notifiers, - ) - .await - } - Err(err_frame) => err_frame, + if num_shards > 1 { + // Multi-shard: dispatch via SPSC + if cmd.eq_ignore_ascii_case(b"FT.SEARCH") { + let response = + match crate::command::vector_search::parse_ft_search_args(cmd_args) { + Ok((index_name, query_blob, k, filter)) => { + if filter.is_some() { + Frame::Error(Bytes::from_static( + b"ERR FILTER not supported in multi-shard mode yet", + )) + } else { + crate::shard::coordinator::scatter_vector_search_remote( + index_name, + query_blob, + k, + shard_id, + num_shards, + &shard_databases, + &dispatch_tx, + &spsc_notifiers, + ) + .await + } + } + Err(err_frame) => err_frame, + }; + responses.push(response); + continue; + } + if cmd.eq_ignore_ascii_case(b"FT.CREATE") + || cmd.eq_ignore_ascii_case(b"FT.DROPINDEX") + { + // Broadcast to ALL shards so every shard has the index + let response = crate::shard::coordinator::broadcast_vector_command( + std::sync::Arc::new(frame), + shard_id, + num_shards, + &shard_databases, + &dispatch_tx, + &spsc_notifiers, + ) + .await; + responses.push(response); + continue; + } + if cmd.eq_ignore_ascii_case(b"FT.INFO") { + let response = { + let vs = shard_databases.vector_store(shard_id); + crate::command::vector_search::ft_info(&vs, cmd_args) }; - responses.push(response); - continue; - } - if cmd.eq_ignore_ascii_case(b"FT.CREATE") - || cmd.eq_ignore_ascii_case(b"FT.DROPINDEX") - { - // Broadcast to ALL shards so every shard has the index - let response = crate::shard::coordinator::broadcast_vector_command( - std::sync::Arc::new(frame), - shard_id, - num_shards, - &shard_databases, - &dispatch_tx, - &spsc_notifiers, - ) - .await; - responses.push(response); - continue; - } - if cmd.eq_ignore_ascii_case(b"FT.INFO") { - // Read-only: local shard is sufficient - let response = { - let vs = shard_databases.vector_store(shard_id); - crate::command::vector_search::ft_info(&vs, cmd_args) - }; - responses.push(response); + responses.push(response); + continue; + } + if cmd.eq_ignore_ascii_case(b"FT.COMPACT") { + let response = { + let mut vs = shard_databases.vector_store(shard_id); + crate::command::vector_search::ft_compact(&mut vs, cmd_args) + }; + responses.push(response); + continue; + } + responses.push(Frame::Error(Bytes::from_static(b"ERR unknown FT command"))); continue; - } - if cmd.eq_ignore_ascii_case(b"FT.COMPACT") { + } else { + // Single-shard: no SPSC channels needed. + // Dispatch directly to shard's VectorStore via shared access. let response = { - let mut vs = shard_databases.vector_store(shard_id); - crate::command::vector_search::ft_compact(&mut vs, cmd_args) + let shard_databases_ref = &shard_databases; + let mut vs = shard_databases_ref.vector_store(shard_id); + if cmd.eq_ignore_ascii_case(b"FT.CREATE") { + crate::command::vector_search::ft_create(&mut vs, cmd_args) + } else if cmd.eq_ignore_ascii_case(b"FT.SEARCH") { + crate::command::vector_search::ft_search(&mut vs, cmd_args) + } else if cmd.eq_ignore_ascii_case(b"FT.DROPINDEX") { + crate::command::vector_search::ft_dropindex(&mut vs, cmd_args) + } else if cmd.eq_ignore_ascii_case(b"FT.INFO") { + crate::command::vector_search::ft_info(&vs, cmd_args) + } else if cmd.eq_ignore_ascii_case(b"FT.COMPACT") { + crate::command::vector_search::ft_compact(&mut vs, cmd_args) + } else { + Frame::Error(Bytes::from_static(b"ERR unknown FT.* command")) + } }; responses.push(response); continue; } - responses.push(Frame::Error(Bytes::from_static(b"ERR unknown FT command"))); - continue; } // --- Routing: keyless, local, or remote --- diff --git a/src/server/conn/handler_sharded.rs b/src/server/conn/handler_sharded.rs index 68e7e1c8..404643f8 100644 --- a/src/server/conn/handler_sharded.rs +++ b/src/server/conn/handler_sharded.rs @@ -1258,13 +1258,19 @@ pub async fn handle_connection_sharded_inner< // Multi-shard: dispatch via SPSC if cmd.eq_ignore_ascii_case(b"FT.SEARCH") { let response = match crate::command::vector_search::parse_ft_search_args(cmd_args) { - Ok((index_name, query_blob, k, _filter)) => { - crate::shard::coordinator::scatter_vector_search_remote( - index_name, query_blob, k, - shard_id, num_shards, - &shard_databases, - &dispatch_tx, &spsc_notifiers, - ).await + Ok((index_name, query_blob, k, filter)) => { + if filter.is_some() { + Frame::Error(Bytes::from_static( + b"ERR FILTER not supported in multi-shard mode yet", + )) + } else { + crate::shard::coordinator::scatter_vector_search_remote( + index_name, query_blob, k, + shard_id, num_shards, + &shard_databases, + &dispatch_tx, &spsc_notifiers, + ).await + } } Err(err_frame) => err_frame, }; @@ -1359,24 +1365,6 @@ pub async fn handle_connection_sharded_inner< DispatchResult::Response(f) => f, DispatchResult::Quit(f) => { should_quit = true; f } }; - // Auto-index vectors on successful HSET (local write path) - if !matches!(response, Frame::Error(_)) - && (cmd.eq_ignore_ascii_case(b"HSET") || cmd.eq_ignore_ascii_case(b"HMSET")) - { - if let Some(key) = cmd_args.first().and_then(|f| extract_bytes(f)) { - let mut vs = shard_databases.vector_store(shard_id); - crate::shard::spsc_handler::auto_index_hset_public(&mut vs, &key, cmd_args); - } - } - // Auto-delete vectors on DEL/HDEL/UNLINK (local write path) - if !matches!(response, Frame::Error(_)) - && (cmd.eq_ignore_ascii_case(b"DEL") || cmd.eq_ignore_ascii_case(b"UNLINK") || cmd.eq_ignore_ascii_case(b"HDEL")) - { - if let Some(key) = cmd_args.first().and_then(|f| extract_bytes(f)) { - let mut vs = shard_databases.vector_store(shard_id); - vs.mark_deleted_for_key(&key); - } - } if !matches!(response, Frame::Error(_)) { let needs_wake = cmd.eq_ignore_ascii_case(b"LPUSH") || cmd.eq_ignore_ascii_case(b"RPUSH") || cmd.eq_ignore_ascii_case(b"LMOVE") || cmd.eq_ignore_ascii_case(b"ZADD"); @@ -1392,6 +1380,30 @@ pub async fn handle_connection_sharded_inner< } } drop(guard); + // Auto-index vectors on successful HSET (local write path) + // Placed AFTER drop(guard) to avoid DB→vector_store lock order + // inversion with the shard event loop (vector_store→DB). + if !matches!(response, Frame::Error(_)) + && (cmd.eq_ignore_ascii_case(b"HSET") || cmd.eq_ignore_ascii_case(b"HMSET")) + { + if let Some(key) = cmd_args.first().and_then(|f| extract_bytes(f)) { + let mut vs = shard_databases.vector_store(shard_id); + crate::shard::spsc_handler::auto_index_hset_public(&mut vs, &key, cmd_args); + } + } + // Auto-delete vectors on DEL/UNLINK (local write path) + // Note: HDEL removes fields, not keys — it should NOT trigger + // vector deletion unless the entire key is removed. + if !matches!(response, Frame::Error(_)) + && (cmd.eq_ignore_ascii_case(b"DEL") || cmd.eq_ignore_ascii_case(b"UNLINK")) + { + let mut vs = shard_databases.vector_store(shard_id); + for arg in cmd_args.iter() { + if let Some(key) = extract_bytes(arg) { + vs.mark_deleted_for_key(key.as_ref()); + } + } + } if let Some(bytes) = aof_bytes { if !matches!(response, Frame::Error(_)) { if let Some(ref tx) = aof_tx { let _ = tx.try_send(AofMessage::Append(bytes)); } diff --git a/src/shard/spsc_handler.rs b/src/shard/spsc_handler.rs index a9877212..4e277ded 100644 --- a/src/shard/spsc_handler.rs +++ b/src/shard/spsc_handler.rs @@ -272,26 +272,17 @@ pub(crate) fn handle_shard_message_shared( } } - // Auto-delete: if DEL/HDEL/UNLINK succeeded and key matches a vector + // Auto-delete: if DEL/UNLINK succeeded and key matches a vector // index prefix, mark stale vectors as deleted in matching indexes. - if (cmd.eq_ignore_ascii_case(b"DEL") - || cmd.eq_ignore_ascii_case(b"HDEL") - || cmd.eq_ignore_ascii_case(b"UNLINK")) + // Note: HDEL removes fields, not keys — it should NOT trigger vector + // deletion unless the entire key is removed. + if (cmd.eq_ignore_ascii_case(b"DEL") || cmd.eq_ignore_ascii_case(b"UNLINK")) && !matches!(frame, crate::protocol::Frame::Error(_)) { - // DEL/UNLINK: args are keys (args[0], args[1], ...). - // HDEL: args[0] is the hash key, remaining are fields. - // For HDEL we only mark the hash key itself (the vector source). - if cmd.eq_ignore_ascii_case(b"HDEL") { - if let Some(crate::protocol::Frame::BulkString(key_bytes)) = args.first() { + for arg in args { + if let crate::protocol::Frame::BulkString(key_bytes) = arg { vector_store.mark_deleted_for_key(key_bytes); } - } else { - for arg in args { - if let crate::protocol::Frame::BulkString(key_bytes) = arg { - vector_store.mark_deleted_for_key(key_bytes); - } - } } } diff --git a/src/vector/distance/fastscan.rs b/src/vector/distance/fastscan.rs index 6b30d9dc..bffd7ca6 100644 --- a/src/vector/distance/fastscan.rs +++ b/src/vector/distance/fastscan.rs @@ -48,12 +48,16 @@ pub fn init_fastscan() { /// Get the static FastScan dispatch table. /// -/// # Safety contract -/// Caller must ensure [`init_fastscan()`] has been called before first use. +/// Auto-initializes on first use if [`init_fastscan()`] was not called explicitly. +/// After the first call the hot path is two atomic loads (both always succeed). #[inline(always)] pub fn fastscan_dispatch() -> &'static FastScanDispatch { - // SAFETY: init_fastscan() is called from distance::init() at startup. - unsafe { FASTSCAN_DISPATCH.get().unwrap_unchecked() } + if FASTSCAN_DISPATCH.get().is_none() { + init_fastscan(); + } + FASTSCAN_DISPATCH + .get() + .expect("fastscan dispatch initialized by init_fastscan()") } /// Scalar FastScan: compute distances for 32 vectors in one interleaved block. diff --git a/src/vector/distance/mod.rs b/src/vector/distance/mod.rs index 2e7e9c4f..6a5fe569 100644 --- a/src/vector/distance/mod.rs +++ b/src/vector/distance/mod.rs @@ -141,20 +141,17 @@ pub fn init() { /// Returns the table initialized by [`init()`]. This is a single pointer load /// followed by a direct function call — at most 1 cache miss per call site. /// -/// # Safety contract -/// Caller must ensure [`init()`] has been called before the first call to `table()`. -/// In practice, `init()` is called from `main()` at startup. +/// Auto-initializes on first use if [`init()`] was not called explicitly. +/// After the first call the hot path is two atomic loads (both always succeed). #[inline(always)] pub fn table() -> &'static DistanceTable { - // SAFETY: init() is called from main() at startup before any search operation. - // The OnceLock is guaranteed to be initialized by the time any search - // path reaches this function. Using unwrap_unchecked avoids a branch - // on the hot path. - debug_assert!( - DISTANCE_TABLE.get().is_some(), - "distance::init() was not called before table()" - ); - unsafe { DISTANCE_TABLE.get().unwrap_unchecked() } + if DISTANCE_TABLE.get().is_none() { + init(); + } + // After init(), DISTANCE_TABLE is guaranteed to be set. + DISTANCE_TABLE + .get() + .expect("distance table initialized by init()") } #[cfg(test)] diff --git a/src/vector/hnsw/search.rs b/src/vector/hnsw/search.rs index 796eae53..5f69eca2 100644 --- a/src/vector/hnsw/search.rs +++ b/src/vector/hnsw/search.rs @@ -281,13 +281,13 @@ pub fn hnsw_search_filtered( let original_dim = query.len(); let padded_dim = q_rotated.len(); let _active_code_bytes = original_dim / 2; // nibble-packed bytes for original dim - let entries_per_coord: usize = if use_subcent { 32 } else { 16 }; - let sub_table = collection.sub_centroid_table.as_ref(); + // Guard use_subcent on sub_table availability to avoid panic + let use_subcent = use_subcent && sub_table.is_some(); + let entries_per_coord: usize = if use_subcent { 32 } else { 16 }; let mut adc_lut = Vec::with_capacity(padded_dim * entries_per_coord); - if use_subcent { - let st = sub_table.unwrap(); + if let Some(st) = sub_table.filter(|_| use_subcent) { for j in 0..padded_dim { let q = q_rotated[j]; for e in 0..32 { diff --git a/src/vector/turbo_quant/codebook.rs b/src/vector/turbo_quant/codebook.rs index a1b4173c..c4690c89 100644 --- a/src/vector/turbo_quant/codebook.rs +++ b/src/vector/turbo_quant/codebook.rs @@ -128,32 +128,36 @@ pub const RAW_BOUNDARIES_3BIT: [f32; 7] = [-1.7480, -1.0500, -0.5006, 0.0, 0.500 /// /// Returns a Vec because the size varies by bit width. /// sigma = 1/sqrt(padded_dim), matching FWHT normalization. -pub fn scaled_centroids_n(padded_dim: u32, bits: u8) -> Vec { +/// +/// Returns `Err` for unsupported bit widths (anything outside 1-4). +pub fn scaled_centroids_n(padded_dim: u32, bits: u8) -> Result, &'static str> { let sigma = 1.0 / (padded_dim as f32).sqrt(); match bits { - 1 => RAW_CENTROIDS_1BIT.iter().map(|&c| c * sigma).collect(), - 2 => RAW_CENTROIDS_2BIT.iter().map(|&c| c * sigma).collect(), - 3 => RAW_CENTROIDS_3BIT.iter().map(|&c| c * sigma).collect(), + 1 => Ok(RAW_CENTROIDS_1BIT.iter().map(|&c| c * sigma).collect()), + 2 => Ok(RAW_CENTROIDS_2BIT.iter().map(|&c| c * sigma).collect()), + 3 => Ok(RAW_CENTROIDS_3BIT.iter().map(|&c| c * sigma).collect()), 4 => { let sc = scaled_centroids(padded_dim); - sc.to_vec() + Ok(sc.to_vec()) } - _ => panic!("unsupported bit width: {bits}"), + _ => Err("unsupported bit width"), } } /// Compute dimension-scaled boundaries for any bit width (1-4). -pub fn scaled_boundaries_n(padded_dim: u32, bits: u8) -> Vec { +/// +/// Returns `Err` for unsupported bit widths (anything outside 1-4). +pub fn scaled_boundaries_n(padded_dim: u32, bits: u8) -> Result, &'static str> { let sigma = 1.0 / (padded_dim as f32).sqrt(); match bits { - 1 => RAW_BOUNDARIES_1BIT.iter().map(|&b| b * sigma).collect(), - 2 => RAW_BOUNDARIES_2BIT.iter().map(|&b| b * sigma).collect(), - 3 => RAW_BOUNDARIES_3BIT.iter().map(|&b| b * sigma).collect(), + 1 => Ok(RAW_BOUNDARIES_1BIT.iter().map(|&b| b * sigma).collect()), + 2 => Ok(RAW_BOUNDARIES_2BIT.iter().map(|&b| b * sigma).collect()), + 3 => Ok(RAW_BOUNDARIES_3BIT.iter().map(|&b| b * sigma).collect()), 4 => { let sb = scaled_boundaries(padded_dim); - sb.to_vec() + Ok(sb.to_vec()) } - _ => panic!("unsupported bit width: {bits}"), + _ => Err("unsupported bit width"), } } @@ -186,7 +190,10 @@ pub fn code_bytes_per_vector(padded_dim: u32, bits: u8) -> usize { 2 => pd / 4, 3 => (pd * 3 + 7) / 8, 4 => pd / 2, - _ => panic!("unsupported bit width: {bits}"), + _ => { + tracing::error!("unsupported bit width {bits} for code_bytes_per_vector"); + 0 + } } } @@ -377,19 +384,20 @@ mod tests { #[test] fn test_scaled_centroids_n_sizes() { let pdim = 1024u32; - assert_eq!(scaled_centroids_n(pdim, 1).len(), 2); - assert_eq!(scaled_centroids_n(pdim, 2).len(), 4); - assert_eq!(scaled_centroids_n(pdim, 3).len(), 8); - assert_eq!(scaled_centroids_n(pdim, 4).len(), 16); + assert_eq!(scaled_centroids_n(pdim, 1).unwrap().len(), 2); + assert_eq!(scaled_centroids_n(pdim, 2).unwrap().len(), 4); + assert_eq!(scaled_centroids_n(pdim, 3).unwrap().len(), 8); + assert_eq!(scaled_centroids_n(pdim, 4).unwrap().len(), 16); + assert!(scaled_centroids_n(pdim, 5).is_err()); } #[test] fn test_scaled_centroids_n_values() { let pdim = 1024u32; let sigma = 1.0 / (pdim as f32).sqrt(); - let c1 = scaled_centroids_n(pdim, 1); + let c1 = scaled_centroids_n(pdim, 1).unwrap(); assert!((c1[1] - 0.7979 * sigma).abs() < 1e-6); - let c2 = scaled_centroids_n(pdim, 2); + let c2 = scaled_centroids_n(pdim, 2).unwrap(); assert!((c2[3] - 1.5104 * sigma).abs() < 1e-5); } diff --git a/src/vector/turbo_quant/collection.rs b/src/vector/turbo_quant/collection.rs index 39450c84..f1c91e58 100644 --- a/src/vector/turbo_quant/collection.rs +++ b/src/vector/turbo_quant/collection.rs @@ -218,13 +218,17 @@ impl CollectionMetadata { fwht_sign_flips: sign_flips, codebook_version: CODEBOOK_VERSION, codebook: if quantization.is_turbo_quant() { + // Fail fast on invalid bit width — this is a programming invariant, + // not user input. Valid bit widths (1-4) are guaranteed by QuantizationConfig. scaled_centroids_n(padded, quantization.bits()) + .expect("codebook centroids: invalid bit width is a programming bug") } else { // SQ8 doesn't use codebooks -- store empty Vec Vec::new() }, codebook_boundaries: if quantization.is_turbo_quant() { scaled_boundaries_n(padded, quantization.bits()) + .expect("codebook boundaries: invalid bit width is a programming bug") } else { Vec::new() }, @@ -276,32 +280,56 @@ impl CollectionMetadata { code_bytes_per_vector(self.padded_dimension, self.quantization.bits()) } - /// Convenience accessor: returns the codebook boundaries as a `&[f32; 15]` reference. + /// Returns the codebook boundaries as a `&[f32; 15]` reference. /// - /// Panics if quantization is not 4-bit (only valid for TurboQuant4 / TurboQuantProd4). - /// Used by legacy `encode_tq_mse_scaled` which requires fixed-size array. + /// Only valid for 4-bit quantization (TurboQuant4 / TurboQuantProd4). + /// The codebook is guaranteed to have exactly 15 boundaries at construction + /// for 4-bit configs. If the invariant is violated (programming bug), logs + /// an error and returns a zeroed fallback to avoid panicking in production. pub fn codebook_boundaries_15(&self) -> &[f32; 15] { - assert_eq!( + debug_assert_eq!( self.codebook_boundaries.len(), 15, - "codebook_boundaries_15 requires 4-bit quantization (15 boundaries), got {}", + "codebook_boundaries_15 called on non-4-bit quantization (len={})", self.codebook_boundaries.len() ); - self.codebook_boundaries[..15].try_into().unwrap() + match self.codebook_boundaries.as_slice().try_into() { + Ok(arr) => arr, + Err(_) => { + tracing::error!( + "codebook_boundaries has {} entries, expected 15 — construction invariant violated", + self.codebook_boundaries.len() + ); + static ZERO: [f32; 15] = [0.0; 15]; + &ZERO + } + } } - /// Convenience accessor: returns the codebook as a `&[f32; 16]` reference. + /// Returns the codebook as a `&[f32; 16]` reference. /// - /// Panics if quantization is not 4-bit (only valid for TurboQuant4 / TurboQuantProd4). - /// Used by legacy `tq_l2_adc_scaled` which requires fixed-size array. + /// Only valid for 4-bit quantization (TurboQuant4 / TurboQuantProd4). + /// The codebook is guaranteed to have exactly 16 centroids at construction + /// for 4-bit configs. If the invariant is violated (programming bug), logs + /// an error and returns a zeroed fallback to avoid panicking in production. pub fn codebook_16(&self) -> &[f32; 16] { - assert_eq!( + debug_assert_eq!( self.codebook.len(), 16, - "codebook_16 requires 4-bit quantization (16 centroids), got {}", + "codebook_16 called on non-4-bit quantization (len={})", self.codebook.len() ); - self.codebook[..16].try_into().unwrap() + match self.codebook.as_slice().try_into() { + Ok(arr) => arr, + Err(_) => { + tracing::error!( + "codebook has {} entries, expected 16 — construction invariant violated", + self.codebook.len() + ); + static ZERO: [f32; 16] = [0.0; 16]; + &ZERO + } + } } /// Verify metadata integrity. Returns Err if checksum mismatch. diff --git a/src/vector/turbo_quant/encoder.rs b/src/vector/turbo_quant/encoder.rs index 1d982fa4..69bdbe4d 100644 --- a/src/vector/turbo_quant/encoder.rs +++ b/src/vector/turbo_quant/encoder.rs @@ -782,7 +782,7 @@ mod tests { normalize_to_unit(&mut v); for bits in [1u8, 2, 3, 4] { - let boundaries = scaled_boundaries_n(padded, bits); + let boundaries = scaled_boundaries_n(padded, bits).unwrap(); let code = encode_tq_mse_multibit(&v, &signs, &boundaries, bits, &mut work); let expected = code_bytes_per_vector(padded, bits); assert_eq!( @@ -794,15 +794,15 @@ mod tests { } // Specific sizes for 768d (padded=1024) - let b1 = scaled_boundaries_n(padded, 1); + let b1 = scaled_boundaries_n(padded, 1).unwrap(); let c1 = encode_tq_mse_multibit(&v, &signs, &b1, 1, &mut work); assert_eq!(c1.codes.len(), 128); // 1024/8 - let b2 = scaled_boundaries_n(padded, 2); + let b2 = scaled_boundaries_n(padded, 2).unwrap(); let c2 = encode_tq_mse_multibit(&v, &signs, &b2, 2, &mut work); assert_eq!(c2.codes.len(), 256); // 1024/4 - let b3 = scaled_boundaries_n(padded, 3); + let b3 = scaled_boundaries_n(padded, 3).unwrap(); let c3 = encode_tq_mse_multibit(&v, &signs, &b3, 3, &mut work); assert_eq!(c3.codes.len(), 384); // 1024*3/8 } @@ -813,8 +813,8 @@ mod tests { let dim = 768; let padded = padded_dimension(dim as u32); let signs = test_sign_flips(padded as usize, 12345); - let boundaries = scaled_boundaries_n(padded, 1); - let centroids = scaled_centroids_n(padded, 1); + let boundaries = scaled_boundaries_n(padded, 1).unwrap(); + let centroids = scaled_centroids_n(padded, 1).unwrap(); let mut work_enc = vec![0.0f32; padded as usize]; let mut work_dec = vec![0.0f32; padded as usize]; @@ -839,8 +839,8 @@ mod tests { let dim = 768; let padded = padded_dimension(dim as u32); let signs = test_sign_flips(padded as usize, 12345); - let boundaries = scaled_boundaries_n(padded, 2); - let centroids = scaled_centroids_n(padded, 2); + let boundaries = scaled_boundaries_n(padded, 2).unwrap(); + let centroids = scaled_centroids_n(padded, 2).unwrap(); let mut work_enc = vec![0.0f32; padded as usize]; let mut work_dec = vec![0.0f32; padded as usize]; @@ -864,8 +864,8 @@ mod tests { let dim = 768; let padded = padded_dimension(dim as u32); let signs = test_sign_flips(padded as usize, 12345); - let boundaries = scaled_boundaries_n(padded, 3); - let centroids = scaled_centroids_n(padded, 3); + let boundaries = scaled_boundaries_n(padded, 3).unwrap(); + let centroids = scaled_centroids_n(padded, 3).unwrap(); let mut work_enc = vec![0.0f32; padded as usize]; let mut work_dec = vec![0.0f32; padded as usize]; diff --git a/src/vector/turbo_quant/inner_product.rs b/src/vector/turbo_quant/inner_product.rs index 5504de0e..49b99821 100644 --- a/src/vector/turbo_quant/inner_product.rs +++ b/src/vector/turbo_quant/inner_product.rs @@ -608,9 +608,9 @@ mod tests { // v2: 3-bit MSE + QJL signs (paper-correct) let boundaries_3 = - crate::vector::turbo_quant::codebook::scaled_boundaries_n(padded as u32, 3); + crate::vector::turbo_quant::codebook::scaled_boundaries_n(padded as u32, 3).unwrap(); let centroids_3 = - crate::vector::turbo_quant::codebook::scaled_centroids_n(padded as u32, 3); + crate::vector::turbo_quant::codebook::scaled_centroids_n(padded as u32, 3).unwrap(); let code_v2 = encode_tq_prod_v2( &vec, &sign_flips, diff --git a/src/vector/turbo_quant/tq_adc.rs b/src/vector/turbo_quant/tq_adc.rs index acb124c9..83c4de38 100644 --- a/src/vector/turbo_quant/tq_adc.rs +++ b/src/vector/turbo_quant/tq_adc.rs @@ -317,15 +317,21 @@ pub fn tq_l2_adc_multibit( 4 => { // Delegate to existing optimized 4-bit path debug_assert_eq!(centroids.len(), 16); - let c: &[f32; 16] = centroids.try_into().unwrap_or_else(|_| { - panic!( + if let Ok(c) = centroids.try_into() { + tq_l2_adc_scaled(q_rotated, code, norm, c) + } else { + // Invariant violated — return max distance rather than panic + tracing::error!( "4-bit ADC requires exactly 16 centroids, got {}", centroids.len() - ) - }); - tq_l2_adc_scaled(q_rotated, code, norm, c) + ); + f32::MAX + } + } + _ => { + tracing::error!("unsupported bit width: {bits}"); + f32::MAX } - _ => panic!("unsupported bit width: {bits}"), } } @@ -593,7 +599,7 @@ pub fn brute_force_tq_adc_multibit( } let mut results: Vec<(f32, u32)> = heap.into_iter().map(|(d, id)| (d.0, id)).collect(); - results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + results.sort_by(|a, b| a.0.total_cmp(&b.0)); results .into_iter() @@ -673,7 +679,7 @@ pub fn brute_force_tq_adc( // Extract sorted results let mut results: Vec<(f32, u32)> = heap.into_iter().map(|(d, id)| (d.0, id)).collect(); - results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + results.sort_by(|a, b| a.0.total_cmp(&b.0)); results .into_iter() @@ -1033,8 +1039,8 @@ mod tests { let dim = 768; let padded = padded_dimension(dim as u32) as usize; let signs = test_sign_flips(padded, 42); - let boundaries = scaled_boundaries_n(padded as u32, 1); - let centroids = scaled_centroids_n(padded as u32, 1); + let boundaries = scaled_boundaries_n(padded as u32, 1).unwrap(); + let centroids = scaled_centroids_n(padded as u32, 1).unwrap(); let mut work = vec![0.0f32; padded]; let mut v = lcg_f32(dim, 99); @@ -1062,8 +1068,8 @@ mod tests { let dim = 768; let padded = padded_dimension(dim as u32) as usize; let signs = test_sign_flips(padded, 42); - let boundaries = scaled_boundaries_n(padded as u32, 2); - let centroids = scaled_centroids_n(padded as u32, 2); + let boundaries = scaled_boundaries_n(padded as u32, 2).unwrap(); + let centroids = scaled_centroids_n(padded as u32, 2).unwrap(); let mut work = vec![0.0f32; padded]; let mut v = lcg_f32(dim, 99); @@ -1090,8 +1096,8 @@ mod tests { let dim = 768; let padded = padded_dimension(dim as u32) as usize; let signs = test_sign_flips(padded, 42); - let boundaries = scaled_boundaries_n(padded as u32, 3); - let centroids = scaled_centroids_n(padded as u32, 3); + let boundaries = scaled_boundaries_n(padded as u32, 3).unwrap(); + let centroids = scaled_centroids_n(padded as u32, 3).unwrap(); let mut work = vec![0.0f32; padded]; let mut v = lcg_f32(dim, 99); @@ -1120,8 +1126,8 @@ mod tests { let signs = test_sign_flips(padded, 42); for bits in [1u8, 2, 3] { - let boundaries = scaled_boundaries_n(padded as u32, bits); - let centroids = scaled_centroids_n(padded as u32, bits); + let boundaries = scaled_boundaries_n(padded as u32, bits).unwrap(); + let centroids = scaled_centroids_n(padded as u32, bits).unwrap(); let mut work_enc = vec![0.0f32; padded]; let mut work_dec = vec![0.0f32; padded]; @@ -1197,8 +1203,8 @@ mod tests { let signs = test_sign_flips(padded, 42); for bits in [1u8, 2, 3] { - let boundaries = scaled_boundaries_n(padded as u32, bits); - let centroids = scaled_centroids_n(padded as u32, bits); + let boundaries = scaled_boundaries_n(padded as u32, bits).unwrap(); + let centroids = scaled_centroids_n(padded as u32, bits).unwrap(); let mut work = vec![0.0f32; padded]; let mut v = lcg_f32(dim, 99);