diff --git a/crates/lance-context-api/src/lib.rs b/crates/lance-context-api/src/lib.rs index b704b19..ebc595b 100644 --- a/crates/lance-context-api/src/lib.rs +++ b/crates/lance-context-api/src/lib.rs @@ -105,6 +105,8 @@ pub struct CreateContextRequest { pub blob_columns: Option>, #[serde(default)] pub embedding_dim: Option, + #[serde(default)] + pub distance_metric: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/lance-context-core/src/lib.rs b/crates/lance-context-core/src/lib.rs index a20bf1b..42b2438 100644 --- a/crates/lance-context-core/src/lib.rs +++ b/crates/lance-context-core/src/lib.rs @@ -13,7 +13,8 @@ pub use record::{ RetrieveResult, SearchResult, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED, }; pub use store::{ - CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, IdIndexType, + CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, DistanceMetric, + IdIndexType, }; // Re-export CompactionMetrics from lance for Python bindings diff --git a/crates/lance-context-core/src/record.rs b/crates/lance-context-core/src/record.rs index 9eec2bc..d31bd22 100644 --- a/crates/lance-context-core/src/record.rs +++ b/crates/lance-context-core/src/record.rs @@ -126,6 +126,10 @@ impl LifecycleQueryOptions { #[derive(Debug, Clone)] pub struct SearchResult { pub record: ContextRecord, + /// Distance score under the store's configured distance metric, always + /// ordered "smaller is better". Its scale is metric-dependent: L2 distance, + /// cosine distance (`1 - cosine_similarity`, in `0..=2`), or the negated dot + /// product for maximum-inner-product search. pub distance: f32, } diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index c46902c..abe6cc6 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -99,6 +99,52 @@ pub enum IdIndexType { BTree, } +/// Distance metric used to rank candidates during vector search. +/// +/// All variants are normalized so that a **smaller** value means a closer +/// match, keeping the search ranking ascending regardless of metric. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum DistanceMetric { + /// Euclidean (L2) distance. Default for backward compatibility. + #[default] + L2, + /// Cosine distance (`1 - cosine_similarity`). Common for normalized + /// embeddings from most modern models. + Cosine, + /// Negated dot product (maximum inner product search). + Dot, +} + +impl DistanceMetric { + /// Parse a metric from its string identifier (`"l2"`, `"cosine"`, `"dot"`). + /// Matching is case-insensitive. + /// + /// # Errors + /// Returns an error if the identifier is not a recognized metric. + pub fn parse(value: &str) -> LanceResult { + match value.to_ascii_lowercase().as_str() { + "l2" | "euclidean" => Ok(Self::L2), + "cosine" => Ok(Self::Cosine), + "dot" | "dot_product" => Ok(Self::Dot), + other => Err(LanceError::from(ArrowError::InvalidArgumentError(format!( + "invalid distance metric '{other}': valid values are 'l2', 'cosine', 'dot'" + )))), + } + } + + /// Compute the metric between a query and a candidate vector. + /// + /// The returned value is always "smaller is better". + #[must_use] + pub fn distance(self, query: &[f32], candidate: &[f32]) -> f32 { + match self { + Self::L2 => l2_distance(query, candidate), + Self::Cosine => cosine_distance(query, candidate), + Self::Dot => dot_distance(query, candidate), + } + } +} + /// Statistics about compaction status and history. #[derive(Debug, Clone)] pub struct CompactionStats { @@ -135,6 +181,7 @@ pub struct ContextStore { blob_columns: HashSet, id_index_type: IdIndexType, embedding_dim: i32, + distance_metric: DistanceMetric, } /// Additional configuration when opening a [`ContextStore`]. @@ -150,6 +197,8 @@ pub struct ContextStoreOptions { pub blob_columns: HashSet, /// Type of scalar index to create on the `id` column. pub id_index_type: IdIndexType, + /// Distance metric used to rank vector-search results. + pub distance_metric: DistanceMetric, } impl ContextStoreOptions { @@ -261,6 +310,7 @@ impl ContextStore { blob_columns, id_index_type: options.id_index_type, embedding_dim, + distance_metric: options.distance_metric, }; // Ensure id index if configured @@ -693,7 +743,9 @@ impl ContextStore { .await? .into_iter() .filter_map(|record| { - let distance = l2_distance(query, record.embedding.as_ref()?); + let distance = self + .distance_metric + .distance(query, record.embedding.as_ref()?); Some(SearchResult { record, distance }) }) .collect(); @@ -740,7 +792,9 @@ impl ContextStore { .iter() .enumerate() .filter_map(|(index, record)| { - let distance = l2_distance(query, record.embedding.as_ref()?); + let distance = self + .distance_metric + .distance(query, record.embedding.as_ref()?); Some((index, distance)) }) .collect(); @@ -2032,6 +2086,34 @@ fn embedding_dim_from_schema(schema: &Schema) -> LanceResult { Ok(*embedding_dim) } +/// Dot product of two vectors. +fn dot_product(left: &[f32], right: &[f32]) -> f32 { + left.iter() + .zip(right) + .map(|(left, right)| left * right) + .sum::() +} + +/// Cosine distance (`1 - cosine_similarity`), ranging from 0 (identical +/// direction) to 2 (opposite). If either vector has zero magnitude the +/// similarity is undefined, so we return the maximum distance (`1.0`) to keep +/// such records ranked last without producing `NaN`. +fn cosine_distance(left: &[f32], right: &[f32]) -> f32 { + let dot = dot_product(left, right); + let left_norm = dot_product(left, left).sqrt(); + let right_norm = dot_product(right, right).sqrt(); + if left_norm == 0.0 || right_norm == 0.0 { + return 1.0; + } + 1.0 - (dot / (left_norm * right_norm)) +} + +/// Negated dot product, so that a larger inner product (a closer match for +/// maximum-inner-product search) sorts first under ascending ordering. +fn dot_distance(left: &[f32], right: &[f32]) -> f32 { + -dot_product(left, right) +} + fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A> where A: Array + 'static, @@ -2147,6 +2229,113 @@ mod tests { }); } + fn make_embedding2(x0: f32, x1: f32) -> Vec { + let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize]; + values[0] = x0; + values[1] = x1; + values + } + + fn text_record_with(id: &str, embedding: Vec) -> ContextRecord { + let mut record = text_record(id, 0.0); + record.embedding = Some(embedding); + record + } + + #[test] + fn distance_metric_parse_and_math() { + assert_eq!(DistanceMetric::parse("l2").unwrap(), DistanceMetric::L2); + assert_eq!(DistanceMetric::parse("L2").unwrap(), DistanceMetric::L2); + assert_eq!( + DistanceMetric::parse("cosine").unwrap(), + DistanceMetric::Cosine + ); + assert_eq!(DistanceMetric::parse("DOT").unwrap(), DistanceMetric::Dot); + assert!(DistanceMetric::parse("manhattan").is_err()); + assert_eq!(DistanceMetric::default(), DistanceMetric::L2); + + let a = [1.0_f32, 0.0]; + let b = [1.0_f32, 1.0]; + // L2: sqrt(0 + 1) = 1 + assert!((DistanceMetric::L2.distance(&a, &b) - 1.0).abs() < 1e-6); + // Cosine distance: 1 - (1 / (1 * sqrt(2))) = 1 - 0.70710677 + assert!((DistanceMetric::Cosine.distance(&a, &b) - (1.0 - 0.707_106_77)).abs() < 1e-5); + // Dot: -(1*1 + 0*1) = -1 + assert!((DistanceMetric::Dot.distance(&a, &b) + 1.0).abs() < 1e-6); + // Zero-magnitude vectors yield max cosine distance, never NaN. + let zero = [0.0_f32, 0.0]; + assert!((DistanceMetric::Cosine.distance(&a, &zero) - 1.0).abs() < 1e-6); + } + + #[test] + fn search_metric_changes_ranking() { + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + // query direction matches "aligned" but "near" is closer in L2. + let query = make_embedding2(1.0, 0.0); + // aligned: same direction as query, larger magnitude -> far in L2, + // but cosine distance 0 and largest dot product. + let aligned = make_embedding2(10.0, 0.0); + // near: closest in L2, but off-axis -> larger cosine distance. + let near = make_embedding2(1.0, 1.0); + + // Default (L2): `near` should rank first. + let l2_dir = TempDir::new().unwrap(); + let mut l2_store = ContextStore::open(&l2_dir.path().to_string_lossy()) + .await + .unwrap(); + l2_store + .add(&[ + text_record_with("aligned", aligned.clone()), + text_record_with("near", near.clone()), + ]) + .await + .unwrap(); + let l2_results = l2_store.search(&query, Some(2)).await.unwrap(); + assert_eq!(l2_results[0].record.id, "near"); + + // Cosine: `aligned` should rank first despite the larger L2 distance. + let cos_dir = TempDir::new().unwrap(); + let cos_opts = ContextStoreOptions { + distance_metric: DistanceMetric::Cosine, + ..Default::default() + }; + let mut cos_store = + ContextStore::open_with_options(&cos_dir.path().to_string_lossy(), cos_opts) + .await + .unwrap(); + cos_store + .add(&[ + text_record_with("aligned", aligned.clone()), + text_record_with("near", near.clone()), + ]) + .await + .unwrap(); + let cos_results = cos_store.search(&query, Some(2)).await.unwrap(); + assert_eq!(cos_results[0].record.id, "aligned"); + + // Dot: `aligned` has the largest inner product -> first. + let dot_dir = TempDir::new().unwrap(); + let dot_opts = ContextStoreOptions { + distance_metric: DistanceMetric::Dot, + ..Default::default() + }; + let mut dot_store = + ContextStore::open_with_options(&dot_dir.path().to_string_lossy(), dot_opts) + .await + .unwrap(); + dot_store + .add(&[ + text_record_with("aligned", aligned), + text_record_with("near", near), + ]) + .await + .unwrap(); + let dot_results = dot_store.search(&query, Some(2)).await.unwrap(); + assert_eq!(dot_results[0].record.id, "aligned"); + }); + } + #[test] fn retrieve_fuses_text_and_vector_channels() { let dir = TempDir::new().unwrap(); diff --git a/crates/lance-context-server/src/routes/contexts.rs b/crates/lance-context-server/src/routes/contexts.rs index f15791e..3be2330 100644 --- a/crates/lance-context-server/src/routes/contexts.rs +++ b/crates/lance-context-server/src/routes/contexts.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use axum::extract::{Path, State}; use axum::Json; use lance_context_api::{ContextInfo, CreateContextRequest, ListContextsResponse}; -use lance_context_core::{ContextStore, ContextStoreOptions, IdIndexType}; +use lance_context_core::{ContextStore, ContextStoreOptions, DistanceMetric, IdIndexType}; use tokio::sync::RwLock; use crate::error::AppError; @@ -37,12 +37,20 @@ pub async fn create_context( let blob_columns: HashSet = req.blob_columns.unwrap_or_default().into_iter().collect(); + let distance_metric = match req.distance_metric.as_deref() { + Some(value) => { + DistanceMetric::parse(value).map_err(|e| AppError::InvalidRequest(e.to_string()))? + } + None => DistanceMetric::default(), + }; + let uri = state.context_uri(&req.name); let options = ContextStoreOptions { storage_options: req.storage_options, embedding_dim: req.embedding_dim, blob_columns, id_index_type, + distance_metric, ..Default::default() }; diff --git a/crates/lance-context/src/unified.rs b/crates/lance-context/src/unified.rs index 9751f50..5b92374 100644 --- a/crates/lance-context/src/unified.rs +++ b/crates/lance-context/src/unified.rs @@ -5,7 +5,9 @@ use lance_context_api::{ ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RetrieveRequest, RetrieveResultDto, SearchResultDto, }; -use lance_context_core::{ContextStore as LocalStore, ContextStoreOptions, IdIndexType}; +use lance_context_core::{ + ContextStore as LocalStore, ContextStoreOptions, DistanceMetric, IdIndexType, +}; #[cfg(feature = "remote")] use lance_context_client::RemoteContextStore; @@ -29,6 +31,7 @@ impl ContextStore { storage_options: Option>, id_index_type: Option<&str>, blob_columns: Option>, + distance_metric: Option<&str>, ) -> Result { let id_idx = match id_index_type { Some("btree") => IdIndexType::BTree, @@ -40,6 +43,11 @@ impl ContextStore { ))); } }; + let metric = match distance_metric { + Some(value) => DistanceMetric::parse(value) + .map_err(|e| ContextError::InvalidRequest(e.to_string()))?, + None => DistanceMetric::default(), + }; let options = ContextStoreOptions { storage_options, blob_columns: blob_columns @@ -47,6 +55,7 @@ impl ContextStore { .into_iter() .collect::>(), id_index_type: id_idx, + distance_metric: metric, ..Default::default() }; let store = LocalStore::open_with_options(uri, options) diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 5d3fe4d..3db9f81 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -312,6 +312,7 @@ def __init__( quiet_hours: list[tuple[int, int]] | None = None, id_index_type: str | None = None, embedding_dim: int | None = None, + distance_metric: str | None = None, ) -> None: options = _merge_storage_options( storage_options, @@ -336,6 +337,7 @@ def __init__( or compaction_config["enabled"] or id_index_type or embedding_dim is not None + or distance_metric ): self._inner = _Context.create( uri, @@ -343,6 +345,7 @@ def __init__( compaction_config=compaction_config, id_index_type=id_index_type, embedding_dim=embedding_dim, + distance_metric=distance_metric, ) else: self._inner = _Context.create(uri) @@ -366,6 +369,7 @@ def create( quiet_hours: list[tuple[int, int]] | None = None, id_index_type: str | None = None, embedding_dim: int | None = None, + distance_metric: str | None = None, ) -> Context: return cls( uri, @@ -383,6 +387,7 @@ def create( quiet_hours=quiet_hours, id_index_type=id_index_type, embedding_dim=embedding_dim, + distance_metric=distance_metric, ) def uri(self) -> str: diff --git a/python/src/lib.rs b/python/src/lib.rs index 8b6f0e5..771ef8a 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -14,8 +14,8 @@ use tokio::runtime::Runtime; use lance_context_core::serde::CONTENT_TYPE_TEXT; use lance_context_core::{ CompactionConfig, CompactionMetrics, CompactionStats, Context as RustContext, ContextRecord, - ContextStore, ContextStoreOptions, IdIndexType, LifecycleQueryOptions, RecordFilters, - Relationship, RetrieveResult, SearchResult, LIFECYCLE_ACTIVE, + ContextStore, ContextStoreOptions, DistanceMetric, IdIndexType, LifecycleQueryOptions, + RecordFilters, Relationship, RetrieveResult, SearchResult, LIFECYCLE_ACTIVE, }; const DEFAULT_BINARY_CONTENT_TYPE: &str = "application/octet-stream"; @@ -168,7 +168,7 @@ fn filters_from_json(filters_json: Option) -> PyResult, py: Python<'_>, @@ -178,6 +178,7 @@ impl Context { blob_columns: Option>, id_index_type: Option, embedding_dim: Option, + distance_metric: Option, ) -> PyResult { let runtime = Arc::new(Runtime::new().map_err(to_py_err)?); @@ -195,12 +196,18 @@ impl Context { } }; + let metric = match distance_metric.as_deref() { + Some(value) => DistanceMetric::parse(value).map_err(to_py_err)?, + None => DistanceMetric::default(), + }; + let options = ContextStoreOptions { storage_options: storage_options_from_dict(storage_options)?, compaction: compaction_config_from_dict(compaction_config)?, embedding_dim, blob_columns: blob_set, id_index_type: id_idx, + distance_metric: metric, }; let store_res = diff --git a/python/tests/test_distance_metric.py b/python/tests/test_distance_metric.py new file mode 100644 index 0000000..ab38cba --- /dev/null +++ b/python/tests/test_distance_metric.py @@ -0,0 +1,60 @@ +"""Tests for the configurable vector-search distance metric.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from pathlib import Path + +from lance_context.api import Context + +DIM = 1536 + + +def _vec(x0: float, x1: float = 0.0) -> list[float]: + vector = [0.0] * DIM + vector[0] = x0 + vector[1] = x1 + return vector + + +# Query points along axis 0. ``aligned`` shares the query's direction (cosine +# distance 0, largest dot product) but is far away in L2; ``near`` is closest in +# L2 but off-axis, so cosine ranks it lower. +QUERY = _vec(1.0) +ALIGNED = _vec(10.0) +NEAR = _vec(1.0, 1.0) + + +def _make(uri: str, **kwargs: str) -> Context: + ctx = Context.create(uri, **kwargs) + ctx.add("user", "aligned", embedding=ALIGNED, external_id="aligned") + ctx.add("user", "near", embedding=NEAR, external_id="near") + return ctx + + +def test_default_metric_is_l2(tmp_path: Path) -> None: + ctx = _make(str(tmp_path / "l2.lance")) + hits = ctx.search(QUERY, limit=2) + assert [h["external_id"] for h in hits][0] == "near" + + +def test_cosine_metric_changes_ranking(tmp_path: Path) -> None: + ctx = _make(str(tmp_path / "cosine.lance"), distance_metric="cosine") + hits = ctx.search(QUERY, limit=2) + assert [h["external_id"] for h in hits][0] == "aligned" + + +def test_dot_metric_ranks_by_inner_product(tmp_path: Path) -> None: + ctx = _make(str(tmp_path / "dot.lance"), distance_metric="dot") + hits = ctx.search(QUERY, limit=2) + assert [h["external_id"] for h in hits][0] == "aligned" + + +def test_invalid_metric_rejected(tmp_path: Path) -> None: + uri = str(tmp_path / "bad.lance") + with pytest.raises(RuntimeError, match="invalid distance metric"): + Context.create(uri, distance_metric="manhattan")