Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/lance-context-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ pub struct CreateContextRequest {
pub blob_columns: Option<Vec<String>>,
#[serde(default)]
pub embedding_dim: Option<i32>,
#[serde(default)]
pub distance_metric: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
3 changes: 2 additions & 1 deletion crates/lance-context-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ pub use record::{
RetrieveResult, SearchResult, StateMetadata, LIFECYCLE_ACTIVE, LIFECYCLE_CONTRADICTED,
};
pub use store::{
CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, IdIndexType,
CompactionConfig, CompactionStats, ContextStore, ContextStoreOptions, DistanceMetric,
IdIndexType,
};

// Re-export CompactionMetrics from lance for Python bindings
Expand Down
4 changes: 4 additions & 0 deletions crates/lance-context-core/src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ impl LifecycleQueryOptions {
#[derive(Debug, Clone)]
pub struct SearchResult {
pub record: ContextRecord,
/// Distance score under the store's configured distance metric, always
/// ordered "smaller is better". Its scale is metric-dependent: L2 distance,
/// cosine distance (`1 - cosine_similarity`, in `0..=2`), or the negated dot
/// product for maximum-inner-product search.
pub distance: f32,
}

Expand Down
193 changes: 191 additions & 2 deletions crates/lance-context-core/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,52 @@ pub enum IdIndexType {
BTree,
}

/// Distance metric used to rank candidates during vector search.
///
/// All variants are normalized so that a **smaller** value means a closer
/// match, keeping the search ranking ascending regardless of metric.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DistanceMetric {
/// Euclidean (L2) distance. Default for backward compatibility.
#[default]
L2,
/// Cosine distance (`1 - cosine_similarity`). Common for normalized
/// embeddings from most modern models.
Cosine,
/// Negated dot product (maximum inner product search).
Dot,
}

impl DistanceMetric {
/// Parse a metric from its string identifier (`"l2"`, `"cosine"`, `"dot"`).
/// Matching is case-insensitive.
///
/// # Errors
/// Returns an error if the identifier is not a recognized metric.
pub fn parse(value: &str) -> LanceResult<Self> {
match value.to_ascii_lowercase().as_str() {
"l2" | "euclidean" => Ok(Self::L2),
"cosine" => Ok(Self::Cosine),
"dot" | "dot_product" => Ok(Self::Dot),
other => Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
"invalid distance metric '{other}': valid values are 'l2', 'cosine', 'dot'"
)))),
}
}

/// Compute the metric between a query and a candidate vector.
///
/// The returned value is always "smaller is better".
#[must_use]
pub fn distance(self, query: &[f32], candidate: &[f32]) -> f32 {
match self {
Self::L2 => l2_distance(query, candidate),
Self::Cosine => cosine_distance(query, candidate),
Self::Dot => dot_distance(query, candidate),
}
}
}

/// Statistics about compaction status and history.
#[derive(Debug, Clone)]
pub struct CompactionStats {
Expand Down Expand Up @@ -135,6 +181,7 @@ pub struct ContextStore {
blob_columns: HashSet<String>,
id_index_type: IdIndexType,
embedding_dim: i32,
distance_metric: DistanceMetric,
}

/// Additional configuration when opening a [`ContextStore`].
Expand All @@ -150,6 +197,8 @@ pub struct ContextStoreOptions {
pub blob_columns: HashSet<String>,
/// Type of scalar index to create on the `id` column.
pub id_index_type: IdIndexType,
/// Distance metric used to rank vector-search results.
pub distance_metric: DistanceMetric,
}

impl ContextStoreOptions {
Expand Down Expand Up @@ -261,6 +310,7 @@ impl ContextStore {
blob_columns,
id_index_type: options.id_index_type,
embedding_dim,
distance_metric: options.distance_metric,
};

// Ensure id index if configured
Expand Down Expand Up @@ -693,7 +743,9 @@ impl ContextStore {
.await?
.into_iter()
.filter_map(|record| {
let distance = l2_distance(query, record.embedding.as_ref()?);
let distance = self
.distance_metric
.distance(query, record.embedding.as_ref()?);
Some(SearchResult { record, distance })
})
.collect();
Expand Down Expand Up @@ -740,7 +792,9 @@ impl ContextStore {
.iter()
.enumerate()
.filter_map(|(index, record)| {
let distance = l2_distance(query, record.embedding.as_ref()?);
let distance = self
.distance_metric
.distance(query, record.embedding.as_ref()?);
Some((index, distance))
})
.collect();
Expand Down Expand Up @@ -2032,6 +2086,34 @@ fn embedding_dim_from_schema(schema: &Schema) -> LanceResult<i32> {
Ok(*embedding_dim)
}

/// Dot product of two vectors.
fn dot_product(left: &[f32], right: &[f32]) -> f32 {
left.iter()
.zip(right)
.map(|(left, right)| left * right)
.sum::<f32>()
}

/// Cosine distance (`1 - cosine_similarity`), ranging from 0 (identical
/// direction) to 2 (opposite). If either vector has zero magnitude the
/// similarity is undefined, so we return the maximum distance (`1.0`) to keep
/// such records ranked last without producing `NaN`.
fn cosine_distance(left: &[f32], right: &[f32]) -> f32 {
let dot = dot_product(left, right);
let left_norm = dot_product(left, left).sqrt();
let right_norm = dot_product(right, right).sqrt();
if left_norm == 0.0 || right_norm == 0.0 {
return 1.0;
}
1.0 - (dot / (left_norm * right_norm))
}

/// Negated dot product, so that a larger inner product (a closer match for
/// maximum-inner-product search) sorts first under ascending ordering.
fn dot_distance(left: &[f32], right: &[f32]) -> f32 {
-dot_product(left, right)
}

fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
where
A: Array + 'static,
Expand Down Expand Up @@ -2147,6 +2229,113 @@ mod tests {
});
}

fn make_embedding2(x0: f32, x1: f32) -> Vec<f32> {
let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
values[0] = x0;
values[1] = x1;
values
}

fn text_record_with(id: &str, embedding: Vec<f32>) -> ContextRecord {
let mut record = text_record(id, 0.0);
record.embedding = Some(embedding);
record
}

#[test]
fn distance_metric_parse_and_math() {
assert_eq!(DistanceMetric::parse("l2").unwrap(), DistanceMetric::L2);
assert_eq!(DistanceMetric::parse("L2").unwrap(), DistanceMetric::L2);
assert_eq!(
DistanceMetric::parse("cosine").unwrap(),
DistanceMetric::Cosine
);
assert_eq!(DistanceMetric::parse("DOT").unwrap(), DistanceMetric::Dot);
assert!(DistanceMetric::parse("manhattan").is_err());
assert_eq!(DistanceMetric::default(), DistanceMetric::L2);

let a = [1.0_f32, 0.0];
let b = [1.0_f32, 1.0];
// L2: sqrt(0 + 1) = 1
assert!((DistanceMetric::L2.distance(&a, &b) - 1.0).abs() < 1e-6);
// Cosine distance: 1 - (1 / (1 * sqrt(2))) = 1 - 0.70710677
assert!((DistanceMetric::Cosine.distance(&a, &b) - (1.0 - 0.707_106_77)).abs() < 1e-5);
// Dot: -(1*1 + 0*1) = -1
assert!((DistanceMetric::Dot.distance(&a, &b) + 1.0).abs() < 1e-6);
// Zero-magnitude vectors yield max cosine distance, never NaN.
let zero = [0.0_f32, 0.0];
assert!((DistanceMetric::Cosine.distance(&a, &zero) - 1.0).abs() < 1e-6);
}

#[test]
fn search_metric_changes_ranking() {
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
// query direction matches "aligned" but "near" is closer in L2.
let query = make_embedding2(1.0, 0.0);
// aligned: same direction as query, larger magnitude -> far in L2,
// but cosine distance 0 and largest dot product.
let aligned = make_embedding2(10.0, 0.0);
// near: closest in L2, but off-axis -> larger cosine distance.
let near = make_embedding2(1.0, 1.0);

// Default (L2): `near` should rank first.
let l2_dir = TempDir::new().unwrap();
let mut l2_store = ContextStore::open(&l2_dir.path().to_string_lossy())
.await
.unwrap();
l2_store
.add(&[
text_record_with("aligned", aligned.clone()),
text_record_with("near", near.clone()),
])
.await
.unwrap();
let l2_results = l2_store.search(&query, Some(2)).await.unwrap();
assert_eq!(l2_results[0].record.id, "near");

// Cosine: `aligned` should rank first despite the larger L2 distance.
let cos_dir = TempDir::new().unwrap();
let cos_opts = ContextStoreOptions {
distance_metric: DistanceMetric::Cosine,
..Default::default()
};
let mut cos_store =
ContextStore::open_with_options(&cos_dir.path().to_string_lossy(), cos_opts)
.await
.unwrap();
cos_store
.add(&[
text_record_with("aligned", aligned.clone()),
text_record_with("near", near.clone()),
])
.await
.unwrap();
let cos_results = cos_store.search(&query, Some(2)).await.unwrap();
assert_eq!(cos_results[0].record.id, "aligned");

// Dot: `aligned` has the largest inner product -> first.
let dot_dir = TempDir::new().unwrap();
let dot_opts = ContextStoreOptions {
distance_metric: DistanceMetric::Dot,
..Default::default()
};
let mut dot_store =
ContextStore::open_with_options(&dot_dir.path().to_string_lossy(), dot_opts)
.await
.unwrap();
dot_store
.add(&[
text_record_with("aligned", aligned),
text_record_with("near", near),
])
.await
.unwrap();
let dot_results = dot_store.search(&query, Some(2)).await.unwrap();
assert_eq!(dot_results[0].record.id, "aligned");
});
}

#[test]
fn retrieve_fuses_text_and_vector_channels() {
let dir = TempDir::new().unwrap();
Expand Down
10 changes: 9 additions & 1 deletion crates/lance-context-server/src/routes/contexts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Arc;
use axum::extract::{Path, State};
use axum::Json;
use lance_context_api::{ContextInfo, CreateContextRequest, ListContextsResponse};
use lance_context_core::{ContextStore, ContextStoreOptions, IdIndexType};
use lance_context_core::{ContextStore, ContextStoreOptions, DistanceMetric, IdIndexType};
use tokio::sync::RwLock;

use crate::error::AppError;
Expand Down Expand Up @@ -37,12 +37,20 @@ pub async fn create_context(

let blob_columns: HashSet<String> = req.blob_columns.unwrap_or_default().into_iter().collect();

let distance_metric = match req.distance_metric.as_deref() {
Some(value) => {
DistanceMetric::parse(value).map_err(|e| AppError::InvalidRequest(e.to_string()))?
}
None => DistanceMetric::default(),
};

let uri = state.context_uri(&req.name);
let options = ContextStoreOptions {
storage_options: req.storage_options,
embedding_dim: req.embedding_dim,
blob_columns,
id_index_type,
distance_metric,
..Default::default()
};

Expand Down
11 changes: 10 additions & 1 deletion crates/lance-context/src/unified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use lance_context_api::{
ContextError, ContextResult, ContextStoreApi, DeleteRecordResponse, RecordDto, RetrieveRequest,
RetrieveResultDto, SearchResultDto,
};
use lance_context_core::{ContextStore as LocalStore, ContextStoreOptions, IdIndexType};
use lance_context_core::{
ContextStore as LocalStore, ContextStoreOptions, DistanceMetric, IdIndexType,
};

#[cfg(feature = "remote")]
use lance_context_client::RemoteContextStore;
Expand All @@ -29,6 +31,7 @@ impl ContextStore {
storage_options: Option<std::collections::HashMap<String, String>>,
id_index_type: Option<&str>,
blob_columns: Option<Vec<String>>,
distance_metric: Option<&str>,
) -> Result<Self, ContextError> {
let id_idx = match id_index_type {
Some("btree") => IdIndexType::BTree,
Expand All @@ -40,13 +43,19 @@ impl ContextStore {
)));
}
};
let metric = match distance_metric {
Some(value) => DistanceMetric::parse(value)
.map_err(|e| ContextError::InvalidRequest(e.to_string()))?,
None => DistanceMetric::default(),
};
let options = ContextStoreOptions {
storage_options,
blob_columns: blob_columns
.unwrap_or_default()
.into_iter()
.collect::<HashSet<_>>(),
id_index_type: id_idx,
distance_metric: metric,
..Default::default()
};
let store = LocalStore::open_with_options(uri, options)
Expand Down
5 changes: 5 additions & 0 deletions python/python/lance_context/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __init__(
quiet_hours: list[tuple[int, int]] | None = None,
id_index_type: str | None = None,
embedding_dim: int | None = None,
distance_metric: str | None = None,
) -> None:
options = _merge_storage_options(
storage_options,
Expand All @@ -336,13 +337,15 @@ def __init__(
or compaction_config["enabled"]
or id_index_type
or embedding_dim is not None
or distance_metric
):
self._inner = _Context.create(
uri,
storage_options=options or None,
compaction_config=compaction_config,
id_index_type=id_index_type,
embedding_dim=embedding_dim,
distance_metric=distance_metric,
)
else:
self._inner = _Context.create(uri)
Expand All @@ -366,6 +369,7 @@ def create(
quiet_hours: list[tuple[int, int]] | None = None,
id_index_type: str | None = None,
embedding_dim: int | None = None,
distance_metric: str | None = None,
) -> Context:
return cls(
uri,
Expand All @@ -383,6 +387,7 @@ def create(
quiet_hours=quiet_hours,
id_index_type=id_index_type,
embedding_dim=embedding_dim,
distance_metric=distance_metric,
)

def uri(self) -> str:
Expand Down
Loading
Loading