From c32d2c764bd23fffb9569974c3f901fa62198056 Mon Sep 17 00:00:00 2001 From: dcfocus Date: Fri, 12 Jun 2026 02:45:28 +0000 Subject: [PATCH] feat: make embedding dimension configurable --- crates/lance-context-api/src/lib.rs | 2 + crates/lance-context-core/src/store.rs | 197 ++++++++++++++++-- .../src/routes/contexts.rs | 1 + python/python/lance_context/api.py | 11 +- python/src/lib.rs | 5 +- python/tests/test_persistence.py | 26 +++ 6 files changed, 223 insertions(+), 19 deletions(-) diff --git a/crates/lance-context-api/src/lib.rs b/crates/lance-context-api/src/lib.rs index 89ae788..2092bee 100644 --- a/crates/lance-context-api/src/lib.rs +++ b/crates/lance-context-api/src/lib.rs @@ -79,6 +79,8 @@ pub struct CreateContextRequest { pub id_index_type: Option, #[serde(default)] pub blob_columns: Option>, + #[serde(default)] + pub embedding_dim: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/lance-context-core/src/store.rs b/crates/lance-context-core/src/store.rs index 527d8ba..c46902c 100644 --- a/crates/lance-context-core/src/store.rs +++ b/crates/lance-context-core/src/store.rs @@ -134,6 +134,7 @@ pub struct ContextStore { pub compaction_config: CompactionConfig, blob_columns: HashSet, id_index_type: IdIndexType, + embedding_dim: i32, } /// Additional configuration when opening a [`ContextStore`]. @@ -141,6 +142,9 @@ pub struct ContextStore { pub struct ContextStoreOptions { pub storage_options: Option>, pub compaction: CompactionConfig, + /// Width of the fixed-size embedding vector for newly-created datasets. + /// Existing datasets always use the dimension persisted in their schema. + pub embedding_dim: Option, /// Column names that should use Lance V1 blob encoding. /// Valid values: `"text_payload"`, `"binary_payload"`. pub blob_columns: HashSet, @@ -212,15 +216,37 @@ impl ContextStore { } } + let requested_embedding_dim = match options.embedding_dim { + Some(dim) => { + validate_embedding_dim(dim)?; + dim + } + None => DEFAULT_EMBEDDING_DIM, + }; let storage_options = options.storage_options(); let blob_columns = options.blob_columns.clone(); - let dataset = match Self::load_with_options(uri, storage_options.clone()).await { - Ok(dataset) => dataset, + let (dataset, created) = match Self::load_with_options(uri, storage_options.clone()).await { + Ok(dataset) => (dataset, false), Err(LanceError::DatasetNotFound { .. }) => { - Self::create_with_options(uri, storage_options, &blob_columns).await? + let dataset = Self::create_with_options( + uri, + storage_options, + &blob_columns, + requested_embedding_dim, + ) + .await?; + (dataset, true) } Err(err) => return Err(err), }; + let arrow_schema: Schema = dataset.schema().into(); + let embedding_dim = embedding_dim_from_schema(&arrow_schema)?; + if !created && options.embedding_dim.is_some() && embedding_dim != requested_embedding_dim { + return Err(LanceError::from(ArrowError::InvalidArgumentError(format!( + "existing context embedding dimension {} does not match requested dimension {}", + embedding_dim, requested_embedding_dim + )))); + } let mut store = Self { dataset, @@ -234,6 +260,7 @@ impl ContextStore { compaction_config: options.compaction, blob_columns, id_index_type: options.id_index_type, + embedding_dim, }; // Ensure id index if configured @@ -245,6 +272,12 @@ impl ContextStore { Ok(store) } + /// Embedding vector width persisted in this context dataset schema. + #[must_use] + pub fn embedding_dim(&self) -> i32 { + self.embedding_dim + } + /// Append context records to the store and return the new dataset version. pub async fn add(&mut self, entries: &[ContextRecord]) -> LanceResult { if entries.is_empty() { @@ -648,7 +681,7 @@ impl ContextStore { filters: Option<&RecordFilters>, options: LifecycleQueryOptions, ) -> LanceResult> { - validate_query_dimension(query)?; + validate_query_dimension(query, self.embedding_dim)?; let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT); if top_k == 0 { @@ -689,7 +722,7 @@ impl ContextStore { } if let Some(query) = vector { - validate_query_dimension(query)?; + validate_query_dimension(query, self.embedding_dim)?; } let top_k = limit.unwrap_or(DEFAULT_SEARCH_LIMIT); @@ -1013,7 +1046,12 @@ impl ContextStore { /// Lance V1 blob encoding (out-of-line binary buffers). For `text_payload`, /// this also changes the Arrow type from `LargeUtf8` to `LargeBinary`. pub fn schema(blob_columns: &HashSet) -> Schema { - Self::schema_with_options(blob_columns, true, true, true, true) + Self::schema_with_embedding_dim(blob_columns, DEFAULT_EMBEDDING_DIM) + } + + /// Lance schema for a context store using a caller-selected embedding width. + pub fn schema_with_embedding_dim(blob_columns: &HashSet, embedding_dim: i32) -> Schema { + Self::schema_with_options(blob_columns, true, true, true, true, embedding_dim) } fn schema_with_options( @@ -1022,6 +1060,7 @@ impl ContextStore { include_metadata: bool, include_relationships: bool, include_lifecycle: bool, + embedding_dim: i32, ) -> Schema { let mut id_metadata = HashMap::new(); id_metadata.insert( @@ -1110,7 +1149,7 @@ impl ContextStore { "embedding", DataType::FixedSizeList( Arc::new(Field::new("item", DataType::Float32, true)), - DEFAULT_EMBEDDING_DIM, + embedding_dim, ), true, ), @@ -1137,8 +1176,9 @@ impl ContextStore { uri: &str, storage_options: Option>, blob_columns: &HashSet, + embedding_dim: i32, ) -> LanceResult { - let schema = Arc::new(Self::schema(blob_columns)); + let schema = Arc::new(Self::schema_with_embedding_dim(blob_columns, embedding_dim)); let empty_batch = RecordBatch::new_empty(schema.clone()); let batches = RecordBatchIterator::new( vec![Ok::(empty_batch)].into_iter(), @@ -1259,7 +1299,7 @@ impl ContextStore { ); let mut embedding_builder = - FixedSizeListBuilder::new(Float32Builder::new(), DEFAULT_EMBEDDING_DIM); + FixedSizeListBuilder::new(Float32Builder::new(), self.embedding_dim); for entry in entries { id_builder.append_value(&entry.id); @@ -1360,11 +1400,11 @@ impl ContextStore { } if let Some(embedding) = &entry.embedding { - if embedding.len() != DEFAULT_EMBEDDING_DIM as usize { + if embedding.len() != self.embedding_dim as usize { return Err(ArrowError::InvalidArgumentError(format!( "embedding length {} does not match expected dimension {}", embedding.len(), - DEFAULT_EMBEDDING_DIM + self.embedding_dim )) .into()); } @@ -1378,7 +1418,7 @@ impl ContextStore { } else { // FixedSizeListBuilder requires padding values for null slots. let values_builder = embedding_builder.values(); - for _ in 0..DEFAULT_EMBEDDING_DIM { + for _ in 0..self.embedding_dim { values_builder.append_null(); } embedding_builder.append(false); @@ -1844,12 +1884,21 @@ fn l2_distance(left: &[f32], right: &[f32]) -> f32 { .sqrt() } -fn validate_query_dimension(query: &[f32]) -> LanceResult<()> { - if query.len() != DEFAULT_EMBEDDING_DIM as usize { +fn validate_embedding_dim(embedding_dim: i32) -> LanceResult<()> { + if embedding_dim <= 0 { + return Err(LanceError::from(ArrowError::InvalidArgumentError(format!( + "embedding_dim must be positive, got {embedding_dim}" + )))); + } + Ok(()) +} + +fn validate_query_dimension(query: &[f32], embedding_dim: i32) -> LanceResult<()> { + if query.len() != embedding_dim as usize { return Err(ArrowError::InvalidArgumentError(format!( "query length {} does not match embedding dimension {}", query.len(), - DEFAULT_EMBEDDING_DIM + embedding_dim )) .into()); } @@ -1965,6 +2014,24 @@ fn compare_optional_score(left: Option, right: Option) -> Ordering { } } +fn embedding_dim_from_schema(schema: &Schema) -> LanceResult { + let field = schema + .field_with_name("embedding") + .map_err(LanceError::from)?; + let DataType::FixedSizeList(item_field, embedding_dim) = field.data_type() else { + return Err(LanceError::from(ArrowError::InvalidArgumentError( + "embedding column must be a FixedSizeList".to_string(), + ))); + }; + if item_field.data_type() != &DataType::Float32 { + return Err(LanceError::from(ArrowError::InvalidArgumentError( + "embedding column must contain Float32 values".to_string(), + ))); + } + validate_embedding_dim(*embedding_dim)?; + Ok(*embedding_dim) +} + fn column_as<'a, A>(batch: &'a RecordBatch, name: &str) -> LanceResult<&'a A> where A: Array + 'static, @@ -1997,14 +2064,18 @@ mod tests { use chrono::{Duration as ChronoDuration, Utc}; use tempfile::TempDir; - fn make_embedding(pivot: f32) -> Vec { - let mut values = vec![0.0; DEFAULT_EMBEDDING_DIM as usize]; + fn make_embedding_with_dim(dim: usize, pivot: f32) -> Vec { + let mut values = vec![0.0; dim]; if !values.is_empty() { values[0] = pivot; } values } + fn make_embedding(pivot: f32) -> Vec { + make_embedding_with_dim(DEFAULT_EMBEDDING_DIM as usize, pivot) + } + fn text_record(id: &str, embedding_pivot: f32) -> ContextRecord { ContextRecord { id: id.to_string(), @@ -2114,6 +2185,97 @@ mod tests { }); } + #[test] + fn custom_embedding_dimension_round_trips_add_search_and_reopen() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let options = ContextStoreOptions { + embedding_dim: Some(3), + ..Default::default() + }; + let mut store = ContextStore::open_with_options(&uri, options) + .await + .unwrap(); + assert_eq!(store.embedding_dim(), 3); + + let mut first = text_record("custom-a", 0.0); + first.embedding = Some(make_embedding_with_dim(3, 0.0)); + let mut second = text_record("custom-b", 0.0); + second.embedding = Some(make_embedding_with_dim(3, 1.0)); + store.add(&[first.clone(), second.clone()]).await.unwrap(); + + let query = make_embedding_with_dim(3, 1.0); + let results = store.search(&query, Some(2)).await.unwrap(); + assert_eq!(results[0].record.id, second.id); + + let reopened = ContextStore::open(&uri).await.unwrap(); + assert_eq!(reopened.embedding_dim(), 3); + let results = reopened.search(&query, Some(1)).await.unwrap(); + assert_eq!(results[0].record.id, second.id); + + let err = reopened + .search(&make_embedding(1.0), None) + .await + .unwrap_err(); + assert!( + err.to_string().contains("embedding dimension 3"), + "unexpected error message: {err}" + ); + }); + } + + #[test] + fn existing_default_dimension_dataset_opens_without_options() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let mut store = ContextStore::open(&uri).await.unwrap(); + assert_eq!(store.embedding_dim(), DEFAULT_EMBEDDING_DIM); + store.add(&[text_record("default-dim", 0.0)]).await.unwrap(); + drop(store); + + let reopened = ContextStore::open(&uri).await.unwrap(); + assert_eq!(reopened.embedding_dim(), DEFAULT_EMBEDDING_DIM); + reopened + .search(&make_embedding(0.0), Some(1)) + .await + .unwrap(); + }); + } + + #[test] + fn opening_existing_dataset_rejects_mismatched_requested_dimension() { + let dir = TempDir::new().unwrap(); + let uri = dir.path().to_string_lossy().to_string(); + let runtime = tokio::runtime::Runtime::new().unwrap(); + runtime.block_on(async { + let options = ContextStoreOptions { + embedding_dim: Some(3), + ..Default::default() + }; + ContextStore::open_with_options(&uri, options) + .await + .unwrap(); + + let mismatched = ContextStoreOptions { + embedding_dim: Some(4), + ..Default::default() + }; + let err = match ContextStore::open_with_options(&uri, mismatched).await { + Ok(_) => panic!("expected mismatched embedding dimension to fail"), + Err(err) => err, + }; + assert!( + err.to_string() + .contains("does not match requested dimension 4"), + "unexpected error message: {err}" + ); + }); + } + #[test] fn list_hides_expired_and_retired_records_by_default() { let dir = TempDir::new().unwrap(); @@ -2311,6 +2473,7 @@ mod tests { true, false, true, + DEFAULT_EMBEDDING_DIM, )); let empty_batch = RecordBatch::new_empty(schema.clone()); let batches = RecordBatchIterator::new( diff --git a/crates/lance-context-server/src/routes/contexts.rs b/crates/lance-context-server/src/routes/contexts.rs index bc42fe1..f15791e 100644 --- a/crates/lance-context-server/src/routes/contexts.rs +++ b/crates/lance-context-server/src/routes/contexts.rs @@ -40,6 +40,7 @@ pub async fn create_context( let uri = state.context_uri(&req.name); let options = ContextStoreOptions { storage_options: req.storage_options, + embedding_dim: req.embedding_dim, blob_columns, id_index_type, ..Default::default() diff --git a/python/python/lance_context/api.py b/python/python/lance_context/api.py index 0995eef..5d3fe4d 100644 --- a/python/python/lance_context/api.py +++ b/python/python/lance_context/api.py @@ -311,6 +311,7 @@ def __init__( compaction_target_rows: int = 1_000_000, quiet_hours: list[tuple[int, int]] | None = None, id_index_type: str | None = None, + embedding_dim: int | None = None, ) -> None: options = _merge_storage_options( storage_options, @@ -330,12 +331,18 @@ def __init__( "quiet_hours": quiet_hours or [], } - if options or compaction_config["enabled"] or id_index_type: + if ( + options + or compaction_config["enabled"] + or id_index_type + or embedding_dim is not None + ): self._inner = _Context.create( uri, storage_options=options or None, compaction_config=compaction_config, id_index_type=id_index_type, + embedding_dim=embedding_dim, ) else: self._inner = _Context.create(uri) @@ -358,6 +365,7 @@ def create( compaction_target_rows: int = 1_000_000, quiet_hours: list[tuple[int, int]] | None = None, id_index_type: str | None = None, + embedding_dim: int | None = None, ) -> Context: return cls( uri, @@ -374,6 +382,7 @@ def create( compaction_target_rows=compaction_target_rows, quiet_hours=quiet_hours, id_index_type=id_index_type, + embedding_dim=embedding_dim, ) def uri(self) -> str: diff --git a/python/src/lib.rs b/python/src/lib.rs index 134e2d8..8b6f0e5 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -167,7 +167,8 @@ fn filters_from_json(filters_json: Option) -> PyResult, py: Python<'_>, @@ -176,6 +177,7 @@ impl Context { compaction_config: Option<&Bound<'_, PyDict>>, blob_columns: Option>, id_index_type: Option, + embedding_dim: Option, ) -> PyResult { let runtime = Arc::new(Runtime::new().map_err(to_py_err)?); @@ -196,6 +198,7 @@ impl Context { let options = ContextStoreOptions { storage_options: storage_options_from_dict(storage_options)?, compaction: compaction_config_from_dict(compaction_config)?, + embedding_dim, blob_columns: blob_set, id_index_type: id_idx, }; diff --git a/python/tests/test_persistence.py b/python/tests/test_persistence.py index 988c1bf..badcdc1 100644 --- a/python/tests/test_persistence.py +++ b/python/tests/test_persistence.py @@ -31,6 +31,12 @@ def _embedding(pivot: float) -> list[float]: return values +def _embedding_with_dim(dim: int, pivot: float) -> list[float]: + values = [0.0] * dim + values[0] = pivot + return values + + def _free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("127.0.0.1", 0)) @@ -349,6 +355,26 @@ def test_retrieve_supports_text_only(tmp_path: Path) -> None: assert hits[0]["text_score"] == 1.0 +def test_custom_embedding_dimension_round_trips(tmp_path: Path) -> None: + uri = tmp_path / "context.lance" + ctx = Context.create(str(uri), embedding_dim=3) + near = _embedding_with_dim(3, 0.0) + far = _embedding_with_dim(3, 1.0) + + ctx.add("assistant", "small vector near", embedding=near) + ctx.add("assistant", "small vector far", embedding=far) + + hits = ctx.search(far, limit=1) + assert hits[0]["text"] == "small vector far" + + reopened = Context.create(str(uri)) + hits = reopened.search(far, limit=1) + assert hits[0]["text"] == "small vector far" + + with pytest.raises(RuntimeError, match="embedding dimension 3"): + reopened.search(_embedding(1.0), limit=1) + + def test_lifecycle_fields_round_trip_and_default_filtering(tmp_path: Path) -> None: uri = tmp_path / "context.lance" ctx = Context.create(str(uri))