From 3538dfb2e59e917b78e92dc7f9a3ec80a94fb474 Mon Sep 17 00:00:00 2001 From: roenchen Date: Thu, 18 Jun 2026 16:19:30 +0800 Subject: [PATCH] fix: compute SQ dot distance from dequantized values --- rust/lance-index/src/vector/sq/storage.rs | 166 ++++++++++++++++++++-- 1 file changed, 151 insertions(+), 15 deletions(-) diff --git a/rust/lance-index/src/vector/sq/storage.rs b/rust/lance-index/src/vector/sq/storage.rs index 1e5eebda0d9..a1deb3b4454 100644 --- a/rust/lance-index/src/vector/sq/storage.rs +++ b/rust/lance-index/src/vector/sq/storage.rs @@ -16,7 +16,7 @@ use lance_core::deepsize::DeepSizeOf; use lance_core::{Error, ROW_ID, Result}; use lance_file::previous::reader::FileReader as PreviousFileReader; use lance_io::object_store::ObjectStore; -use lance_linalg::distance::{DistanceType, dot_distance, l2_u8::l2_u8}; +use lance_linalg::distance::{DistanceType, dot_u8::dot_u8, l2_u8::l2_u8}; use lance_table::format::SelfDescribingFileReader; use object_store::path::Path; use serde::{Deserialize, Serialize}; @@ -374,23 +374,56 @@ impl VectorStore for ScalarQuantizationStorage { let (offset, chunk) = self.chunk(id); let query_sq_code = chunk.sq_code_slice(id - offset); let bounds = self.quantizer.bounds(); + let query_sq_code_sum = sq_code_sum(query_sq_code); SQDistCalculator { query_sq_code: SQQueryCode::Borrowed(query_sq_code), + query_sq_code_sum, scale: sq_distance_scale(&bounds), + step: sq_quantization_step(&bounds), + lower_bound: bounds.start as f32, storage: self, } } } +#[inline] +fn sq_quantization_step(bounds: &Range) -> f32 { + (bounds.end - bounds.start) as f32 / 255.0_f32 +} + #[inline] fn sq_distance_scale(bounds: &Range) -> f32 { - let range = (bounds.end - bounds.start) as f32; - (range * range) / (255.0_f32 * 255.0_f32) + let step = sq_quantization_step(bounds); + step * step +} + +#[inline] +fn sq_code_sum(sq_code: &[u8]) -> u32 { + sq_code.iter().map(|&value| value as u32).sum() +} + +#[inline] +fn sq_dequantized_dot_distance( + sq_code: &[u8], + sq_code_sum: u32, + query_sq_code: &[u8], + query_sq_code_sum: u32, + step: f32, + lower_bound: f32, +) -> f32 { + let code_dot = dot_u8(sq_code, query_sq_code) as f32; + let dot = step * step * code_dot + + lower_bound * step * (sq_code_sum + query_sq_code_sum) as f32 + + sq_code.len() as f32 * lower_bound * lower_bound; + 1.0 - dot } pub struct SQDistCalculator<'a> { query_sq_code: SQQueryCode<'a>, + query_sq_code_sum: u32, scale: f32, + step: f32, + lower_bound: f32, storage: &'a ScalarQuantizationStorage, } @@ -429,9 +462,13 @@ impl<'a> SQDistCalculator<'a> { panic!("Unsupported data type for ScalarQuantizationStorage"); } }; + let query_sq_code_sum = sq_code_sum(&query_sq_code); Self { query_sq_code: SQQueryCode::Owned(query_sq_code), + query_sq_code_sum, scale: sq_distance_scale(&bounds), + step: sq_quantization_step(&bounds), + lower_bound: bounds.start as f32, storage, } } @@ -440,14 +477,23 @@ impl<'a> SQDistCalculator<'a> { impl DistCalculator for SQDistCalculator<'_> { fn distance(&self, id: u32) -> f32 { let (offset, chunk) = self.storage.chunk(id); - let sq_code = chunk.sq_code_slice(id - offset); + let chunk_id = id - offset; + let sq_code = chunk.sq_code_slice(chunk_id); let query_sq_code = self.query_sq_code.as_slice(); - let dist = match self.storage.distance_type { - DistanceType::L2 | DistanceType::Cosine => l2_u8(sq_code, query_sq_code) as f32, - DistanceType::Dot => dot_distance(sq_code, query_sq_code), + match self.storage.distance_type { + DistanceType::L2 | DistanceType::Cosine => { + l2_u8(sq_code, query_sq_code) as f32 * self.scale + } + DistanceType::Dot => sq_dequantized_dot_distance( + sq_code, + sq_code_sum(sq_code), + query_sq_code, + self.query_sq_code_sum, + self.step, + self.lower_bound, + ), _ => panic!("We should not reach here: sq distance can only be L2 or Dot"), - }; - dist * self.scale + } } fn distance_all(&self, _k_hint: usize) -> Vec { @@ -470,12 +516,17 @@ impl DistCalculator for SQDistCalculator<'_> { .chunks .iter() .flat_map(|c| { - c.sq_codes - .values() - .chunks_exact(c.dim()) - .map(|sq_codes| dot_distance(sq_codes, query_sq_code)) + c.sq_codes.values().chunks_exact(c.dim()).map(|sq_codes| { + sq_dequantized_dot_distance( + sq_codes, + sq_code_sum(sq_codes), + query_sq_code, + self.query_sq_code_sum, + self.step, + self.lower_bound, + ) + }) }) - .map(|dist| dist * self.scale) .collect(), _ => panic!("We should not reach here: sq distance can only be L2 or Dot"), } @@ -511,7 +562,7 @@ mod tests { use std::iter::repeat_with; use std::sync::Arc; - use arrow_array::FixedSizeListArray; + use arrow_array::{FixedSizeListArray, Float32Array}; use arrow_schema::{DataType, Field, Schema}; use lance_arrow::FixedSizeListArrayExt; use lance_testing::datagen::generate_random_array; @@ -541,6 +592,24 @@ mod tests { RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap() } + fn create_record_batch_from_codes(row_ids: Vec, dim: i32, codes: Vec) -> RecordBatch { + assert_eq!(codes.len(), row_ids.len() * dim as usize); + + let row_ids = UInt64Array::from(row_ids); + let sq_code = UInt8Array::from(codes); + let code_arr = FixedSizeListArray::try_new_from_values(sq_code, dim).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new(ROW_ID, DataType::UInt64, false), + Field::new( + SQ_CODE_COLUMN, + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::UInt8, true)), dim), + false, + ), + ])); + RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap() + } + #[test] fn test_get_chunks() { const DIM: usize = 64; @@ -592,4 +661,71 @@ mod tests { assert_eq!(offset, 400); assert_eq!(chunk.row_id(5), 105); } + + #[test] + fn test_dot_distance_uses_dequantized_values() { + let batch = create_record_batch_from_codes( + vec![10, 11], + 2, + vec![ + 255, 255, // dequantized to [245, 245] + 100, 0, // dequantized to [90, -10] + ], + ); + let storage = + ScalarQuantizationStorage::try_new(8, DistanceType::Dot, -10.0..245.0, [batch], None) + .unwrap(); + + let query = Arc::new(Float32Array::from(vec![0.0, -10.0])); + let calc = storage.dist_calculator(query, 0.0); + + // Code-space dot would rank row 10 first: + // [255, 255] . [10, 0] > [100, 0] . [10, 0]. + // Dequantized dot ranks row 11 first: + // [245, 245] . [0, -10] = -2450 + // [90, -10] . [0, -10] = 100 + assert!((calc.distance(0) - 2451.0).abs() < 1e-5); + assert!((calc.distance(1) - -99.0).abs() < 1e-5); + assert!(calc.distance(1) < calc.distance(0)); + + let all_distances = calc.distance_all(2); + assert_eq!(all_distances.len(), 2); + assert!((all_distances[0] - calc.distance(0)).abs() < 1e-5); + assert!((all_distances[1] - calc.distance(1)).abs() < 1e-5); + } + + #[test] + fn test_dot_distance_from_id_uses_dequantized_values() { + let batch = create_record_batch_from_codes( + vec![10, 11], + 2, + vec![ + 255, 255, // dequantized to [245, 245] + 100, 0, // dequantized to [90, -10] + ], + ); + let storage = + ScalarQuantizationStorage::try_new(8, DistanceType::Dot, -10.0..245.0, [batch], None) + .unwrap(); + + let calc = storage.dist_calculator_from_id(1); + + // [90, -10] . [245, 245] = 19600, so dot distance is 1 - 19600. + assert!((calc.distance(0) - -19599.0).abs() < 1e-5); + } + + #[test] + fn test_dot_distance_with_constant_bounds() { + let batch = create_record_batch_from_codes(vec![10], 2, vec![255, 1]); + let storage = + ScalarQuantizationStorage::try_new(8, DistanceType::Dot, 3.0..3.0, [batch], None) + .unwrap(); + + let query = Arc::new(Float32Array::from(vec![5.0, 6.0])); + let calc = storage.dist_calculator(query, 0.0); + + // Constant bounds dequantize every value to lower_bound, regardless of code. + // [3, 3] . [3, 3] = 18, so dot distance is 1 - 18. + assert!((calc.distance(0) - -17.0).abs() < 1e-5); + } }