Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/lance-context-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ pub struct CreateContextRequest {
pub id_index_type: Option<String>,
#[serde(default)]
pub blob_columns: Option<Vec<String>>,
#[serde(default)]
pub embedding_dim: Option<i32>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
197 changes: 180 additions & 17 deletions crates/lance-context-core/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,17 @@ pub struct ContextStore {
pub compaction_config: CompactionConfig,
blob_columns: HashSet<String>,
id_index_type: IdIndexType,
embedding_dim: i32,
}

/// Additional configuration when opening a [`ContextStore`].
#[derive(Debug, Clone, Default)]
pub struct ContextStoreOptions {
pub storage_options: Option<HashMap<String, String>>,
pub compaction: CompactionConfig,
/// Width of the fixed-size embedding vector for newly-created datasets.
/// Existing datasets always use the dimension persisted in their schema.
pub embedding_dim: Option<i32>,
/// Column names that should use Lance V1 blob encoding.
/// Valid values: `"text_payload"`, `"binary_payload"`.
pub blob_columns: HashSet<String>,
Expand Down Expand Up @@ -212,15 +216,37 @@ impl ContextStore {
}
}

let requested_embedding_dim = match options.embedding_dim {
Some(dim) => {
validate_embedding_dim(dim)?;
dim
}
None => DEFAULT_EMBEDDING_DIM,
};
let storage_options = options.storage_options();
let blob_columns = options.blob_columns.clone();
let dataset = match Self::load_with_options(uri, storage_options.clone()).await {
Ok(dataset) => dataset,
let (dataset, created) = match Self::load_with_options(uri, storage_options.clone()).await {
Ok(dataset) => (dataset, false),
Err(LanceError::DatasetNotFound { .. }) => {
Self::create_with_options(uri, storage_options, &blob_columns).await?
let dataset = Self::create_with_options(
uri,
storage_options,
&blob_columns,
requested_embedding_dim,
)
.await?;
(dataset, true)
}
Err(err) => return Err(err),
};
let arrow_schema: Schema = dataset.schema().into();
let embedding_dim = embedding_dim_from_schema(&arrow_schema)?;
if !created && options.embedding_dim.is_some() && embedding_dim != requested_embedding_dim {
return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
"existing context embedding dimension {} does not match requested dimension {}",
embedding_dim, requested_embedding_dim
))));
}

let mut store = Self {
dataset,
Expand All @@ -234,6 +260,7 @@ impl ContextStore {
compaction_config: options.compaction,
blob_columns,
id_index_type: options.id_index_type,
embedding_dim,
};

// Ensure id index if configured
Expand All @@ -245,6 +272,12 @@ impl ContextStore {
Ok(store)
}

/// Embedding vector width persisted in this context dataset schema.
#[must_use]
pub fn embedding_dim(&self) -> i32 {
self.embedding_dim
}

/// Append context records to the store and return the new dataset version.
pub async fn add(&mut self, entries: &[ContextRecord]) -> LanceResult<u64> {
if entries.is_empty() {
Expand Down Expand Up @@ -648,7 +681,7 @@ impl ContextStore {
filters: Option<&RecordFilters>,
options: LifecycleQueryOptions,
) -> LanceResult<Vec<SearchResult>> {
validate_query_dimension(query)?;
validate_query_dimension(query, self.embedding_dim)?;

let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
if top_k == 0 {
Expand Down Expand Up @@ -689,7 +722,7 @@ impl ContextStore {
}

if let Some(query) = vector {
validate_query_dimension(query)?;
validate_query_dimension(query, self.embedding_dim)?;
}

let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT);
Expand Down Expand Up @@ -1013,7 +1046,12 @@ impl ContextStore {
/// Lance V1 blob encoding (out-of-line binary buffers). For `text_payload`,
/// this also changes the Arrow type from `LargeUtf8` to `LargeBinary`.
pub fn schema(blob_columns: &HashSet<String>) -> Schema {
Self::schema_with_options(blob_columns, true, true, true, true)
Self::schema_with_embedding_dim(blob_columns, DEFAULT_EMBEDDING_DIM)
}

/// Lance schema for a context store using a caller-selected embedding width.
pub fn schema_with_embedding_dim(blob_columns: &HashSet<String>, embedding_dim: i32) -> Schema {
Self::schema_with_options(blob_columns, true, true, true, true, embedding_dim)
}

fn schema_with_options(
Expand All @@ -1022,6 +1060,7 @@ impl ContextStore {
include_metadata: bool,
include_relationships: bool,
include_lifecycle: bool,
embedding_dim: i32,
) -> Schema {
let mut id_metadata = HashMap::new();
id_metadata.insert(
Expand Down Expand Up @@ -1110,7 +1149,7 @@ impl ContextStore {
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
DEFAULT_EMBEDDING_DIM,
embedding_dim,
),
true,
),
Expand All @@ -1137,8 +1176,9 @@ impl ContextStore {
uri: &str,
storage_options: Option<HashMap<String, String>>,
blob_columns: &HashSet<String>,
embedding_dim: i32,
) -> LanceResult<Dataset> {
let schema = Arc::new(Self::schema(blob_columns));
let schema = Arc::new(Self::schema_with_embedding_dim(blob_columns, embedding_dim));
let empty_batch = RecordBatch::new_empty(schema.clone());
let batches = RecordBatchIterator::new(
vec![Ok::<RecordBatch, ArrowError>(empty_batch)].into_iter(),
Expand Down Expand Up @@ -1259,7 +1299,7 @@ impl ContextStore {
);

let mut embedding_builder =
FixedSizeListBuilder::new(Float32Builder::new(), DEFAULT_EMBEDDING_DIM);
FixedSizeListBuilder::new(Float32Builder::new(), self.embedding_dim);

for entry in entries {
id_builder.append_value(&entry.id);
Expand Down Expand Up @@ -1360,11 +1400,11 @@ impl ContextStore {
}

if let Some(embedding) = &entry.embedding {
if embedding.len() != DEFAULT_EMBEDDING_DIM as usize {
if embedding.len() != self.embedding_dim as usize {
return Err(ArrowError::InvalidArgumentError(format!(
"embedding length {} does not match expected dimension {}",
embedding.len(),
DEFAULT_EMBEDDING_DIM
self.embedding_dim
))
.into());
}
Expand All @@ -1378,7 +1418,7 @@ impl ContextStore {
} else {
// FixedSizeListBuilder requires padding values for null slots.
let values_builder = embedding_builder.values();
for _ in 0..DEFAULT_EMBEDDING_DIM {
for _ in 0..self.embedding_dim {
values_builder.append_null();
}
embedding_builder.append(false);
Expand Down Expand Up @@ -1844,12 +1884,21 @@ fn l2_distance(left: &[f32], right: &[f32]) -> f32 {
.sqrt()
}

fn validate_query_dimension(query: &[f32]) -> LanceResult<()> {
if query.len() != DEFAULT_EMBEDDING_DIM as usize {
fn validate_embedding_dim(embedding_dim: i32) -> LanceResult<()> {
if embedding_dim <= 0 {
return Err(LanceError::from(ArrowError::InvalidArgumentError(format!(
"embedding_dim must be positive, got {embedding_dim}"
))));
}
Ok(())
}

fn validate_query_dimension(query: &[f32], embedding_dim: i32) -> LanceResult<()> {
if query.len() != embedding_dim as usize {
return Err(ArrowError::InvalidArgumentError(format!(
"query length {} does not match embedding dimension {}",
query.len(),
DEFAULT_EMBEDDING_DIM
embedding_dim
))
.into());
}
Expand Down Expand Up @@ -1965,6 +2014,24 @@ fn compare_optional_score(left: Option<f32>, right: Option<f32>) -> Ordering {
}
}

fn embedding_dim_from_schema(schema: &Schema) -> LanceResult<i32> {
let field = schema
.field_with_name("embedding")
.map_err(LanceError::from)?;
let DataType::FixedSizeList(item_field, embedding_dim) = field.data_type() else {
return Err(LanceError::from(ArrowError::InvalidArgumentError(
"embedding column must be a FixedSizeList<Float32>".to_string(),
)));
};
if item_field.data_type() != &DataType::Float32 {
return Err(LanceError::from(ArrowError::InvalidArgumentError(
"embedding column must contain Float32 values".to_string(),
)));
}
validate_embedding_dim(*embedding_dim)?;
Ok(*embedding_dim)
}

fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A>
where
A: Array + 'static,
Expand Down Expand Up @@ -1997,14 +2064,18 @@ mod tests {
use chrono::{Duration as ChronoDuration, Utc};
use tempfile::TempDir;

fn make_embedding(pivot: f32) -> Vec<f32> {
let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize];
fn make_embedding_with_dim(dim: usize, pivot: f32) -> Vec<f32> {
let mut values = vec![0.0; dim];
if !values.is_empty() {
values[0] = pivot;
}
values
}

fn make_embedding(pivot: f32) -> Vec<f32> {
make_embedding_with_dim(DEFAULT_EMBEDDING_DIM as usize, pivot)
}

fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord {
ContextRecord {
id: id.to_string(),
Expand Down Expand Up @@ -2114,6 +2185,97 @@ mod tests {
});
}

#[test]
fn custom_embedding_dimension_round_trips_add_search_and_reopen() {
let dir = TempDir::new().unwrap();
let uri = dir.path().to_string_lossy().to_string();
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let options = ContextStoreOptions {
embedding_dim: Some(3),
..Default::default()
};
let mut store = ContextStore::open_with_options(&uri, options)
.await
.unwrap();
assert_eq!(store.embedding_dim(), 3);

let mut first = text_record("custom-a", 0.0);
first.embedding = Some(make_embedding_with_dim(3, 0.0));
let mut second = text_record("custom-b", 0.0);
second.embedding = Some(make_embedding_with_dim(3, 1.0));
store.add(&[first.clone(), second.clone()]).await.unwrap();

let query = make_embedding_with_dim(3, 1.0);
let results = store.search(&query, Some(2)).await.unwrap();
assert_eq!(results[0].record.id, second.id);

let reopened = ContextStore::open(&uri).await.unwrap();
assert_eq!(reopened.embedding_dim(), 3);
let results = reopened.search(&query, Some(1)).await.unwrap();
assert_eq!(results[0].record.id, second.id);

let err = reopened
.search(&make_embedding(1.0), None)
.await
.unwrap_err();
assert!(
err.to_string().contains("embedding dimension 3"),
"unexpected error message: {err}"
);
});
}

#[test]
fn existing_default_dimension_dataset_opens_without_options() {
let dir = TempDir::new().unwrap();
let uri = dir.path().to_string_lossy().to_string();
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let mut store = ContextStore::open(&uri).await.unwrap();
assert_eq!(store.embedding_dim(), DEFAULT_EMBEDDING_DIM);
store.add(&[text_record("default-dim", 0.0)]).await.unwrap();
drop(store);

let reopened = ContextStore::open(&uri).await.unwrap();
assert_eq!(reopened.embedding_dim(), DEFAULT_EMBEDDING_DIM);
reopened
.search(&make_embedding(0.0), Some(1))
.await
.unwrap();
});
}

#[test]
fn opening_existing_dataset_rejects_mismatched_requested_dimension() {
let dir = TempDir::new().unwrap();
let uri = dir.path().to_string_lossy().to_string();
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let options = ContextStoreOptions {
embedding_dim: Some(3),
..Default::default()
};
ContextStore::open_with_options(&uri, options)
.await
.unwrap();

let mismatched = ContextStoreOptions {
embedding_dim: Some(4),
..Default::default()
};
let err = match ContextStore::open_with_options(&uri, mismatched).await {
Ok(_) => panic!("expected mismatched embedding dimension to fail"),
Err(err) => err,
};
assert!(
err.to_string()
.contains("does not match requested dimension 4"),
"unexpected error message: {err}"
);
});
}

#[test]
fn list_hides_expired_and_retired_records_by_default() {
let dir = TempDir::new().unwrap();
Expand Down Expand Up @@ -2311,6 +2473,7 @@ mod tests {
true,
false,
true,
DEFAULT_EMBEDDING_DIM,
));
let empty_batch = RecordBatch::new_empty(schema.clone());
let batches = RecordBatchIterator::new(
Expand Down
1 change: 1 addition & 0 deletions crates/lance-context-server/src/routes/contexts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub async fn create_context(
let uri = state.context_uri(&req.name);
let options = ContextStoreOptions {
storage_options: req.storage_options,
embedding_dim: req.embedding_dim,
blob_columns,
id_index_type,
..Default::default()
Expand Down
Loading
Loading