diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index bebaf4b8e..efd058ffb 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -63,6 +63,9 @@ scalar-quantization = [] # Enable minmax-quantization based algorithms minmax-quantization = [] +# Enable multi-vector distance benchmarks (Chamfer / MaxSim) +multi-vector = [] + # Enable Disk Index benchmarks disk-index = [ "diskann-disk/perf_test", diff --git a/diskann-benchmark/example/multi-vector-test.json b/diskann-benchmark/example/multi-vector-test.json new file mode 100644 index 000000000..28e9b9d64 --- /dev/null +++ b/diskann-benchmark/example/multi-vector-test.json @@ -0,0 +1,47 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "implementation": "optimized", + "runs": [ + { "operation": "chamfer", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 2, "num_measurements": 1 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 2, "num_measurements": 1 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float16", + "implementation": "optimized", + "runs": [ + { "operation": "chamfer", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 2, "num_measurements": 1 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "implementation": "reference", + "runs": [ + { "operation": "chamfer", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 2, "num_measurements": 1 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 2, "num_measurements": 1 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float16", + "implementation": "reference", + "runs": [ + { "operation": "max_sim", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 2, "num_measurements": 1 } + ] + } + } + ] +} diff --git a/diskann-benchmark/example/multi-vector.json b/diskann-benchmark/example/multi-vector.json new file mode 100644 index 000000000..553a6a9d8 --- /dev/null +++ b/diskann-benchmark/example/multi-vector.json @@ -0,0 +1,117 @@ +{ + "search_directories": [], + "jobs": [ + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "implementation": "optimized", + "runs": [ + { "operation": "chamfer", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 }, + + { "operation": "max_sim", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float16", + "implementation": "optimized", + "runs": [ + { "operation": "chamfer", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 }, + + { "operation": "max_sim", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float32", + "implementation": "reference", + "runs": [ + { "operation": "chamfer", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 }, + + { "operation": "max_sim", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 } + ] + } + }, + { + "type": "multi-vector-op", + "content": { + "element_type": "float16", + "implementation": "reference", + "runs": [ + { "operation": "chamfer", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "chamfer", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 }, + + { "operation": "max_sim", "num_query_vectors": 8, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 500, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 16, "num_doc_vectors": 64, "dim": 256, "loops_per_measurement": 100, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 128, "dim": 384, "loops_per_measurement": 20, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 16, "dim": 256, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 264, "loops_per_measurement": 50, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 1250, "dim": 128, "loops_per_measurement": 10, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 1250, "dim": 512, "loops_per_measurement": 2, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 64, "num_doc_vectors": 32, "dim": 128, "loops_per_measurement": 200, "num_measurements": 50 }, + { "operation": "max_sim", "num_query_vectors": 32, "num_doc_vectors": 32, "dim": 512, "loops_per_measurement": 50, "num_measurements": 50 } + ] + } + } + ] +} diff --git a/diskann-benchmark/perf_test_inputs/multi-vector-tolerance.json b/diskann-benchmark/perf_test_inputs/multi-vector-tolerance.json new file mode 100644 index 000000000..8d5997199 --- /dev/null +++ b/diskann-benchmark/perf_test_inputs/multi-vector-tolerance.json @@ -0,0 +1,16 @@ +{ + "checks": [ + { + "input": { + "type": "multi-vector-op", + "content": {} + }, + "tolerance": { + "type": "multi-vector-tolerance", + "content": { + "min_time_regression": 0.05 + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 24fe91d7e..0d1c61345 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -7,10 +7,12 @@ mod disk_index; mod exhaustive; mod filters; mod index; +mod multi_vector; pub(crate) fn register_benchmarks(registry: &mut diskann_benchmark_runner::registry::Benchmarks) { exhaustive::register_benchmarks(registry); disk_index::register_benchmarks(registry); index::register_benchmarks(registry); filters::register_benchmarks(registry); + multi_vector::register_benchmarks(registry); } diff --git a/diskann-benchmark/src/backend/multi_vector.rs b/diskann-benchmark/src/backend/multi_vector.rs new file mode 100644 index 000000000..cfdb77f33 --- /dev/null +++ b/diskann-benchmark/src/backend/multi_vector.rs @@ -0,0 +1,806 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Multi-vector distance benchmarks (Chamfer / MaxSim) with regression detection. + +use diskann_benchmark_runner::registry::Benchmarks; + +// Create a stub-module if the "multi-vector" feature is disabled. +crate::utils::stub_impl!("multi-vector", inputs::multi_vector::MultiVectorOp); + +pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { + #[cfg(feature = "multi-vector")] + { + use half::f16; + + // Optimized (architecture-dispatched QueryComputer). + benchmarks.register_regression( + "multi-vector-op-f32-optimized", + imp::Kernel::::new(), + ); + benchmarks.register_regression( + "multi-vector-op-f16-optimized", + imp::Kernel::::new(), + ); + + // Reference (Chamfer / MaxSim fallback path). + benchmarks.register_regression( + "multi-vector-op-f32-reference", + imp::Kernel::::new(), + ); + benchmarks.register_regression( + "multi-vector-op-f16-reference", + imp::Kernel::::new(), + ); + } + + // Stub implementation + #[cfg(not(feature = "multi-vector"))] + imp::register("multi-vector-op", benchmarks); +} + +#[cfg(feature = "multi-vector")] +mod imp { + use std::io::Write; + + use diskann_benchmark_runner::{ + benchmark::{PassFail, Regression}, + dispatcher::{DispatchRule, FailureScore, MatchScore}, + utils::{datatype, num::relative_change, percentiles, MicroSeconds}, + Benchmark, + }; + use diskann_quantization::multi_vector::{ + Chamfer, Init, Mat, MatRef, MaxSim, QueryComputer, Standard, + }; + use diskann_vector::distance::InnerProduct; + use diskann_vector::{DistanceFunctionMut, PureDistanceFunction}; + use half::f16; + use rand::{ + distr::{Distribution, StandardUniform}, + rngs::StdRng, + SeedableRng, + }; + use serde::{Deserialize, Serialize}; + + use crate::inputs::multi_vector::{ + Implementation, MultiVectorOp, MultiVectorTolerance, Operation, Run, + }; + + /////////// + // Utils // + /////////// + + #[derive(Debug, Clone, Copy)] + pub(super) struct DisplayWrapper<'a, T: ?Sized>(pub(super) &'a T); + + impl std::ops::Deref for DisplayWrapper<'_, T> { + type Target = T; + fn deref(&self) -> &T { + self.0 + } + } + + ////////////// + // Dispatch // + ////////////// + + /// Dispatch marker for the [`QueryComputer`] implementation. + #[derive(Debug)] + pub(super) struct Optimized; + + /// Dispatch marker for the [`Chamfer`] / [`MaxSim`] fallback. + #[derive(Debug)] + pub(super) struct Reference; + + /// A multi-vector benchmark. + pub(super) struct Kernel { + _type: std::marker::PhantomData<(I, T)>, + } + + impl Kernel { + pub(super) fn new() -> Self { + Self { + _type: std::marker::PhantomData, + } + } + } + + /// Pairs the standard `TryFrom` conversion with the static + /// description info needed for friendly diagnostics in `Benchmark::description`. + pub(super) trait ImplementationMatcher: + TryFrom + 'static + { + /// Human-readable description of which implementation this marker handles. + const DESCRIPTION: &'static str; + /// The implementation variant this marker expects (for mismatch diagnostics). + const EXPECTED: Implementation; + } + + impl TryFrom for Optimized { + type Error = FailureScore; + fn try_from(i: Implementation) -> Result { + match i { + Implementation::Optimized => Ok(Self), + _ => Err(FailureScore(1)), + } + } + } + + impl ImplementationMatcher for Optimized { + const DESCRIPTION: &'static str = "QueryComputer (architecture-dispatched)"; + const EXPECTED: Implementation = Implementation::Optimized; + } + + impl TryFrom for Reference { + type Error = FailureScore; + fn try_from(i: Implementation) -> Result { + match i { + Implementation::Reference => Ok(Self), + _ => Err(FailureScore(1)), + } + } + } + + impl ImplementationMatcher for Reference { + const DESCRIPTION: &'static str = "Chamfer / MaxSim fallback"; + const EXPECTED: Implementation = Implementation::Reference; + } + + impl Benchmark for Kernel + where + datatype::Type: DispatchRule, + I: ImplementationMatcher, + Kernel: RunBenchmark, + T: 'static, + { + type Input = MultiVectorOp; + type Output = Vec; + + fn try_match(&self, from: &MultiVectorOp) -> Result { + let mut failscore: Option = None; + if datatype::Type::::try_match(&from.element_type).is_err() { + *failscore.get_or_insert(0) += 10; + } + if let Err(FailureScore(score)) = I::try_from(from.implementation) { + *failscore.get_or_insert(0) += 2 + score; + } + + match failscore { + None => Ok(MatchScore(0)), + Some(score) => Err(FailureScore(score)), + } + } + + fn run( + &self, + input: &MultiVectorOp, + _: diskann_benchmark_runner::Checkpoint<'_>, + mut output: &mut dyn diskann_benchmark_runner::Output, + ) -> anyhow::Result { + // The dispatcher only invokes `run` after `try_match` has already accepted + // the input, so a failure here would indicate a dispatcher bug. + I::try_from(input.implementation).expect("try_match accepted the input"); + writeln!(output, "{}", input)?; + let results = self.run_benchmark(input)?; + writeln!(output, "\n\n{}", DisplayWrapper(&*results))?; + Ok(results) + } + + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + input: Option<&MultiVectorOp>, + ) -> std::fmt::Result { + match input { + None => { + writeln!( + f, + "- Element Type: {}", + diskann_benchmark_runner::dispatcher::Description::< + datatype::DataType, + datatype::Type, + >::new() + )?; + writeln!(f, "- Implementation: {}", I::DESCRIPTION)?; + } + Some(input) => { + if let Err(err) = datatype::Type::::try_match_verbose(&input.element_type) { + writeln!(f, "\n - Mismatched element type: {}", err)?; + } + if I::try_from(input.implementation).is_err() { + writeln!( + f, + "\n - Mismatched implementation: expected {}, got {}", + I::EXPECTED, + input.implementation + )?; + } + } + } + Ok(()) + } + } + + impl Regression for Kernel + where + datatype::Type: DispatchRule, + I: ImplementationMatcher, + Kernel: RunBenchmark, + T: 'static, + { + type Tolerances = MultiVectorTolerance; + type Pass = CheckResult; + type Fail = CheckResult; + + fn check( + &self, + tolerance: &MultiVectorTolerance, + _input: &MultiVectorOp, + before: &Vec, + after: &Vec, + ) -> anyhow::Result> { + anyhow::ensure!( + before.len() == after.len(), + "before has {} runs but after has {}", + before.len(), + after.len(), + ); + + let mut passed = true; + let checks: Vec = std::iter::zip(before.iter(), after.iter()) + .enumerate() + .map(|(i, (b, a))| { + anyhow::ensure!(b.run == a.run, "run {i} mismatched"); + + let computations_per_latency = b.computations_per_latency() as f64; + + let before_min = + b.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency; + let after_min = + a.percentiles.minimum.as_f64() * 1000.0 / computations_per_latency; + + let comparison = Comparison { + run: b.run.clone(), + tolerance: *tolerance, + before_min, + after_min, + }; + + match relative_change(before_min, after_min) { + Ok(change) => { + if change > tolerance.min_time_regression.get() { + passed = false; + } + } + Err(_) => passed = false, + }; + + Ok(comparison) + }) + .collect::>>()?; + + let check = CheckResult { checks }; + + if passed { + Ok(PassFail::Pass(check)) + } else { + Ok(PassFail::Fail(check)) + } + } + } + + ////////////////////// + // Regression Check // + ////////////////////// + + /// Per-run comparison result showing before/after percentile differences. + #[derive(Debug, Serialize)] + pub(super) struct Comparison { + run: Run, + tolerance: MultiVectorTolerance, + before_min: f64, + after_min: f64, + } + + /// Aggregated result of the regression check across all runs. + #[derive(Debug, Serialize)] + pub(super) struct CheckResult { + checks: Vec, + } + + impl std::fmt::Display for CheckResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let header = [ + "Operation", + "Q", + "D", + "Dim", + "Min Before (ns/IP @ Dim)", + "Min After (ns/IP @ Dim)", + "Change (%)", + "Remark", + ]; + + let mut table = + diskann_benchmark_runner::utils::fmt::Table::new(header, self.checks.len()); + + for (i, c) in self.checks.iter().enumerate() { + let mut row = table.row(i); + let change = relative_change(c.before_min, c.after_min); + + row.insert(c.run.operation, 0); + row.insert(c.run.num_query_vectors, 1); + row.insert(c.run.num_doc_vectors, 2); + row.insert(c.run.dim, 3); + row.insert(format!("{:.3}", c.before_min), 4); + row.insert(format!("{:.3}", c.after_min), 5); + match change { + Ok(change) => { + row.insert(format!("{:.3} %", change * 100.0), 6); + if change > c.tolerance.min_time_regression.get() { + row.insert("FAIL", 7); + } + } + Err(err) => { + row.insert("invalid", 6); + row.insert(err, 7); + } + } + } + + table.fmt(f) + } + } + + /////////////// + // Benchmark // + /////////////// + + pub(super) trait RunBenchmark { + fn run_benchmark(&self, input: &MultiVectorOp) -> Result, anyhow::Error>; + } + + #[derive(Debug, Serialize, Deserialize)] + pub(super) struct RunResult { + /// The configuration for this run. + run: Run, + /// Per-measurement latencies (over `loops_per_measurement` calls). + latencies: Vec, + /// Latency percentiles. + percentiles: percentiles::Percentiles, + } + + impl RunResult { + fn computations_per_latency(&self) -> usize { + self.run.num_query_vectors.get() + * self.run.num_doc_vectors.get() + * self.run.loops_per_measurement.get() + } + } + + impl std::fmt::Display for DisplayWrapper<'_, [RunResult]> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.is_empty() { + return Ok(()); + } + + // ns/IP is normalized as `min_latency_us * 1000 / (Q * D * loops)` and is + // approximately linear in `dim`. Compare across rows with the same `Dim`; + // divide further by `Dim` to recover ns per scalar multiply. + writeln!( + f, + "ns/IP = time per (query, doc) inner-product call (~ linear in Dim)" + )?; + + let header = [ + "Operation", + "Q", + "D", + "Dim", + "Min Time (ns/IP @ Dim)", + "Mean Time (ns/IP @ Dim)", + "Loops", + "Measurements", + ]; + + let mut table = diskann_benchmark_runner::utils::fmt::Table::new(header, self.len()); + + self.iter().enumerate().for_each(|(row, r)| { + let mut row = table.row(row); + + let min_latency = r + .latencies + .iter() + .min() + .copied() + .unwrap_or(MicroSeconds::new(u64::MAX)); + let mean_latency = r.percentiles.mean; + + let computations_per_latency = r.computations_per_latency() as f64; + + // Convert time from micro-seconds to nano-seconds per inner-product call + // (one (query, doc) pair, ~ linear in dim). + let min_time = min_latency.as_f64() / computations_per_latency * 1000.0; + let mean_time = mean_latency / computations_per_latency * 1000.0; + + row.insert(r.run.operation, 0); + row.insert(r.run.num_query_vectors, 1); + row.insert(r.run.num_doc_vectors, 2); + row.insert(r.run.dim, 3); + row.insert(format!("{:.3}", min_time), 4); + row.insert(format!("{:.3}", mean_time), 5); + row.insert(r.run.loops_per_measurement, 6); + row.insert(r.run.num_measurements, 7); + }); + + table.fmt(f) + } + } + + fn run_loops(run: &Run, mut body: F) -> RunResult + where + F: FnMut(), + { + let mut latencies = Vec::with_capacity(run.num_measurements.get()); + + for _ in 0..run.num_measurements.get() { + let start = std::time::Instant::now(); + for _ in 0..run.loops_per_measurement.get() { + body(); + } + latencies.push(start.elapsed().into()); + } + + let percentiles = percentiles::compute_percentiles(&mut latencies).unwrap(); + RunResult { + run: run.clone(), + latencies, + percentiles, + } + } + + /////////////////// + // Data fixtures // + /////////////////// + + const RNG_SEED: u64 = 0x12345; + + struct Data { + queries: Mat>, + docs: Mat>, + } + + impl Data + where + StandardUniform: Distribution, + { + fn new(run: &Run) -> Self { + let mut rng = StdRng::seed_from_u64(RNG_SEED); + let queries = Mat::new( + Standard::new(run.num_query_vectors.get(), run.dim.get()).unwrap(), + Init(|| StandardUniform.sample(&mut rng)), + ) + .unwrap(); + let docs = Mat::new( + Standard::new(run.num_doc_vectors.get(), run.dim.get()).unwrap(), + Init(|| StandardUniform.sample(&mut rng)), + ) + .unwrap(); + Self { queries, docs } + } + } + + ////////////////////// + // Distance kernels // + ////////////////////// + + /// Object-safe abstraction over a per-shape distance executor. + /// + /// The two implementations ([`OptimizedDistance`] and [`ReferenceDistance`]) share the + /// same hot-loop nest in [`run_with_distance`]; dispatching through `&dyn Distance` + /// keeps `run_loops` from being monomorphised over the implementation axis. + trait Distance { + fn chamfer(&self, doc: MatRef<'_, Standard>) -> f32; + fn max_sim(&self, doc: MatRef<'_, Standard>, scores: &mut [f32]); + } + + /// Distance executor that drives [`QueryComputer`] (architecture-dispatched SIMD). + struct OptimizedDistance(QueryComputer); + + impl Distance for OptimizedDistance { + fn chamfer(&self, doc: MatRef<'_, Standard>) -> f32 { + self.0.chamfer(doc) + } + fn max_sim(&self, doc: MatRef<'_, Standard>, scores: &mut [f32]) { + self.0.max_sim(doc, scores); + } + } + + /// Distance executor that drives the [`Chamfer`] / [`MaxSim`] fallback path. + struct ReferenceDistance<'a, T: Copy>( + diskann_quantization::multi_vector::distance::QueryMatRef<'a, Standard>, + ); + + impl Distance for ReferenceDistance<'_, T> + where + InnerProduct: for<'q, 'd> PureDistanceFunction<&'q [T], &'d [T], f32>, + { + fn chamfer(&self, doc: MatRef<'_, Standard>) -> f32 { + Chamfer::evaluate(self.0, doc) + } + fn max_sim(&self, doc: MatRef<'_, Standard>, scores: &mut [f32]) { + // `MaxSim::new` is a non-empty check + pointer wrap, so constructing it per + // iteration is free — no need to hoist it out of the loop. + let mut max_sim = MaxSim::new(scores).unwrap(); + let _ = max_sim.evaluate(self.0, doc); + } + } + + ///////////////////// + // Implementations // + ///////////////////// + + /// Shared loop nest. The trait-object dispatch happens once per outer iteration of + /// `run_loops`; the work inside each `chamfer` / `max_sim` call is O(Q*D*dim), so the + /// vtable hop is in the noise. + fn run_with_distance( + run: &Run, + doc: MatRef<'_, Standard>, + dist: &dyn Distance, + ) -> RunResult { + match run.operation { + Operation::Chamfer => run_loops(run, || { + let v = dist.chamfer(doc); + std::hint::black_box(v); + }), + Operation::MaxSim => { + let mut scores = vec![0.0f32; run.num_query_vectors.get()]; + run_loops(run, || { + dist.max_sim(doc, &mut scores); + std::hint::black_box(&mut scores); + }) + } + } + } + + fn run_optimized(input: &MultiVectorOp) -> anyhow::Result> + where + T: Copy, + StandardUniform: Distribution, + QueryComputer: NewFromMatRef, + OptimizedDistance: Distance, + { + let mut results = Vec::with_capacity(input.runs.len()); + for run in input.runs.iter() { + let data = Data::::new(run); + // `QueryComputer` performs query-side precomputation that is intentionally + // amortized across many `chamfer` / `max_sim` calls; construct it once per + // shape, outside the timed loop. + let dist = OptimizedDistance( as NewFromMatRef>::new_from( + data.queries.as_view(), + )); + results.push(run_with_distance(run, data.docs.as_view(), &dist)); + } + Ok(results) + } + + /// Drive the [`Chamfer`] / [`MaxSim`] fallback path. + fn run_reference(input: &MultiVectorOp) -> anyhow::Result> + where + T: Copy, + StandardUniform: Distribution, + InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>, + for<'a> ReferenceDistance<'a, T>: Distance, + { + let mut results = Vec::with_capacity(input.runs.len()); + for run in input.runs.iter() { + let data = Data::::new(run); + let dist = ReferenceDistance(data.queries.as_view().into()); + results.push(run_with_distance(run, data.docs.as_view(), &dist)); + } + Ok(results) + } + + /// Element-type-erasing constructor for [`QueryComputer`]. + /// + /// `QueryComputer::::new` is defined as an inherent method on the concrete + /// `QueryComputer` / `QueryComputer` types (not a generic), so we need + /// this shim trait to let generic code (e.g. `run_optimized`) call it. + trait NewFromMatRef { + fn new_from(query: MatRef<'_, Standard>) -> QueryComputer; + } + + impl NewFromMatRef for QueryComputer { + fn new_from(query: MatRef<'_, Standard>) -> QueryComputer { + QueryComputer::::new(query) + } + } + + impl NewFromMatRef for QueryComputer { + fn new_from(query: MatRef<'_, Standard>) -> QueryComputer { + QueryComputer::::new(query) + } + } + + impl RunBenchmark for Kernel + where + T: Copy + 'static, + StandardUniform: Distribution, + QueryComputer: NewFromMatRef, + OptimizedDistance: Distance, + { + fn run_benchmark(&self, input: &MultiVectorOp) -> anyhow::Result> { + run_optimized::(input) + } + } + + impl RunBenchmark for Kernel + where + T: Copy + 'static, + StandardUniform: Distribution, + InnerProduct: for<'a, 'b> PureDistanceFunction<&'a [T], &'b [T], f32>, + for<'a> ReferenceDistance<'a, T>: Distance, + { + fn run_benchmark(&self, input: &MultiVectorOp) -> anyhow::Result> { + run_reference::(input) + } + } + + /////////// + // Tests // + /////////// + + #[cfg(test)] + mod tests { + use std::num::NonZeroUsize; + + use diskann_benchmark_runner::{ + benchmark::{PassFail, Regression}, + utils::{datatype::DataType, num::NonNegativeFinite, percentiles::compute_percentiles}, + }; + + use super::*; + + fn tiny_run(operation: Operation) -> Run { + Run { + operation, + num_query_vectors: NonZeroUsize::new(2).unwrap(), + num_doc_vectors: NonZeroUsize::new(2).unwrap(), + dim: NonZeroUsize::new(4).unwrap(), + loops_per_measurement: NonZeroUsize::new(1).unwrap(), + num_measurements: NonZeroUsize::new(1).unwrap(), + } + } + + fn tiny_op() -> MultiVectorOp { + MultiVectorOp { + element_type: DataType::Float32, + implementation: Implementation::Optimized, + runs: vec![tiny_run(Operation::Chamfer)], + } + } + + fn tiny_result(operation: Operation, minimum: u64) -> RunResult { + let run = tiny_run(operation); + let minimum = MicroSeconds::new(minimum); + let mut latencies = vec![minimum]; + let percentiles = compute_percentiles(&mut latencies).unwrap(); + RunResult { + run, + latencies, + percentiles, + } + } + + fn tolerance(limit: f64) -> MultiVectorTolerance { + MultiVectorTolerance { + min_time_regression: NonNegativeFinite::new(limit).unwrap(), + } + } + + #[test] + fn check_rejects_mismatched_runs() { + let kernel = Kernel::::new(); + + let err = kernel + .check( + &tolerance(0.0), + &tiny_op(), + &vec![tiny_result(Operation::Chamfer, 100)], + &vec![tiny_result(Operation::MaxSim, 100)], + ) + .unwrap_err(); + + assert_eq!(err.to_string(), "run 0 mismatched"); + } + + #[test] + fn check_allows_negative_relative_change() { + let kernel = Kernel::::new(); + + let result = kernel + .check( + &tolerance(0.0), + &tiny_op(), + &vec![tiny_result(Operation::Chamfer, 100)], + &vec![tiny_result(Operation::Chamfer, 95)], + ) + .unwrap(); + + assert!(matches!(result, PassFail::Pass(_))); + } + + #[test] + fn check_passes_on_tolerance_boundary() { + let kernel = Kernel::::new(); + + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(Operation::Chamfer, 100)], + &vec![tiny_result(Operation::Chamfer, 105)], + ) + .unwrap(); + + assert!(matches!(result, PassFail::Pass(_))); + } + + #[test] + fn check_fails_above_tolerance_boundary() { + let kernel = Kernel::::new(); + + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(Operation::Chamfer, 100)], + &vec![tiny_result(Operation::Chamfer, 106)], + ) + .unwrap(); + + assert!(matches!(result, PassFail::Fail(_))); + } + + #[test] + fn check_result_display_includes_failure_details() { + let check = CheckResult { + checks: vec![Comparison { + run: tiny_run(Operation::Chamfer), + tolerance: tolerance(0.05), + before_min: 100.0, + after_min: 106.0, + }], + }; + + let rendered = check.to_string(); + assert!(rendered.contains("Operation"), "rendered = {rendered}"); + assert!(rendered.contains("chamfer"), "rendered = {rendered}"); + assert!(rendered.contains("100.000"), "rendered = {rendered}"); + assert!(rendered.contains("106.000"), "rendered = {rendered}"); + assert!(rendered.contains("6.000 %"), "rendered = {rendered}"); + assert!(rendered.contains("FAIL"), "rendered = {rendered}"); + } + + /// A "before" value of 0 means the measurement was too fast to obtain a + /// reliable signal, so we *could* be letting a regression through. We + /// require at least a non-zero value. + #[test] + fn zero_values_rejected() { + let kernel = Kernel::::new(); + + let result = kernel + .check( + &tolerance(0.05), + &tiny_op(), + &vec![tiny_result(Operation::Chamfer, 0)], + &vec![tiny_result(Operation::Chamfer, 0)], + ) + .unwrap(); + + assert!(matches!(result, PassFail::Fail(_))); + } + } +} diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index 856412e2a..414a0b52e 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod disk; pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod graph_index; +pub(crate) mod multi_vector; pub(crate) mod save_and_load; pub(crate) fn register_inputs( @@ -16,6 +17,7 @@ pub(crate) fn register_inputs( exhaustive::register_inputs(registry)?; disk::register_inputs(registry)?; filters::register_inputs(registry)?; + multi_vector::register_inputs(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/inputs/multi_vector.rs b/diskann-benchmark/src/inputs/multi_vector.rs new file mode 100644 index 000000000..8010162d6 --- /dev/null +++ b/diskann-benchmark/src/inputs/multi_vector.rs @@ -0,0 +1,190 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::num::NonZeroUsize; + +use diskann_benchmark_runner::{ + utils::{datatype::DataType, num::NonNegativeFinite}, + CheckDeserialization, Checker, +}; +use serde::{Deserialize, Serialize}; + +use crate::inputs::{as_input, Example}; + +////////////// +// Registry // +////////////// + +as_input!(MultiVectorOp); +as_input!(MultiVectorTolerance); + +pub(super) fn register_inputs( + registry: &mut diskann_benchmark_runner::registry::Inputs, +) -> anyhow::Result<()> { + registry.register::()?; + registry.register::()?; + Ok(()) +} + +//////////////// +// Enum types // +//////////////// + +/// The two distance operations exposed by `QueryComputer`. +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub(crate) enum Operation { + Chamfer, + MaxSim, +} + +impl std::fmt::Display for Operation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let st = match self { + Self::Chamfer => "chamfer", + Self::MaxSim => "max_sim", + }; + write!(f, "{}", st) + } +} + +/// Which implementation tier to benchmark. +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(crate) enum Implementation { + Optimized, + Reference, +} + +impl std::fmt::Display for Implementation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let st = match self { + Self::Optimized => "optimized", + Self::Reference => "reference", + }; + write!(f, "{}", st) + } +} + +/// One benchmark configuration: a single (operation, shape) measurement. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub(crate) struct Run { + pub(crate) operation: Operation, + pub(crate) num_query_vectors: NonZeroUsize, + pub(crate) num_doc_vectors: NonZeroUsize, + pub(crate) dim: NonZeroUsize, + pub(crate) loops_per_measurement: NonZeroUsize, + pub(crate) num_measurements: NonZeroUsize, +} + +/////////////////////// +// Multi-Vector Op // +/////////////////////// + +/// A complete multi-vector benchmark job. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct MultiVectorOp { + pub(crate) element_type: DataType, + pub(crate) implementation: Implementation, + pub(crate) runs: Vec, +} + +impl MultiVectorOp { + pub(crate) const fn tag() -> &'static str { + "multi-vector-op" + } +} + +impl CheckDeserialization for MultiVectorOp { + fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { + Ok(()) + } +} + +impl Example for MultiVectorOp { + fn example() -> Self { + const NUM_QUERY_VECTORS: NonZeroUsize = NonZeroUsize::new(32).unwrap(); + const NUM_DOC_VECTORS: NonZeroUsize = NonZeroUsize::new(64).unwrap(); + const DIM: NonZeroUsize = NonZeroUsize::new(128).unwrap(); + const LOOPS_PER_MEASUREMENT: NonZeroUsize = NonZeroUsize::new(200).unwrap(); + const NUM_MEASUREMENTS: NonZeroUsize = NonZeroUsize::new(100).unwrap(); + + let runs = vec![ + Run { + operation: Operation::Chamfer, + num_query_vectors: NUM_QUERY_VECTORS, + num_doc_vectors: NUM_DOC_VECTORS, + dim: DIM, + loops_per_measurement: LOOPS_PER_MEASUREMENT, + num_measurements: NUM_MEASUREMENTS, + }, + Run { + operation: Operation::MaxSim, + num_query_vectors: NUM_QUERY_VECTORS, + num_doc_vectors: NUM_DOC_VECTORS, + dim: DIM, + loops_per_measurement: LOOPS_PER_MEASUREMENT, + num_measurements: NUM_MEASUREMENTS, + }, + ]; + + Self { + element_type: DataType::Float32, + implementation: Implementation::Optimized, + runs, + } + } +} + +macro_rules! write_field { + ($f:ident, $field:tt, $($expr:tt)*) => { + writeln!($f, "{:>18}: {}", $field, $($expr)*) + } +} + +impl std::fmt::Display for MultiVectorOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Multi-Vector Operation\n")?; + write_field!(f, "tag", Self::tag())?; + write_field!(f, "element type", self.element_type)?; + write_field!(f, "implementation", self.implementation)?; + write_field!(f, "number of runs", self.runs.len())?; + Ok(()) + } +} + +///////////////////////////// +// Multi-Vector Tolerance // +///////////////////////////// + +/// Tolerance thresholds for multi-vector benchmark regression detection. +/// +/// Each field specifies the maximum allowed relative increase in the corresponding metric. +/// For example, a value of `0.05` means a 5% increase is tolerated. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub(crate) struct MultiVectorTolerance { + pub(crate) min_time_regression: NonNegativeFinite, +} + +impl MultiVectorTolerance { + pub(crate) const fn tag() -> &'static str { + "multi-vector-tolerance" + } +} + +impl CheckDeserialization for MultiVectorTolerance { + fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { + Ok(()) + } +} + +impl Example for MultiVectorTolerance { + fn example() -> Self { + Self { + min_time_regression: NonNegativeFinite::new(0.05) + .expect("0.05 is a valid non-negative finite"), + } + } +} diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index 424e63bb7..c7276f2e1 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -776,6 +776,92 @@ mod tests { assert!(!output_path.exists()); } + /////////////////// + // Multi-Vector // + /////////////////// + + #[test] + fn multi_vector_integration() { + let path = example_directory().join("multi-vector-test.json"); + let tempdir = tempfile::tempdir().unwrap(); + let output_path = tempdir.path().join("output.json"); + assert!(!output_path.exists()); + + let modified_input_path = tempdir.path().join("input.json"); + + let mut raw = value_from_file(&path); + prefix_search_directories(&mut raw, &root_directory()); + save_to_file(&modified_input_path, &raw); + + run_multi_vector_integration(&modified_input_path, &output_path) + } + + #[cfg(feature = "multi-vector")] + fn run_multi_vector_integration(input_path: &std::path::Path, output_path: &std::path::Path) { + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + allow_debug: true, + }; + + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + cli.run(&mut output).unwrap(); + println!( + "output = {}", + String::from_utf8(output.into_inner()).unwrap() + ); + + // Check that the results file is generated. + assert!(output_path.exists()); + } + + #[cfg(not(feature = "multi-vector"))] + fn run_multi_vector_integration(input_path: &std::path::Path, output_path: &std::path::Path) { + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + allow_debug: true, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + let err = cli.run(&mut output).unwrap_err(); + println!("err = {:?}", err); + + let output = String::from_utf8(output.into_inner()).unwrap(); + assert!(output.contains("\"multi-vector\" feature")); + println!("output = {}", output); + + // The output file should not have been created because we failed the test. + assert!(!output_path.exists()); + } + + #[test] + #[cfg(feature = "multi-vector")] + fn multi_vector_check_verify() { + let input_path = example_directory().join("multi-vector-test.json"); + let tolerance_path = project_directory() + .join("perf_test_inputs") + .join("multi-vector-tolerance.json"); + + let command = Commands::Check(diskann_benchmark_runner::app::Check::Verify { + tolerances: tolerance_path, + input_file: input_path, + }); + + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + cli.run(&mut output).unwrap(); + println!( + "output = {}", + String::from_utf8(output.into_inner()).unwrap() + ); + } + #[test] fn quiet_suppresses_check_target_warning() { let cli = Cli::from_commands(Commands::Skeleton, true); diff --git a/diskann-quantization/src/multi_vector/matrix.rs b/diskann-quantization/src/multi_vector/matrix.rs index 70629d44c..bcbafaaa3 100644 --- a/diskann-quantization/src/multi_vector/matrix.rs +++ b/diskann-quantization/src/multi_vector/matrix.rs @@ -244,6 +244,18 @@ pub unsafe trait NewOwned: ReprOwned { #[derive(Debug, Clone, Copy)] pub struct Defaulted; +/// An initializer argument to [`NewOwned`] that invokes the wrapped closure for each +/// element. +/// +/// # Example +/// ``` +/// use diskann_quantization::multi_vector::{Init, Mat, Standard}; +/// let mut n = 0; +/// let mat = Mat::new(Standard::::new(1, 4).unwrap(), Init(|| { n += 1; n })).unwrap(); +/// assert_eq!(mat.as_slice(), &[1, 2, 3, 4]); +/// ``` +pub struct Init(pub F); + /// Create a new [`Mat`] cloned from a view. pub trait NewCloned: ReprOwned { /// Clone the contents behind `v`, returning a new owning [`Mat`]. @@ -514,6 +526,22 @@ where } } +// SAFETY: The implementation uses guarantees from `Box` to ensure that the pointer +// initialized by it is non-null and properly aligned to the underlying type. +unsafe impl NewOwned> for Standard +where + T: Copy, + F: FnMut() -> T, +{ + type Error = crate::error::Infallible; + fn new_owned(self, mut init: Init) -> Result, Self::Error> { + let b: Box<[T]> = (0..self.num_elements()).map(|_| (init.0)()).collect(); + + // SAFETY: By construction, `b` has length `self.num_elements()`. + Ok(unsafe { self.box_to_mat(b) }) + } +} + // SAFETY: This checks that the slice has the correct length, which is all that is // required for [`Repr`]. unsafe impl NewRef for Standard @@ -1767,6 +1795,22 @@ mod tests { } } + #[test] + fn test_standard_new_owned_with_init() { + let mut counter: i32 = 0; + let m = Mat::new( + Standard::::new(2, 3).unwrap(), + Init(|| { + let v = counter; + counter += 1; + v + }), + ) + .unwrap(); + + assert_eq!(m.as_slice(), &[0, 1, 2, 3, 4, 5]); + } + #[test] fn matref_new_slice_length_error() { let repr = Standard::::new(3, 4).unwrap(); diff --git a/diskann-quantization/src/multi_vector/mod.rs b/diskann-quantization/src/multi_vector/mod.rs index 3670b1aaf..1d765bacc 100644 --- a/diskann-quantization/src/multi_vector/mod.rs +++ b/diskann-quantization/src/multi_vector/mod.rs @@ -74,6 +74,6 @@ pub(crate) mod matrix; pub use block_transposed::{BlockTransposed, BlockTransposedMut, BlockTransposedRef}; pub use distance::{Chamfer, MaxSim, MaxSimError, QueryComputer, QueryMatRef}; pub use matrix::{ - Defaulted, LayoutError, Mat, MatMut, MatRef, NewCloned, NewMut, NewOwned, NewRef, Overflow, - Repr, ReprMut, ReprOwned, SliceError, Standard, + Defaulted, Init, LayoutError, Mat, MatMut, MatRef, NewCloned, NewMut, NewOwned, NewRef, + Overflow, Repr, ReprMut, ReprOwned, SliceError, Standard, };