diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 2149842a8..089c9b12c 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -22,7 +22,7 @@ use crate::{ }; /// A built-in helper for benchmarking the K-nearest neighbors method -/// [`graph::DiskANNIndex::search`]. +/// [`graph::DiskANNIndex::search`] with optional post-processing support. /// /// This is intended to be used in conjunction with [`search::search`] or /// [`search::search_all`] and provides some basic additional metrics for @@ -31,21 +31,29 @@ use crate::{ /// /// The provided implementation of [`Search`] accepts [`graph::search::Knn`] /// and returns [`Metrics`] as additional output. +/// +/// # Type Parameters +/// +/// - `DP`: The data provider type +/// - `T`: The query element type +/// - `S`: The search strategy type +/// - `PP`: Optional post-processor type (defaults to `()` for no post-processing) #[derive(Debug)] -pub struct KNN +pub struct KNN where DP: provider::DataProvider, { index: Arc>, queries: Arc>, strategy: Strategy, + post_processor: Option, } -impl KNN +impl KNN where DP: provider::DataProvider, { - /// Construct a new [`KNN`] searcher. + /// Construct a new [`KNN`] searcher without post-processing. /// /// If `strategy` is one of the container variants of [`Strategy`], its length /// must match the number of rows in `queries`. If this is the case, then the @@ -67,10 +75,58 @@ where index, queries, strategy, + post_processor: None, })) } } +impl KNN +where + DP: provider::DataProvider, +{ + /// Construct a new [`KNN`] searcher with post-processing. + /// + /// # Errors + /// + /// Returns an error if the number of elements in `strategy` is not compatible with + /// the number of rows in `queries`. + pub fn with_postprocessor( + index: Arc>, + queries: Arc>, + strategy: Strategy, + post_processor: PP, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + + Ok(Arc::new(Self { + index, + queries, + strategy, + post_processor: Some(post_processor), + })) + } + + /// Access the index. + pub fn index(&self) -> &Arc> { + &self.index + } + + /// Access the queries. + pub fn queries(&self) -> &Arc> { + &self.queries + } + + /// Access the strategy. + pub fn strategy(&self) -> &Strategy { + &self.strategy + } + + /// Access the post-processor, if present. + pub fn post_processor(&self) -> &Option { + &self.post_processor + } +} + /// Additional metrics collected during [`KNN`] search. /// /// # Note @@ -85,7 +141,14 @@ pub struct Metrics { pub hops: u32, } -impl Search for KNN +impl Metrics { + /// Construct a new metrics value. + pub fn new(comparisons: u32, hops: u32) -> Self { + Self { comparisons, hops } + } +} + +impl Search for KNN where DP: provider::DataProvider, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, diff --git a/diskann-benchmark/example/async-determinant-diversity.json b/diskann-benchmark/example/async-determinant-diversity.json new file mode 100644 index 000000000..acb4260ea --- /dev/null +++ b/diskann-benchmark/example/async-determinant-diversity.json @@ -0,0 +1,51 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "graph-index-build", + "content": { + "source": { + "index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "backedge_ratio": 1.0, + "num_threads": 1, + "start_point_strategy": "medoid", + "num_insert_attempts": 1, + "saturate_inserts": false + }, + "search_phase": { + "search-type": "topk", + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "reps": 5, + "num_threads": [ + 1 + ], + "post_processor": { + "type": "determinant-diversity", + "power": 2.0, + "eta": 0.01 + }, + "runs": [ + { + "search_n": 20, + "search_l": [ + 20, + 30, + 40 + ], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/diskann-benchmark/example/disk-index-determinant-diversity.json b/diskann-benchmark/example/disk-index-determinant-diversity.json new file mode 100644 index 000000000..2962c1d97 --- /dev/null +++ b/diskann-benchmark/example/disk-index-determinant-diversity.json @@ -0,0 +1,42 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 32, + "l_build": 50, + "num_threads": 1, + "build_ram_limit_gb": 2.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "siftsmall_index_full_det_div" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "search_list": [10, 20, 40], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "post_processor": { + "type": "determinant-diversity", + "power": 2.0, + "eta": 1.0 + } + } + } + } + ] +} diff --git a/diskann-benchmark/src/backend/disk_index/search.rs b/diskann-benchmark/src/backend/disk_index/search.rs index 487432598..1e64583c5 100644 --- a/diskann-benchmark/src/backend/disk_index/search.rs +++ b/diskann-benchmark/src/backend/disk_index/search.rs @@ -14,7 +14,8 @@ use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ data_model::{AdHoc, CachingStrategy}, search::provider::{ - disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, + disk_provider::{DiskIndexSearcher, SearchPostProcessorKind}, + disk_vertex_provider_factory::DiskVertexProviderFactory, }, storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, @@ -32,7 +33,10 @@ use serde::{Deserialize, Serialize}; use crate::{ backend::disk_index::json_spancollector::JsonSpanCollector, - inputs::disk::{DiskIndexLoad, DiskSearchPhase}, + inputs::{ + disk::{DiskIndexLoad, DiskSearchPhase}, + post_processor::TopkPostProcessor, + }, utils::{datafiles, SimilarityMeasure}, }; @@ -264,6 +268,14 @@ where zipped.for_each_in_pool( pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { + let post_processor = search_params.post_processor.as_ref().map( + |TopkPostProcessor::DeterminantDiversity { power, eta }| { + SearchPostProcessorKind::DeterminantDiversity { + power: *power, + eta: *eta, + } + }, + ); let vector_filter = if search_params.vector_filters_file.is_none() { None } else { @@ -277,20 +289,21 @@ where l, Some(search_params.beam_width), vector_filter, + post_processor, search_params.is_flat_search, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; - *rc = search_result.results.len() as u32; - let actual_results = search_result - .results - .len() - .min(search_params.recall_at as usize); - for (i, result_item) in search_result - .results - .iter() - .take(actual_results) - .enumerate() + let base_count = (search_result.stats.result_count as usize) + .min(search_params.recall_at as usize) + .min(search_result.results.len()); + + *rc = base_count as u32; + id_chunk.fill(0); + dist_chunk.fill(0.0); + + for (i, result_item) in + search_result.results.iter().take(base_count).enumerate() { id_chunk[i] = result_item.vertex_id; dist_chunk[i] = result_item.distance; diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 57aafc8eb..63170e446 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -41,6 +41,7 @@ use super::{ }; use crate::{ backend::index::{ + post_processor, result::{AggregatedSearchResults, BuildResult}, search::plugins, streaming::{self, managed, stats::StreamStats, FullPrecisionStream, Managed}, @@ -73,6 +74,7 @@ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::reg benchmarks.register( "graph-index-full-precision-f32", FullPrecision::::new() + .search(plugins::DeterminantDiversity) .search(plugins::Topk) .search(plugins::Range) .search(plugins::TopkBetaFilter) @@ -447,17 +449,133 @@ impl Strategy { // Topk // //------// +struct DeterminantDiversityKnn { + index: Arc>>, + queries: Arc>, + strategy: benchmark_core::search::graph::Strategy, + post_processor: post_processor::DeterminantDiversity, +} + +impl DeterminantDiversityKnn { + fn new( + index: Arc>>, + queries: Arc>, + strategy: benchmark_core::search::graph::Strategy, + post_processor: post_processor::DeterminantDiversity, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + Ok(Arc::new(Self { + index, + queries, + strategy, + post_processor, + })) + } +} + +impl benchmark_core::search::Search for DeterminantDiversityKnn +where + common::FullPrecision: for<'a, 'b> glue::SearchStrategy< + FullPrecisionProvider, + &'a [f32], + SearchAccessor<'b>: post_processor::determinant_diversity::FullPrecisionVectorAccessor, + >, +{ + type Id = u32; + type Parameters = diskann::graph::search::Knn; + type Output = benchmark_core::search::graph::knn::Metrics; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Self::Parameters) -> benchmark_core::search::IdCount { + benchmark_core::search::IdCount::Fixed(parameters.k_value()) + } + + async fn search( + &self, + parameters: &Self::Parameters, + buffer: &mut O, + index: usize, + ) -> diskann::ANNResult + where + O: diskann::graph::SearchOutputBuffer + Send, + { + let context = DefaultContext; + let stats = self + .index + .search_with( + *parameters, + self.strategy.get(index)?, + self.post_processor, + &context, + self.queries.row(index), + buffer, + ) + .await?; + + Ok(benchmark_core::search::graph::knn::Metrics::new( + stats.cmps, stats.hops, + )) + } +} + +impl search::Plugin, SearchPhase, Strategy> + for plugins::DeterminantDiversity +where + common::FullPrecision: for<'a, 'b> glue::SearchStrategy< + FullPrecisionProvider, + &'a [f32], + SearchAccessor<'b>: post_processor::determinant_diversity::FullPrecisionVectorAccessor, + >, +{ + fn is_match(&self, phase: &SearchPhase) -> bool { + plugins::DeterminantDiversity::is_match(phase) + } + + fn kind(&self) -> &'static str { + plugins::DeterminantDiversity::as_str() + } + + fn run( + &self, + index: Arc>>, + phase: &SearchPhase, + _strategy: &Strategy, + ) -> anyhow::Result { + let (topk, params) = plugins::DeterminantDiversity::get(phase)?; + + let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( + &topk.queries, + ))?); + let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + + let knn = DeterminantDiversityKnn::new( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(common::FullPrecision), + post_processor::DeterminantDiversity::new(params.power, params.eta), + )?; + + let steps = search::knn::SearchSteps::new(topk.reps, &topk.num_threads, &topk.runs); + let results = search::knn::run(&knn, &groundtruth, steps)?; + + Ok(AggregatedSearchResults::Topk(results)) + } +} + impl search::Plugin> for plugins::Topk where DP: DataProvider + QueryType, S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::Topk::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::Topk::as_str() } fn run( @@ -496,11 +614,11 @@ where S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::Range::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::Range::as_str() } fn run( @@ -540,11 +658,11 @@ where S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::TopkBetaFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::TopkBetaFilter::as_str() } fn run( @@ -599,11 +717,11 @@ where S: for<'a> glue::DefaultSearchStrategy + Clone + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::TopkMultihopFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::TopkMultihopFilter::as_str() } fn run( diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index 269887c6d..f762d41ad 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -4,6 +4,7 @@ */ mod build; +pub(crate) mod post_processor; mod search; mod streaming; diff --git a/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs new file mode 100644 index 000000000..aa915f0fd --- /dev/null +++ b/diskann-benchmark/src/backend/index/post_processor/determinant_diversity.rs @@ -0,0 +1,88 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::future::Future; + +use diskann::graph::search_output_buffer::SearchOutputBuffer; +use diskann::{error::ANNError, graph::glue, neighbor::Neighbor, provider::Accessor}; +use diskann_providers::model::graph::provider::async_::{ + determinant_diversity_post_process, inmem, +}; +use diskann_utils::future::AsyncFriendly; + +pub(crate) trait FullPrecisionVectorAccessor: Accessor + Send { + fn get_full_precision_vector( + &mut self, + id: Self::Id, + ) -> impl Future, ANNError>> + Send; +} + +impl FullPrecisionVectorAccessor for inmem::FullAccessor<'_, f32, Q, D, Ctx> +where + Q: AsyncFriendly, + D: AsyncFriendly, + Ctx: diskann::provider::ExecutionContext, +{ + async fn get_full_precision_vector(&mut self, id: Self::Id) -> Result, ANNError> { + self.get_element(id) + .await + .map(|vector| vector.to_vec()) + .map_err(Into::into) + } +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct DeterminantDiversity { + power: f32, + eta: f32, +} + +impl DeterminantDiversity { + pub(crate) const fn new(power: f32, eta: f32) -> Self { + Self { power, eta } + } +} + +impl glue::SearchPostProcess for DeterminantDiversity +where + A: FullPrecisionVectorAccessor + diskann::provider::BuildQueryComputer + Send, + T: AsRef<[f32]> + Send + Sync, +{ + type Error = ANNError; + + async fn post_process( + &self, + accessor: &mut A, + query: T, + _computer: &>::QueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let candidates: Vec> = candidates.collect(); + let mut embedded = Vec::with_capacity(candidates.len()); + + for candidate in &candidates { + embedded.push(( + candidate.id, + candidate.distance, + accessor.get_full_precision_vector(candidate.id).await?, + )); + } + + let reranked = determinant_diversity_post_process( + embedded, + query.as_ref(), + candidates.len(), + self.eta, + self.power, + ); + + Ok(output.extend(reranked)) + } +} diff --git a/diskann-benchmark/src/backend/index/post_processor/mod.rs b/diskann-benchmark/src/backend/index/post_processor/mod.rs new file mode 100644 index 000000000..4afaab925 --- /dev/null +++ b/diskann-benchmark/src/backend/index/post_processor/mod.rs @@ -0,0 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +pub(crate) mod determinant_diversity; + +pub(crate) use determinant_diversity::DeterminantDiversity; diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index b50e69010..a4485933c 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -78,45 +78,19 @@ pub(crate) trait Knn { // Impls // /////////// -impl Knn for Arc> +impl Knn for Arc where - DP: diskann::provider::DataProvider, - core_search::graph::KNN: core_search::Search< - Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, - Output = core_search::graph::knn::Metrics, - >, + I: benchmark_core::recall::RecallCompatible, + R: core_search::Search< + Id = I, + Parameters = diskann::graph::search::Knn, + Output = core_search::graph::knn::Metrics, + > + 'static, { fn search_all( &self, parameters: Vec>, - groundtruth: &dyn benchmark_core::recall::Rows, - recall_k: usize, - recall_n: usize, - ) -> anyhow::Result> { - let results = core_search::search_all( - self.clone(), - parameters.into_iter(), - core_search::graph::knn::Aggregator::new(groundtruth, recall_k, recall_n), - )?; - - Ok(results.into_iter().map(SearchResults::new).collect()) - } -} - -impl Knn for Arc> -where - DP: diskann::provider::DataProvider, - core_search::graph::MultiHop: core_search::Search< - Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, - Output = core_search::graph::knn::Metrics, - >, -{ - fn search_all( - &self, - parameters: Vec>, - groundtruth: &dyn benchmark_core::recall::Rows, + groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, ) -> anyhow::Result> { diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 43b8ba3e8..bb7a9a7ae 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -36,9 +36,14 @@ use std::sync::Arc; use diskann::{graph::DiskANNIndex, provider::DataProvider}; use diskann_benchmark_runner::utils::fmt::{Delimit, Quote}; +use diskann_providers::post_processor::DeterminantDiversityParams; use crate::{ - backend::index::result::AggregatedSearchResults, inputs::graph_index::SearchPhaseKind, + backend::index::result::AggregatedSearchResults, + inputs::{ + graph_index::{SearchPhase, TopkSearchPhase}, + post_processor::TopkPostProcessor, + }, }; /// A dyn-compatible search plugin for `DP`. @@ -145,9 +150,49 @@ where pub(crate) struct Topk; impl Topk { - /// Returns [`SearchPhaseKind::Topk`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Topk + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase + .as_topk() + .ok() + .is_some_and(|topk| topk.post_processor.is_none()) + } + + pub(crate) const fn as_str() -> &'static str { + "topk" + } +} + +/// A search plugin for determinant-diversity top-k post-processing. +#[derive(Debug, Clone, Copy)] +pub(crate) struct DeterminantDiversity; + +impl DeterminantDiversity { + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase + .as_topk() + .ok() + .and_then(|topk| topk.post_processor.as_ref()) + .is_some_and(|pp| matches!(pp, TopkPostProcessor::DeterminantDiversity { .. })) + } + + pub(crate) const fn as_str() -> &'static str { + "topk + determinant-diversity" + } + + pub(crate) fn get( + phase: &SearchPhase, + ) -> anyhow::Result<(&TopkSearchPhase, DeterminantDiversityParams)> { + let topk = phase.as_topk()?; + match topk.post_processor.as_ref() { + Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => { + let params = DeterminantDiversityParams::new(*power, *eta) + .map_err(|e| anyhow::anyhow!("{}", e))?; + Ok((topk, params)) + } + _ => Err(anyhow::anyhow!( + "determinant-diversity plugin selected for non determinant-diversity input", + )), + } } } @@ -156,9 +201,12 @@ impl Topk { pub(crate) struct Range; impl Range { - /// Returns [`SearchPhaseKind::Range`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Range + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_range().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "range" } } @@ -167,9 +215,12 @@ impl Range { pub(crate) struct TopkBetaFilter; impl TopkBetaFilter { - /// Returns [`SearchPhaseKind::TopkBetaFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkBetaFilter + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_beta_filter().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "topk + beta filter" } } @@ -178,8 +229,11 @@ impl TopkBetaFilter { pub(crate) struct TopkMultihopFilter; impl TopkMultihopFilter { - /// Returns [`SearchPhaseKind::TopkMultihopFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkMultihopFilter + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_multihop_filter().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "topk + multihop filter" } } diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 20e9c0e29..7590f3321 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -86,7 +86,7 @@ mod imp { }, inputs::{ exhaustive, - graph_index::{SearchPhase, SphericalQuantBuild}, + graph_index::{SearchPhase, SearchPhaseKind, SphericalQuantBuild}, }, utils::{ self, datafiles, @@ -363,11 +363,11 @@ mod imp { for search::plugins::Topk { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::Topk::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::Topk.as_str() } fn run( @@ -402,11 +402,11 @@ mod imp { for search::plugins::Range { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::Range::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::Range.as_str() } fn run( @@ -444,11 +444,11 @@ mod imp { for search::plugins::TopkBetaFilter { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::TopkBetaFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::TopkBetaFilter.as_str() } fn run( @@ -498,11 +498,11 @@ mod imp { for search::plugins::TopkMultihopFilter { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::TopkMultihopFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::TopkMultihopFilter.as_str() } fn run( diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 2951d1fe4..339fb12c0 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -15,7 +15,7 @@ use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, ge use serde::{Deserialize, Serialize}; use crate::{ - inputs::{as_input, Example}, + inputs::{as_input, post_processor::TopkPostProcessor, Example}, utils::SimilarityMeasure, }; @@ -85,6 +85,7 @@ pub(crate) struct DiskSearchPhase { pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, pub(crate) search_io_limit: Option, + pub(crate) post_processor: Option, } ///////// @@ -234,6 +235,12 @@ impl CheckDeserialization for DiskSearchPhase { anyhow::bail!("search_io_limit must be positive if specified"); } } + + if let Some(pp) = self.post_processor.as_mut() { + pp.check_deserialization(checker) + .context("invalid disk search post processor")?; + } + Ok(()) } } @@ -272,6 +279,7 @@ impl Example for DiskIndexOperation { vector_filters_file: None, num_nodes_to_cache: None, search_io_limit: None, + post_processor: None, }; Self { @@ -397,6 +405,14 @@ impl DiskSearchPhase { Some(lim) => write_field!(f, "Search IO Limit", format!("{lim}"))?, None => write_field!(f, "Search IO Limit", "none (defaults to `usize::MAX`)")?, } + match &self.post_processor { + Some(TopkPostProcessor::DeterminantDiversity { power, eta }) => { + write_field!(f, "Post Processor", "determinant-diversity")?; + write_field!(f, "DetDiv Power", power)?; + write_field!(f, "DetDiv Eta", eta)?; + } + None => write_field!(f, "Post Processor", "none")?, + } Ok(()) } } diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 849b1a381..7d7fc7040 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::{ - inputs::{self, as_input, save_and_load, Example}, + inputs::{self, as_input, post_processor::TopkPostProcessor, save_and_load, Example}, utils::SimilarityMeasure, }; @@ -126,6 +126,7 @@ pub(crate) struct TopkSearchPhase { // Enable sweeping threads pub(crate) num_threads: Vec, pub(crate) runs: Vec, + pub(crate) post_processor: Option, } impl CheckDeserialization for TopkSearchPhase { @@ -139,6 +140,12 @@ impl CheckDeserialization for TopkSearchPhase { .with_context(|| format!("search run {}", i))?; } + if let Some(post_processor) = self.post_processor.as_mut() { + post_processor + .check_deserialization(checker) + .context("invalid topk post processor")?; + } + Ok(()) } } @@ -166,6 +173,7 @@ impl Example for TopkSearchPhase { reps: REPS, num_threads: THREAD_COUNTS.to_vec(), runs, + post_processor: None, } } } @@ -416,6 +424,14 @@ impl CheckDeserialization for SearchPhase { } } +fn has_topk_determinant_diversity(phase: &SearchPhase) -> bool { + phase + .as_topk() + .ok() + .and_then(|topk| topk.post_processor.as_ref()) + .is_some_and(|pp| matches!(pp, TopkPostProcessor::DeterminantDiversity { .. })) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum SearchPhaseKind { Topk, @@ -769,6 +785,14 @@ impl CheckDeserialization for IndexOperation { self.source.check_deserialization(checker)?; self.search_phase.check_deserialization(checker)?; + if has_topk_determinant_diversity(&self.search_phase) + && *self.source.data_type() != DataType::Float32 + { + anyhow::bail!( + "determinant-diversity post-processor requires graph-index full precision float32 input" + ); + } + Ok(()) } } @@ -847,7 +871,15 @@ impl IndexPQOperation { impl CheckDeserialization for IndexPQOperation { fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.index_operation.check_deserialization(checker) + self.index_operation.check_deserialization(checker)?; + + if has_topk_determinant_diversity(&self.index_operation.search_phase) { + anyhow::bail!( + "determinant-diversity post-processor is only supported on graph-index full precision float32 topk" + ); + } + + Ok(()) } } @@ -933,7 +965,15 @@ impl CheckDeserialization for IndexSQOperation { )); } - self.index_operation.check_deserialization(checker) + self.index_operation.check_deserialization(checker)?; + + if has_topk_determinant_diversity(&self.index_operation.search_phase) { + anyhow::bail!( + "determinant-diversity post-processor is only supported on graph-index full precision float32 topk" + ); + } + + Ok(()) } } @@ -1012,6 +1052,12 @@ impl CheckDeserialization for SphericalQuantBuild { self.build.check_deserialization(checker)?; self.search_phase.check_deserialization(checker)?; + if has_topk_determinant_diversity(&self.search_phase) { + anyhow::bail!( + "determinant-diversity post-processor is only supported on graph-index full precision float32 topk" + ); + } + if self.build.save_path.is_some() { return Err(anyhow::anyhow!( "Spherical quantization does not support saving the index" @@ -1287,6 +1333,13 @@ impl CheckDeserialization for DynamicIndexRun { self.build.check_deserialization(checker)?; self.runbook_params.check_deserialization(checker)?; self.search_phase.check_deserialization(checker)?; + + if has_topk_determinant_diversity(&self.search_phase) { + anyhow::bail!( + "determinant-diversity post-processor is only supported on graph-index full precision float32 topk" + ); + } + Ok(()) } } diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index 856412e2a..e9a3a1775 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod disk; pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod graph_index; +pub(crate) mod post_processor; pub(crate) mod save_and_load; pub(crate) fn register_inputs( diff --git a/diskann-benchmark/src/inputs/post_processor.rs b/diskann-benchmark/src/inputs/post_processor.rs new file mode 100644 index 000000000..5ff739321 --- /dev/null +++ b/diskann-benchmark/src/inputs/post_processor.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use diskann_benchmark_runner::{CheckDeserialization, Checker}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub(crate) enum TopkPostProcessor { + DeterminantDiversity { power: f32, eta: f32 }, +} + +impl CheckDeserialization for TopkPostProcessor { + fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { + match self { + TopkPostProcessor::DeterminantDiversity { power, eta } => { + if *power <= 0.0 { + anyhow::bail!("determinant-diversity power must be > 0.0, got: {}", power); + } + if *eta < 0.0 { + anyhow::bail!("determinant-diversity eta must be >= 0.0, got: {}", eta); + } + Ok(()) + } + } + } +} diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index efb9bf697..cb38cbd1f 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1092,6 +1092,7 @@ pub(crate) mod disk_index_builder_tests { &mut indices, &mut distances, &mut associated_data, + None, &|_| true, false, ); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 1344605f4..039499505 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -36,7 +36,10 @@ use diskann::{ }; use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ - model::{compute_pq_distance, compute_pq_distance_for_pq_coordinates}, + model::{ + compute_pq_distance, compute_pq_distance_for_pq_coordinates, + graph::provider::async_::determinant_diversity_post_process, + }, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, }; use diskann_utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}; @@ -273,12 +276,37 @@ pub struct RerankAndFilter<'a> { filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), } +#[derive(Clone, Copy)] +pub struct DeterminantDiversityAndFilter<'a> { + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + power: f32, + eta: f32, +} + +#[derive(Clone, Copy)] +pub enum SearchPostProcessorKind { + RerankAndFilter, + DeterminantDiversity { power: f32, eta: f32 }, +} + +#[derive(Clone, Copy)] +pub enum DiskSearchPostProcessor<'a> { + RerankAndFilter(RerankAndFilter<'a>), + DeterminantDiversity(DeterminantDiversityAndFilter<'a>), +} + impl<'a> RerankAndFilter<'a> { - fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { + pub fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { Self { filter } } } +impl<'a> DeterminantDiversityAndFilter<'a> { + pub fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), power: f32, eta: f32) -> Self { + Self { filter, power, eta } + } +} + impl SearchPostProcess< DiskAccessor<'_, Data, VP>, @@ -340,6 +368,123 @@ where } } +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + &[Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DeterminantDiversityAndFilter<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + _computer: &DiskQueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + let provider = accessor.provider; + let query_f32 = Data::VectorDataType::as_f32(query).map_err(Into::into)?; + + let candidate_ids: Vec = candidates + .map(|candidate| candidate.id) + .filter(|id| (self.filter)(id)) + .collect(); + + if candidate_ids.is_empty() { + return Ok(0); + } + + ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &candidate_ids)?; + + let mut candidate_vectors = Vec::with_capacity(candidate_ids.len()); + let mut associated_data = HashMap::with_capacity(candidate_ids.len()); + + for id in candidate_ids { + let vector = accessor.scratch.vertex_provider.get_vector(&id)?; + let distance = provider + .distance_comparer + .evaluate_similarity(query, vector); + let vector_f32 = Data::VectorDataType::as_f32(vector).map_err(Into::into)?; + let data = accessor.scratch.vertex_provider.get_associated_data(&id)?; + + candidate_vectors.push((id, distance, vector_f32.to_vec())); + associated_data.insert(id, *data); + } + + let reranked = determinant_diversity_post_process( + candidate_vectors, + &query_f32, + usize::MAX, + self.eta, + self.power, + ); + + Ok( + output.extend(reranked.into_iter().filter_map(|(id, distance)| { + associated_data + .get(&id) + .copied() + .map(|data| ((id, data), distance)) + })), + ) + } +} + +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + &[Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DiskSearchPostProcessor<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + computer: &DiskQueryComputer, + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + match self { + DiskSearchPostProcessor::RerankAndFilter(pp) => { + pp.post_process(accessor, query, computer, candidates, output) + .await + } + DiskSearchPostProcessor::DeterminantDiversity(pp) => { + pp.post_process(accessor, query, computer, candidates, output) + .await + } + } + } +} + impl<'this, Data, ProviderFactory> SearchStrategy, &[Data::VectorDataType]> for DiskSearchStrategy<'this, Data, ProviderFactory> where @@ -917,6 +1062,7 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. + #[allow(clippy::too_many_arguments)] pub fn search( &self, query: &[Data::VectorDataType], @@ -924,6 +1070,7 @@ where search_list_size: u32, beam_width: Option, vector_filter: Option>, + post_processor: Option, is_flat_search: bool, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); @@ -932,6 +1079,20 @@ where let mut associated_data = vec![Data::AssociatedDataType::default(); return_list_size as usize]; + let vector_filter = vector_filter.unwrap_or(default_vector_filter::()); + let post_processor = post_processor.map(|processor| match processor { + SearchPostProcessorKind::RerankAndFilter => DiskSearchPostProcessor::RerankAndFilter( + RerankAndFilter::new(vector_filter.as_ref()), + ), + SearchPostProcessorKind::DeterminantDiversity { power, eta } => { + DiskSearchPostProcessor::DeterminantDiversity(DeterminantDiversityAndFilter::new( + vector_filter.as_ref(), + power, + eta, + )) + } + }); + let stats = self.search_internal( query, return_list_size as usize, @@ -941,7 +1102,8 @@ where &mut indices, &mut distances, &mut associated_data, - &vector_filter.unwrap_or(default_vector_filter::()), + post_processor, + vector_filter.as_ref(), is_flat_search, )?; @@ -968,7 +1130,7 @@ where /// Perform a raw search on the disk index. /// This is a lower-level API that allows more control over the search parameters and output buffers. #[allow(clippy::too_many_arguments)] - pub(crate) fn search_internal( + pub fn search_internal( &self, query: &[Data::VectorDataType], k_value: usize, @@ -978,6 +1140,7 @@ where indices: &mut [u32], distances: &mut [f32], associated_data: &mut [Data::AssociatedDataType], + post_processor: Option>, vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), is_flat_search: bool, ) -> ANNResult { @@ -1000,10 +1163,18 @@ where &Knn::new(k, l, beam_width)?, &mut result_output_buffer, ))? + } else if let Some(processor) = post_processor { + self.runtime.block_on(self.index.search_with( + Knn::new(k, l, beam_width)?, + &strategy, + processor, + &DefaultContext, + strategy.query, + &mut result_output_buffer, + ))? } else { - let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( - knn_search, + Knn::new(k, l, beam_width)?, &strategy, &DefaultContext, strategy.query, @@ -1400,6 +1571,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &(|_| true), false, ); @@ -1448,7 +1620,15 @@ mod disk_provider_tests { .for_each_in_pool(pool.as_ref(), |(i, query)| { let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false) + .search( + query, + params.k as u32, + params.l as u32, + beam_width, + None, + None, + false, + ) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1558,6 +1738,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &|_| true, false, ); @@ -1628,6 +1809,7 @@ mod disk_provider_tests { search_list_size, Some(4), None, + None, false, ); assert!(result.is_ok(), "Expected search to succeed"); @@ -1966,6 +2148,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &vector_filter, is_flat_search, ); @@ -1988,6 +2171,7 @@ mod disk_provider_tests { 10, None, // beam_width Some(Box::new(vector_filter)), + None, is_flat_search, ); diff --git a/diskann-providers/src/lib.rs b/diskann-providers/src/lib.rs index 0edeb2625..8b0aa43ba 100644 --- a/diskann-providers/src/lib.rs +++ b/diskann-providers/src/lib.rs @@ -14,6 +14,8 @@ pub mod model; pub mod common; +pub mod post_processor; + pub mod index; pub mod storage; diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index cf719e730..5ad549563 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -8,9 +8,10 @@ pub mod common; pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; pub(crate) mod postprocess; - +// Re-export from parent module for backward compatibility. +// The algorithm is not async-specific and lives in provider::determinant_diversity. pub mod distances; - +pub use super::determinant_diversity_post_process; pub mod memory_vector_provider; pub use memory_vector_provider::MemoryVectorProviderAsync; diff --git a/diskann-providers/src/model/graph/provider/determinant_diversity.rs b/diskann-providers/src/model/graph/provider/determinant_diversity.rs new file mode 100644 index 000000000..6563bc00e --- /dev/null +++ b/diskann-providers/src/model/graph/provider/determinant_diversity.rs @@ -0,0 +1,427 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Determinant-Diversity post-processing for search results. +//! +//! This module implements the Determinant-Diversity algorithm for diversity-promoting +//! reranking of approximate nearest neighbor search results. The algorithm takes +//! relevance-ranked candidates and reorders them to maximize geometric diversity +//! while maintaining relevance to the original query. +//! +//! # Algorithm Overview +//! +//! Determinant-Diversity selects a diverse subset from an initial set of candidates +//! by iteratively choosing points that maximize the determinant of the distance matrix. +//! This creates a diverse set that is both relevant to the query and geometrically spread out. +//! +//! # Parameters +//! +//! - **power**: Relevance weighting exponent (must be > 0.0). Controls the emphasis on +//! maintaining relevance scores from the initial search. Higher values prefer relevance +//! over diversity. +//! +//! - **eta**: Numerical stability parameter (must be >= 0.0). Used for ridge regularization: +//! - `eta = 0`: Exact determinant computation (can be numerically unstable for some inputs) +//! - `eta > 0`: Ridge-regularized computation for improved numerical stability +//! +//! # Variants +//! +//! The module provides two implementations: +//! +//! - `post_process_with_eta_f32()`: Uses ridge regularization for numerical stability +//! - `post_process_without_eta_f32()`: Computes exact determinants (faster but less stable) +//! +//! These are selected automatically based on the eta parameter value. +//! +//! # Time Complexity +//! +//! O(m³) where m is the number of candidates, due to determinant computation. +//! In practice, m is typically small (search returns hundreds of candidates, +//! but only top-k ≪ m are selected). +//! +//! # References +//! +//! The algorithm is based on diversity-promoting ranking methods for nearest neighbor search, +//! as used in approximate nearest neighbor indices like DiskANN. + +use diskann_utils::views::Matrix; +use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; + +pub fn determinant_diversity_post_process( + candidates: Vec<(Id, f32, Vec)>, + query: &[f32], + k: usize, + determinant_diversity_eta: f32, + determinant_diversity_power: f32, +) -> Vec<(Id, f32)> { + if candidates.is_empty() || query.is_empty() { + return Vec::new(); + } + + let candidates: Vec<_> = candidates + .into_iter() + .filter(|(_, _, vector)| vector.len() == query.len()) + .collect(); + + if candidates.is_empty() { + return Vec::new(); + } + + let k = k.min(candidates.len()); + if k == 0 { + return Vec::new(); + } + + if candidates[0].2.is_empty() { + return Vec::new(); + } + + let distance_range = { + let mut min_distance = f32::INFINITY; + let mut max_distance = f32::NEG_INFINITY; + + for (_, distance, _) in &candidates { + min_distance = min_distance.min(*distance); + max_distance = max_distance.max(*distance); + } + + (min_distance, max_distance) + }; + + // For eta=0, the inv_sqrt_eta factor is 1.0 (greedy orthogonalization without regularization). + // For eta>0, the factor scales residuals for ridge-regularized determinant computation. + let inv_sqrt_eta = if determinant_diversity_eta > 0.0 { + 1.0 / determinant_diversity_eta.sqrt() + } else { + 1.0 + }; + + greedy_orthogonal_select( + candidates, + k, + determinant_diversity_power, + inv_sqrt_eta, + distance_range, + ) +} + +/// Core greedy selection algorithm for Determinant-Diversity. +/// +/// Iteratively selects the candidate with the largest residual norm after projecting +/// out previously selected candidates. The `inv_sqrt_eta` parameter controls the +/// ridge-regularization scaling: +/// +/// - `inv_sqrt_eta = 1.0`: exact greedy orthogonalization (eta=0 case) +/// - `inv_sqrt_eta = 1/sqrt(eta)`: ridge-regularized variant for numerical stability +/// +/// This unified implementation replaces two nearly-identical functions that only +/// differed in whether the scale factor included the eta term. +fn greedy_orthogonal_select( + candidates: Vec<(Id, f32, Vec)>, + k: usize, + power: f32, + inv_sqrt_eta: f32, + distance_range: (f32, f32), +) -> Vec<(Id, f32)> { + let n = candidates.len(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + let dim = candidates[0].2.len(); + + // Use a contiguous Matrix allocation for residuals instead of Vec>. + // This reduces the number of heap allocations from O(n) to O(1) and improves + // cache locality when accessing residuals sequentially during orthogonalization. + let mut residuals = Matrix::new(0.0f32, n, dim); + let mut norms_sq = Vec::with_capacity(n); + + for (i, (_, distance_to_query, v)) in candidates.iter().enumerate() { + let scale = + distance_to_similarity(*distance_to_query, distance_range).powf(power) * inv_sqrt_eta; + let row = residuals.row_mut(i); + for (r, &x) in row.iter_mut().zip(v.iter()) { + *r = x * scale; + } + let norm_sq = dot_product(residuals.row(i), residuals.row(i)); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + let mut projections = vec![0.0f32; n]; + + for _ in 0..k { + let best_idx = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i); + + let Some(selected_index) = best_idx else { + break; + }; + + selected.push(selected_index); + available[selected_index] = false; + + if selected.len() == k { + break; + } + + let best_norm_sq = norms_sq[selected_index]; + if best_norm_sq <= 0.0 { + continue; + } + + let inv_norm_sq = 1.0 / best_norm_sq; + // Clone selected row before mutable iteration over remaining rows. + let r_star_copy: Vec = residuals.row(selected_index).to_vec(); + + for i in 0..n { + if !available[i] { + projections[i] = 0.0; + } else { + projections[i] = dot_product(residuals.row(i), &r_star_copy) * inv_norm_sq; + } + } + + for i in 0..n { + if !available[i] { + continue; + } + + let projection = projections[i]; + for (residual, &star) in residuals.row_mut(i).iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; + } + + norms_sq[i] = (norms_sq[i] - projection * projection * best_norm_sq).max(0.0); + } + } + + selected + .iter() + .map(|&idx| { + let (id, dist, _) = &candidates[idx]; + (*id, *dist) + }) + .collect() +} + +fn distance_to_similarity(distance: f32, distance_range: (f32, f32)) -> f32 { + let (min_distance, max_distance) = distance_range; + let span = (max_distance - min_distance).max(f32::EPSILON); + + // Distances are lower-is-better in DiskANN distance semantics. + ((max_distance - distance) / span).max(0.0) + f32::EPSILON +} + +#[inline] +fn dot_product(a: &[f32], b: &[f32]) -> f32 { + >>::evaluate(a, b) + .into_inner() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_candidates() { + let result = + determinant_diversity_post_process::(Vec::new(), &[1.0, 2.0], 5, 0.5, 1.0); + assert_eq!(result.len(), 0); + } + + #[test] + fn test_empty_query() { + let candidates = vec![(0u32, 0.5, vec![1.0, 2.0])]; + let result = determinant_diversity_post_process(candidates, &[], 5, 0.5, 1.0); + assert_eq!(result.len(), 0); + } + + #[test] + fn test_mismatched_dimensions() { + let candidates = vec![ + (0u32, 0.5, vec![1.0, 2.0]), + (1u32, 0.3, vec![1.0]), // Wrong dimension + ]; + let query = &[1.0, 2.0, 3.0]; + let result = determinant_diversity_post_process(candidates, query, 5, 0.5, 1.0); + assert_eq!(result.len(), 0); // All candidates filtered due to dimension mismatch + } + + #[test] + fn test_single_candidate() { + let candidates = vec![(0u32, 0.5, vec![1.0, 2.0])]; + let query = &[1.0, 2.0]; + let result = determinant_diversity_post_process(candidates, query, 5, 0.5, 1.0); + assert_eq!(result.len(), 1); + assert_eq!(result[0].0, 0); + } + + #[test] + fn test_k_larger_than_candidates() { + let candidates = vec![(0u32, 0.5, vec![1.0, 0.0]), (1u32, 0.3, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 10, 0.5, 1.0); + assert_eq!(result.len(), 2); // Should return min(k, candidates.len()) + } + + #[test] + fn test_with_eta_diversity() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.9, 0.1]), + (2u32, 0.3, vec![0.8, 0.2]), + ]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 1.0, 1.0); + + assert_eq!(result.len(), 2); + // Should select based on diversity metric with eta > 0 + assert!(result.iter().all(|(id, _)| *id < 3)); + } + + #[test] + fn test_without_eta_greedy() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.9, 0.1]), + (2u32, 0.3, vec![0.8, 0.2]), + ]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + assert_eq!(result.len(), 2); + // Should select based on greedy orthogonalization (eta == 0) + assert!(result.iter().all(|(id, _)| *id < 3)); + } + + #[test] + fn test_power_parameter() { + let candidates = vec![(0u32, 0.1, vec![1.0, 0.0]), (1u32, 0.2, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + + // Test with different power values - should still work without panicking + let result1 = determinant_diversity_post_process(candidates.clone(), query, 2, 0.0, 1.0); + let result2 = determinant_diversity_post_process(candidates, query, 2, 0.0, 2.0); + + assert_eq!(result1.len(), 2); + assert_eq!(result2.len(), 2); + } + + #[test] + fn test_distances_preserved() { + let candidates = vec![(0u32, 0.5, vec![1.0, 0.0]), (1u32, 0.3, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + // Verify that distances are preserved from input + assert!(result.iter().all(|(_, dist)| *dist == 0.5 || *dist == 0.3)); + } + + /// Verify that diversity is actually promoted: when candidates lie along orthogonal + /// directions, a 2-element diverse subset should choose orthogonal pairs over similar ones. + /// + /// Using equal distances ensures pure diversity drives selection without relevance weighting. + #[test] + fn test_diversity_selects_orthogonal_candidates() { + // Three candidates with equal distance: two very similar (nearly parallel) and one orthogonal. + // Equal distances remove relevance weighting, so pure diversity drives selection. + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0, 0.0]), // along x + (1u32, 0.1, vec![0.0, 1.0, 0.0]), // along y - orthogonal to 0 + (2u32, 0.1, vec![0.99, 0.01, 0.0]), // nearly parallel to 0 + ]; + let query = &[1.0, 1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + // Should select 2 candidates + assert_eq!(result.len(), 2); + // The diverse pair is (0, 1) - orthogonal. Candidate 2 is redundant with 0. + let ids: Vec = result.iter().map(|(id, _)| *id).collect(); + assert!(ids.contains(&0), "Expected candidate 0 to be selected"); + assert!( + ids.contains(&1), + "Expected candidate 1 (orthogonal) to be selected, not redundant candidate 2" + ); + } + + /// Verify eta variant selects the same k results. + #[test] + fn test_diversity_selects_orthogonal_candidates_with_eta() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0, 0.0]), + (1u32, 0.1, vec![0.0, 1.0, 0.0]), + (2u32, 0.1, vec![0.99, 0.01, 0.0]), + ]; + let query = &[1.0, 1.0, 1.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.5, 1.0); + + assert_eq!(result.len(), 2); + let ids: Vec = result.iter().map(|(id, _)| *id).collect(); + assert!(ids.contains(&0), "Expected candidate 0 to be selected"); + assert!( + ids.contains(&1), + "Expected candidate 1 (orthogonal) to be selected" + ); + } + + /// Verify power=high weights nearby candidates (distance=0.1) more strongly than far ones. + #[test] + fn test_high_power_prefers_closer_candidates() { + // Two orthogonal candidates: one close, one far + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), // close to query + (1u32, 0.9, vec![0.0, 1.0]), // far from query + ]; + let query = &[1.0, 0.0]; + + // With high power, relevance is heavily weighted so the closest candidate dominates + let result = determinant_diversity_post_process(candidates.clone(), query, 1, 0.0, 10.0); + assert_eq!(result.len(), 1); + // Closest candidate should be preferred due to high power weighting + assert_eq!( + result[0].0, 0, + "Closest candidate should be selected with high power" + ); + } + + /// Verify that distance-to-similarity conversion handles equal distances gracefully. + #[test] + fn test_equal_distances() { + let candidates = vec![ + (0u32, 0.5, vec![1.0, 0.0]), + (1u32, 0.5, vec![0.0, 1.0]), // same distance as 0 + ]; + let query = &[1.0, 0.0]; + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + + // Should still return candidates without panicking + assert_eq!(result.len(), 2); + } + + /// Test eta=0 exactly matches greedy orthogonalization path. + #[test] + fn test_eta_zero_is_greedy_path() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.0, 1.0]), + (2u32, 0.3, vec![0.5, 0.5]), + ]; + let query = &[1.0, 1.0]; + // eta=0.0 must invoke greedy path, not ridge-regularized + let result = determinant_diversity_post_process(candidates, query, 2, 0.0, 1.0); + assert_eq!(result.len(), 2); + } +} diff --git a/diskann-providers/src/model/graph/provider/mod.rs b/diskann-providers/src/model/graph/provider/mod.rs index 0e045bfb5..f0ac174dd 100644 --- a/diskann-providers/src/model/graph/provider/mod.rs +++ b/diskann-providers/src/model/graph/provider/mod.rs @@ -6,3 +6,10 @@ pub mod async_; // Layers for the async index. pub mod layers; + +/// Determinant-diversity post-processing algorithm. +/// +/// This module is not async-specific and is re-exported here for clarity. +/// It provides diversity-promoting reranking for nearest neighbor search results. +pub mod determinant_diversity; +pub use determinant_diversity::determinant_diversity_post_process; diff --git a/diskann-providers/src/post_processor.rs b/diskann-providers/src/post_processor.rs new file mode 100644 index 000000000..07ecd9b39 --- /dev/null +++ b/diskann-providers/src/post_processor.rs @@ -0,0 +1,137 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Unified post-processor parameter types with validation. +//! +//! This module provides centralized definitions and validation for post-processor +//! parameters like Determinant-Diversity, ensuring consistent validation across +//! different search contexts (in-memory, disk, benchmarking). + +use std::fmt; + +/// Parameters for Determinant-Diversity post-processor with validation. +/// +/// Determinant-Diversity is a diversity-promoting reranking algorithm that takes +/// relevance-ranked neighbors and reorders them to maximize geometric diversity +/// while maintaining relevance. +/// +/// # Parameters +/// +/// - `power`: Relevance weighting exponent. Controls the emphasis on maintaining +/// relevance scores from the original search. Must be > 0.0. +/// +/// - `eta`: Numerical stability parameter for ridge-regularization. Controls the +/// trade-off between exact determinant computation (eta=0) and numerical robustness +/// (eta>0). Must be >= 0.0. +/// +/// # Errors +/// +/// Construction fails if: +/// - `power <= 0.0` (invalid power weighting) +/// - `eta < 0.0` (negative stability parameter) +#[derive(Debug, Clone, Copy)] +pub struct DeterminantDiversityParams { + /// Relevance weighting exponent. Must be > 0.0. + pub power: f32, + /// Numerical stability parameter. Must be >= 0.0. + pub eta: f32, +} + +impl DeterminantDiversityParams { + /// Create and validate new Determinant-Diversity parameters. + /// + /// # Errors + /// + /// Returns an error if validation fails: + /// - `power <= 0.0`: invalid relevance weighting + /// - `eta < 0.0`: invalid numerical stability parameter + pub fn new(power: f32, eta: f32) -> Result { + if power <= 0.0 { + return Err(DeterminantDiversityError::InvalidPower(power)); + } + if eta < 0.0 { + return Err(DeterminantDiversityError::InvalidEta(eta)); + } + Ok(Self { power, eta }) + } + + /// Get power parameter. + #[inline] + pub fn power(&self) -> f32 { + self.power + } + + /// Get eta parameter. + #[inline] + pub fn eta(&self) -> f32 { + self.eta + } +} + +impl fmt::Display for DeterminantDiversityParams { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "DeterminantDiversity(power={}, eta={})", + self.power, self.eta + ) + } +} + +/// Validation error for Determinant-Diversity parameters. +#[derive(Debug, Clone)] +pub enum DeterminantDiversityError { + /// Power parameter <= 0.0 + InvalidPower(f32), + /// Eta parameter < 0.0 + InvalidEta(f32), +} + +impl fmt::Display for DeterminantDiversityError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidPower(p) => { + write!(f, "determinant-diversity power must be > 0.0, got: {}", p) + } + Self::InvalidEta(e) => { + write!(f, "determinant-diversity eta must be >= 0.0, got: {}", e) + } + } + } +} + +impl std::error::Error for DeterminantDiversityError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_params() { + assert!(DeterminantDiversityParams::new(1.0, 0.0).is_ok()); + assert!(DeterminantDiversityParams::new(0.5, 1.5).is_ok()); + assert!(DeterminantDiversityParams::new(2.0, 0.1).is_ok()); + } + + #[test] + fn test_invalid_power() { + assert!(DeterminantDiversityParams::new(0.0, 1.0).is_err()); + assert!(DeterminantDiversityParams::new(-1.0, 1.0).is_err()); + } + + #[test] + fn test_invalid_eta() { + assert!(DeterminantDiversityParams::new(1.0, -0.1).is_err()); + } + + #[test] + fn test_display() { + let params = DeterminantDiversityParams::new(1.5, 0.5).unwrap(); + assert_eq!( + params.to_string(), + "DeterminantDiversity(power=1.5, eta=0.5)" + ); + } +} diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index a0a91fde2..81aabb902 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -259,6 +259,7 @@ where l, Some(parameters.beam_width as usize), Some(vector_filter_function), + None, parameters.is_flat_search, );