Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 151 additions & 15 deletions rust/lance-index/src/vector/sq/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use lance_core::deepsize::DeepSizeOf;
use lance_core::{Error, ROW_ID, Result};
use lance_file::previous::reader::FileReader as PreviousFileReader;
use lance_io::object_store::ObjectStore;
use lance_linalg::distance::{DistanceType, dot_distance, l2_u8::l2_u8};
use lance_linalg::distance::{DistanceType, dot_u8::dot_u8, l2_u8::l2_u8};
use lance_table::format::SelfDescribingFileReader;
use object_store::path::Path;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -374,23 +374,56 @@ impl VectorStore for ScalarQuantizationStorage {
let (offset, chunk) = self.chunk(id);
let query_sq_code = chunk.sq_code_slice(id - offset);
let bounds = self.quantizer.bounds();
let query_sq_code_sum = sq_code_sum(query_sq_code);
SQDistCalculator {
query_sq_code: SQQueryCode::Borrowed(query_sq_code),
query_sq_code_sum,
scale: sq_distance_scale(&bounds),
step: sq_quantization_step(&bounds),
lower_bound: bounds.start as f32,
storage: self,
}
}
}

#[inline]
fn sq_quantization_step(bounds: &Range<f64>) -> f32 {
(bounds.end - bounds.start) as f32 / 255.0_f32
}

#[inline]
fn sq_distance_scale(bounds: &Range<f64>) -> f32 {
let range = (bounds.end - bounds.start) as f32;
(range * range) / (255.0_f32 * 255.0_f32)
let step = sq_quantization_step(bounds);
step * step
}

#[inline]
fn sq_code_sum(sq_code: &[u8]) -> u32 {
sq_code.iter().map(|&value| value as u32).sum()
}

#[inline]
fn sq_dequantized_dot_distance(
sq_code: &[u8],
sq_code_sum: u32,
query_sq_code: &[u8],
query_sq_code_sum: u32,
step: f32,
lower_bound: f32,
) -> f32 {
let code_dot = dot_u8(sq_code, query_sq_code) as f32;
let dot = step * step * code_dot

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This expanded affine dot calculation is performed in f32, and the large offset terms can cancel in high-dimensional near-zero vectors enough to flip SQ Dot rankings.

+ lower_bound * step * (sq_code_sum + query_sq_code_sum) as f32
+ sq_code.len() as f32 * lower_bound * lower_bound;
1.0 - dot
}

pub struct SQDistCalculator<'a> {
query_sq_code: SQQueryCode<'a>,
query_sq_code_sum: u32,
scale: f32,
step: f32,
lower_bound: f32,
storage: &'a ScalarQuantizationStorage,
}

Expand Down Expand Up @@ -429,9 +462,13 @@ impl<'a> SQDistCalculator<'a> {
panic!("Unsupported data type for ScalarQuantizationStorage");
}
};
let query_sq_code_sum = sq_code_sum(&query_sq_code);
Self {
query_sq_code: SQQueryCode::Owned(query_sq_code),
query_sq_code_sum,
scale: sq_distance_scale(&bounds),
step: sq_quantization_step(&bounds),
lower_bound: bounds.start as f32,
storage,
}
}
Expand All @@ -440,14 +477,23 @@ impl<'a> SQDistCalculator<'a> {
impl DistCalculator for SQDistCalculator<'_> {
fn distance(&self, id: u32) -> f32 {
let (offset, chunk) = self.storage.chunk(id);
let sq_code = chunk.sq_code_slice(id - offset);
let chunk_id = id - offset;
let sq_code = chunk.sq_code_slice(chunk_id);
let query_sq_code = self.query_sq_code.as_slice();
let dist = match self.storage.distance_type {
DistanceType::L2 | DistanceType::Cosine => l2_u8(sq_code, query_sq_code) as f32,
DistanceType::Dot => dot_distance(sq_code, query_sq_code),
match self.storage.distance_type {
DistanceType::L2 | DistanceType::Cosine => {
l2_u8(sq_code, query_sq_code) as f32 * self.scale
}
DistanceType::Dot => sq_dequantized_dot_distance(
sq_code,
sq_code_sum(sq_code),
query_sq_code,
self.query_sq_code_sum,
self.step,
self.lower_bound,
),
_ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
};
dist * self.scale
}
}

fn distance_all(&self, _k_hint: usize) -> Vec<f32> {
Expand All @@ -470,12 +516,17 @@ impl DistCalculator for SQDistCalculator<'_> {
.chunks
.iter()
.flat_map(|c| {
c.sq_codes
.values()
.chunks_exact(c.dim())
.map(|sq_codes| dot_distance(sq_codes, query_sq_code))
c.sq_codes.values().chunks_exact(c.dim()).map(|sq_codes| {
sq_dequantized_dot_distance(
sq_codes,
sq_code_sum(sq_codes),
query_sq_code,
self.query_sq_code_sum,
self.step,
self.lower_bound,
)
})
})
.map(|dist| dist * self.scale)
.collect(),
_ => panic!("We should not reach here: sq distance can only be L2 or Dot"),
}
Expand Down Expand Up @@ -511,7 +562,7 @@ mod tests {
use std::iter::repeat_with;
use std::sync::Arc;

use arrow_array::FixedSizeListArray;
use arrow_array::{FixedSizeListArray, Float32Array};
use arrow_schema::{DataType, Field, Schema};
use lance_arrow::FixedSizeListArrayExt;
use lance_testing::datagen::generate_random_array;
Expand Down Expand Up @@ -541,6 +592,24 @@ mod tests {
RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap()
}

fn create_record_batch_from_codes(row_ids: Vec<u64>, dim: i32, codes: Vec<u8>) -> RecordBatch {
assert_eq!(codes.len(), row_ids.len() * dim as usize);

let row_ids = UInt64Array::from(row_ids);
let sq_code = UInt8Array::from(codes);
let code_arr = FixedSizeListArray::try_new_from_values(sq_code, dim).unwrap();

let schema = Arc::new(Schema::new(vec![
Field::new(ROW_ID, DataType::UInt64, false),
Field::new(
SQ_CODE_COLUMN,
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::UInt8, true)), dim),
false,
),
]));
RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(code_arr)]).unwrap()
}

#[test]
fn test_get_chunks() {
const DIM: usize = 64;
Expand Down Expand Up @@ -592,4 +661,71 @@ mod tests {
assert_eq!(offset, 400);
assert_eq!(chunk.row_id(5), 105);
}

#[test]
fn test_dot_distance_uses_dequantized_values() {
let batch = create_record_batch_from_codes(
vec![10, 11],
2,
vec![
255, 255, // dequantized to [245, 245]
100, 0, // dequantized to [90, -10]
],
);
let storage =
ScalarQuantizationStorage::try_new(8, DistanceType::Dot, -10.0..245.0, [batch], None)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new coverage only exercises unit-step and constant bounds, so regressions in the step and step * step terms can pass while arbitrary-range SQ Dot distances remain wrong.

.unwrap();

let query = Arc::new(Float32Array::from(vec![0.0, -10.0]));
let calc = storage.dist_calculator(query, 0.0);

// Code-space dot would rank row 10 first:
// [255, 255] . [10, 0] > [100, 0] . [10, 0].
// Dequantized dot ranks row 11 first:
// [245, 245] . [0, -10] = -2450
// [90, -10] . [0, -10] = 100
assert!((calc.distance(0) - 2451.0).abs() < 1e-5);
assert!((calc.distance(1) - -99.0).abs() < 1e-5);
assert!(calc.distance(1) < calc.distance(0));

let all_distances = calc.distance_all(2);
assert_eq!(all_distances.len(), 2);
assert!((all_distances[0] - calc.distance(0)).abs() < 1e-5);
assert!((all_distances[1] - calc.distance(1)).abs() < 1e-5);
}

#[test]
fn test_dot_distance_from_id_uses_dequantized_values() {
let batch = create_record_batch_from_codes(
vec![10, 11],
2,
vec![
255, 255, // dequantized to [245, 245]
100, 0, // dequantized to [90, -10]
],
);
let storage =
ScalarQuantizationStorage::try_new(8, DistanceType::Dot, -10.0..245.0, [batch], None)
.unwrap();

let calc = storage.dist_calculator_from_id(1);

// [90, -10] . [245, 245] = 19600, so dot distance is 1 - 19600.
assert!((calc.distance(0) - -19599.0).abs() < 1e-5);
}

#[test]
fn test_dot_distance_with_constant_bounds() {
let batch = create_record_batch_from_codes(vec![10], 2, vec![255, 1]);
let storage =
ScalarQuantizationStorage::try_new(8, DistanceType::Dot, 3.0..3.0, [batch], None)
.unwrap();

let query = Arc::new(Float32Array::from(vec![5.0, 6.0]));
let calc = storage.dist_calculator(query, 0.0);

// Constant bounds dequantize every value to lower_bound, regardless of code.
// [3, 3] . [3, 3] = 18, so dot distance is 1 - 18.
assert!((calc.distance(0) - -17.0).abs() < 1e-5);
}
}
Loading