From 6349dc5b0c50db6984d7fd5e322d3bf211f58a94 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 14:18:04 -0700 Subject: [PATCH 1/3] feat: accept TableProvider write inputs for merge_insert and insert Make `Arc` the canonical internal write input behind ergonomic wrappers. Re-readable sources are now replayed across retries without spilling to disk, and materialized sources report statistics that let DataFusion choose the merge-join build side. - merge_insert: add `execute_provider` (canonical) and `execute_batches` (multi-partition `MemTable`); `execute(stream)` becomes a wrapper. Retries re-scan the provider instead of the removed `new_source_iter`/ `SpillStreamIter` replay layer. A one-shot stream spills only when retries are enabled; `spill_for_retry(false)` fails fast instead of buffering. - Plan against the provider directly so a MemTable/file source's statistics reach the join; stream sources keep the one-shot path (no stats lost, and the source's original error type is preserved). - InsertBuilder gains `execute_provider`/`execute_uncommitted_provider`. - Python routes materialized inputs (pa.Table, RecordBatch, DataFrame, ...) through the in-memory path; streams and scanners keep spilling. Re-scannable Python providers (pa.dataset.Dataset/Scanner) and parallel data-file writes over provider partitions remain follow-ups. Issue: #4583 Co-Authored-By: Claude Opus 4.8 (1M context) --- python/python/lance/dataset.py | 11 +- python/python/lance/lance/__init__.pyi | 7 + python/python/lance/types.py | 28 + python/python/tests/test_dataset.py | 23 + python/src/dataset.rs | 54 ++ rust/lance-datafusion/src/exec.rs | 31 +- rust/lance-datafusion/src/spill.rs | 103 +++- rust/lance/src/dataset/write.rs | 109 +--- rust/lance/src/dataset/write/insert.rs | 47 ++ rust/lance/src/dataset/write/merge_insert.rs | 592 +++++++++++++++++-- 10 files changed, 856 insertions(+), 149 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 45dc1b253d3..f361122bb4f 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -70,7 +70,7 @@ from .lance import __version__ as __version__ from .lance import _Session as Session from .query import FullTextQuery -from .types import _coerce_reader +from .types import _coerce_reader, _is_materialized from .udf import BatchUDF, normalize_transform from .udf import BatchUDFCheckpoint as BatchUDFCheckpoint from .udf import batch_udf as batch_udf @@ -392,6 +392,12 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None): """ reader = _coerce_reader(data_obj, schema) + # Materialized sources are wrapped in an in-memory table so retries never + # spill and the source's statistics can drive the join; everything else is + # treated as a one-shot stream. + if _is_materialized(data_obj): + return super(MergeInsertBuilder, self).execute_batches(reader) + return super(MergeInsertBuilder, self).execute(reader) def execute_uncommitted( @@ -416,6 +422,9 @@ def execute_uncommitted( """ reader = _coerce_reader(data_obj, schema) + if _is_materialized(data_obj): + return super(MergeInsertBuilder, self).execute_uncommitted_batches(reader) + return super(MergeInsertBuilder, self).execute_uncommitted(reader) # These next three overrides exist only to document the methods diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 74db076db41..323e78879cb 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -471,6 +471,13 @@ class _MergeInsertBuilder: def when_not_matched_insert_all(self) -> Self: ... def when_not_matched_by_source_delete(self, expr: Optional[str] = None) -> Self: ... def execute(self, new_data: pa.RecordBatchReader) -> ExecuteResult: ... + def execute_batches(self, new_data: pa.RecordBatchReader) -> ExecuteResult: ... + def execute_uncommitted( + self, new_data: pa.RecordBatchReader + ) -> tuple[Transaction, ExecuteResult]: ... + def execute_uncommitted_batches( + self, new_data: pa.RecordBatchReader + ) -> tuple[Transaction, ExecuteResult]: ... class _Scanner: @property diff --git a/python/python/lance/types.py b/python/python/lance/types.py index 41cc191e4d6..0ed2ae9f95e 100644 --- a/python/python/lance/types.py +++ b/python/python/lance/types.py @@ -52,6 +52,34 @@ def _casting_recordbatch_iter( yield batch +def _is_materialized(data_obj: ReaderLike) -> bool: + """Whether ``data_obj`` is fully materialized in memory. + + Materialized sources (tables, in-memory frames) can be wrapped in an + in-memory table for replay without spilling and to expose exact statistics. + Streaming or re-readable sources (readers, scanners, datasets, generators) + are not considered materialized. + """ + if _check_for_pandas(data_obj) and isinstance(data_obj, pd.DataFrame): + return True + if isinstance(data_obj, (pa.Table, pa.RecordBatch)): + return True + if ( + type(data_obj).__module__.startswith("polars") + and data_obj.__class__.__name__ == "DataFrame" + ): + return True + if isinstance(data_obj, dict): + return True + if ( + isinstance(data_obj, list) + and len(data_obj) > 0 + and isinstance(data_obj[0], dict) + ): + return True + return False + + def _coerce_reader( data_obj: ReaderLike, schema: Optional[pa.Schema] = None ) -> pa.RecordBatchReader: diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 45866f3c4da..4c373ad0180 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -2342,6 +2342,29 @@ def test_merge_insert(tmp_path: Path): check_merge_stats(merge_dict, (None, None, None)) +@pytest.mark.parametrize("materialized", [True, False]) +def test_merge_insert_input_kinds(tmp_path: Path, materialized: bool): + # A materialized pa.Table is routed through the in-memory (MemTable) path, + # while a RecordBatchReader is routed through the streaming path. Both must + # produce identical results. + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + base = pa.table({"id": range(5), "value": [0] * 5}, schema=schema) + new = pa.table({"id": [1, 2, 5, 6], "value": [10, 20, 50, 60]}, schema=schema) + + dataset = lance.write_dataset(base, tmp_path / "dataset", mode="create") + source = new if materialized else new.to_reader() + + dataset.merge_insert( + "id" + ).when_matched_update_all().when_not_matched_insert_all().execute(source) + + result = dataset.to_table().sort_by("id").to_pydict() + assert result == { + "id": [0, 1, 2, 3, 4, 5, 6], + "value": [0, 10, 20, 0, 0, 50, 60], + } + + def test_merge_insert_subcols(tmp_path: Path): initial_data = pa.table( { diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 8bfa81aeae4..1d75d9bd128 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -434,6 +434,60 @@ impl MergeInsertBuilder { Ok((PyLance(transaction), stats)) } + /// Execute the merge insert from fully-materialized data. + /// + /// The data is read into memory and wrapped in an in-memory table, so retries + /// never spill to disk and the source's statistics drive the join. Callers + /// should only route in-memory inputs (e.g. a `pa.Table`) here. + pub fn execute_batches(&mut self, new_data: &Bound) -> PyResult> { + let py = new_data.py(); + let reader = convert_reader(new_data)?; + let batches = reader + .collect::, _>>() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let job = self + .builder + .try_build() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let (new_dataset, stats) = rt() + .spawn(Some(py), job.execute_batches(batches))? + .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; + + let dataset = self.dataset.bind(py); + dataset.borrow_mut().ds = new_dataset; + + Ok(Self::build_stats(&stats, py)?.into()) + } + + /// [`Self::execute_batches`] without committing; returns the transaction. + pub fn execute_uncommitted_batches<'a>( + &mut self, + new_data: &Bound<'a, PyAny>, + ) -> PyResult<(PyLance, Bound<'a, PyDict>)> { + let py = new_data.py(); + let reader = convert_reader(new_data)?; + let batches = reader + .collect::, _>>() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let job = self + .builder + .try_build() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let UncommittedMergeInsert { + transaction, stats, .. + } = rt() + .spawn(Some(py), job.execute_uncommitted_batches(batches))? + .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; + + let stats = Self::build_stats(&stats, py)?; + + Ok((PyLance(transaction), stats)) + } + #[pyo3(signature=(schema = None, verbose = false))] pub fn explain_plan( &mut self, diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index 8f346f45612..e08bd00f6f0 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -16,7 +16,7 @@ use arrow_array::RecordBatch; use arrow_schema::Schema as ArrowSchema; use datafusion::physical_plan::metrics::MetricType; use datafusion::{ - catalog::streaming::StreamingTable, + catalog::{TableProvider, streaming::StreamingTable}, dataframe::DataFrame, execution::{ TaskContext, @@ -39,7 +39,7 @@ use datafusion::{ use datafusion_common::{DataFusionError, Statistics}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use futures::{StreamExt, stream}; +use futures::{StreamExt, TryStreamExt, stream}; use lance_arrow::SchemaExt; use lance_core::{ Error, Result, @@ -867,6 +867,33 @@ impl SessionContextExt for SessionContext { } } +/// Scan a [`TableProvider`] into a single-partition [`SendableRecordBatchStream`]. +/// +/// Multi-partition providers are coalesced into a single partition. This adapts a +/// re-scannable provider back into the one stream the writer pipeline consumes; +/// re-scanning the same provider (e.g. on a write retry) yields a fresh stream. +/// +/// The first batch is read eagerly and re-chained onto the stream. This surfaces a +/// scan error from the source directly, before it can be fed into (and obscured by) +/// a downstream plan — preserving the original error type for callers. +pub async fn provider_to_stream( + provider: Arc, +) -> Result { + let ctx = SessionContext::new(); + let plan = provider.scan(&ctx.state(), None, &[], None).await?; + let plan: Arc = + if plan.properties().output_partitioning().partition_count() > 1 { + Arc::new(CoalescePartitionsExec::new(plan)) + } else { + plan + }; + let schema = plan.schema(); + let mut stream = plan.execute(0, ctx.task_ctx())?; + let first = stream.try_next().await?; + let rechained = stream::iter(first.map(Ok)).chain(stream); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, rechained))) +} + #[derive(Clone, Debug)] pub struct StrictBatchSizeExec { input: Arc, diff --git a/rust/lance-datafusion/src/spill.rs b/rust/lance-datafusion/src/spill.rs index 8fa60c93ab6..1ffbe9083c4 100644 --- a/rust/lance-datafusion/src/spill.rs +++ b/rust/lance-datafusion/src/spill.rs @@ -9,13 +9,17 @@ use std::{ use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; use arrow_array::RecordBatch; -use arrow_schema::{ArrowError, Schema}; +use arrow_schema::{ArrowError, Schema, SchemaRef}; use datafusion::{ - execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter, + catalog::{TableProvider, streaming::StreamingTable}, + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{stream::RecordBatchStreamAdapter, streaming::PartitionStream}, }; use datafusion_common::DataFusionError; +use futures::StreamExt; use lance_arrow::memory::MemoryAccumulator; use lance_core::error::LanceOptionExt; +use lance_core::utils::tempfile::TempDir; /// Start a spill of Arrow data to a file that can be read later multiple times. /// @@ -60,6 +64,101 @@ pub fn create_replay_spill( (sender, receiver) } +/// Wrap a one-shot [`SendableRecordBatchStream`] in a re-scannable [`TableProvider`]. +/// +/// The source is drained in the background into a replayable spill: up to +/// `memory_limit` bytes are buffered in memory before spilling to a temporary file +/// on disk. Each scan of the returned provider replays the full source from the +/// spill, which is what makes a one-shot stream usable in the write retry loop. +/// +/// The provider reports no statistics — the source size is not known until it has +/// been fully drained — so callers that need source statistics (e.g. to drive join +/// ordering) should prefer a materialized or file-backed provider instead. +pub async fn spilling_table_provider( + mut source: SendableRecordBatchStream, + memory_limit: usize, +) -> Result, DataFusionError> { + let schema = source.schema(); + let tmp_dir = tokio::task::spawn_blocking(TempDir::try_new) + .await + .map_err(|e| DataFusionError::Execution(format!("Failed to spawn temp dir task: {e}")))? + .map_err(|e| DataFusionError::Execution(format!("Failed to create temp dir: {e}")))?; + let tmp_path = tmp_dir.std_path().join("spill.arrows"); + let (mut sender, receiver) = create_replay_spill(tmp_path, schema.clone(), memory_limit); + + // Drain the one-shot source into the spill once, in the background. The spill + // tees to memory/disk so the first reader can consume batches as they arrive + // while later readers replay the complete source. + let drain_handle = tokio::task::spawn(async move { + let mut errored = false; + while let Some(res) = source.next().await { + match res { + Ok(batch) => { + if let Err(e) = sender.write(batch).await { + sender.send_error(e); + errored = true; + break; + } + } + Err(e) => { + sender.send_error(e); + errored = true; + break; + } + } + } + // Only finish on a clean drain. Calling finish() after an error would + // overwrite the original (replayable) error with a generic one, losing + // the source error's type (e.g. an external error from user code). + if !errored && let Err(err) = sender.finish().await { + sender.send_error(err); + } + sender + }); + + let partition = Arc::new(SpillPartition { + schema: schema.clone(), + receiver, + _tmp_dir: Arc::new(tmp_dir), + _drain_handle: Arc::new(drain_handle), + }); + Ok(Arc::new(StreamingTable::try_new(schema, vec![partition])?)) +} + +/// A [`PartitionStream`] backed by a replayable spill. +/// +/// Each call to [`PartitionStream::execute`] opens a fresh stream over the spill, +/// so the partition can be scanned repeatedly. The spill file and the background +/// task draining the source are kept alive for as long as this partition exists. +struct SpillPartition { + schema: SchemaRef, + receiver: SpillReceiver, + // The spilled data lives in this temp dir; dropping it deletes the spill file. + _tmp_dir: Arc, + // Keeps the background drain task (which owns the `SpillSender`) alive. The + // `SpillSender` must outlive the readers or they error out, so we hold the + // handle rather than detaching it. + _drain_handle: Arc>, +} + +impl std::fmt::Debug for SpillPartition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SpillPartition") + .field("schema", &self.schema) + .finish() + } +} + +impl PartitionStream for SpillPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + self.receiver.read() + } +} + #[derive(Clone)] pub struct SpillReceiver { status_receiver: tokio::sync::watch::Receiver, diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index ff0a119158c..9af133d2e9c 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -4,8 +4,7 @@ use arrow_array::RecordBatch; use chrono::TimeDelta; use datafusion::physical_plan::SendableRecordBatchStream; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use lance_arrow::{ ARROW_EXT_NAME_KEY, BLOB_DEDICATED_SIZE_THRESHOLD_META_KEY, BLOB_INLINE_SIZE_THRESHOLD_META_KEY, BLOB_META_KEY, BLOB_V2_EXT_NAME, @@ -13,12 +12,9 @@ use lance_arrow::{ use lance_core::datatypes::{ NullabilityComparison, OnMissing, OnTypeMismatch, SchemaCompareOptions, }; -use lance_core::error::LanceOptionExt; -use lance_core::utils::tempfile::TempDir; use lance_core::utils::tracing::{AUDIT_MODE_CREATE, AUDIT_TYPE_DATA, TRACE_FILE_AUDIT}; use lance_core::{Error, Result, datatypes::Schema}; use lance_datafusion::chunker::{break_stream, chunk_stream}; -use lance_datafusion::spill::{SpillReceiver, SpillSender, create_replay_spill}; use lance_datafusion::utils::StreamingWriteSource; use lance_file::previous::writer::{ FileWriter as PreviousFileWriter, ManifestProvider as PreviousManifestProvider, @@ -1510,108 +1506,6 @@ async fn resolve_commit_handler( } } -/// Create an iterator of record batch streams from the given source. -/// -/// If `enable_retries` is true, then the source will be saved either in memory -/// or spilled to disk to allow replaying the source in case of a failure. The -/// source will be kept in memory if either (1) the size hint shows that -/// there is only one batch or (2) the stream contains less than 100MB of -/// data. Otherwise, the source will be spilled to a temporary file on disk. -/// -/// This is used to support retries on write operations. -async fn new_source_iter( - source: SendableRecordBatchStream, - enable_retries: bool, -) -> Result + Send + 'static>> { - if enable_retries { - let schema = source.schema(); - - // If size hint shows there is only one batch, spilling has no benefit, just keep that - // in memory. (This is a pretty common case.) - let size_hint = source.size_hint(); - if size_hint.0 == 1 && size_hint.1 == Some(1) { - let batches: Vec = source.try_collect().await?; - Ok(Box::new(std::iter::repeat_with(move || { - Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - futures::stream::iter(batches.clone().into_iter().map(Ok)), - )) as SendableRecordBatchStream - }))) - } else { - // Allow buffering up to 100MB in memory before spilling to disk. - Ok(Box::new( - SpillStreamIter::try_new(source, 100 * 1024 * 1024).await?, - )) - } - } else { - Ok(Box::new(std::iter::once(source))) - } -} - -struct SpillStreamIter { - receiver: SpillReceiver, - _sender_handle: tokio::task::JoinHandle, - // This temp dir is used to store the spilled data. It is kept alive by - // this struct. When this struct is dropped, the Drop implementation of - // tempfile::TempDir will delete the temp dir. - _tmp_dir: TempDir, -} - -impl SpillStreamIter { - pub async fn try_new( - mut source: SendableRecordBatchStream, - memory_limit: usize, - ) -> Result { - let tmp_dir = tokio::task::spawn_blocking(|| { - TempDir::try_new() - .map_err(|e| Error::invalid_input(format!("Failed to create temp dir: {}", e))) - }) - .await - .ok() - .expect_ok()??; - - let tmp_path = tmp_dir.std_path().join("spill.arrows"); - let (mut sender, receiver) = create_replay_spill(tmp_path, source.schema(), memory_limit); - - let sender_handle = tokio::task::spawn(async move { - while let Some(res) = source.next().await { - match res { - Ok(batch) => match sender.write(batch).await { - Ok(_) => {} - Err(e) => { - sender.send_error(e); - break; - } - }, - Err(e) => { - sender.send_error(e); - break; - } - } - } - - if let Err(err) = sender.finish().await { - sender.send_error(err); - } - sender - }); - - Ok(Self { - receiver, - _tmp_dir: tmp_dir, - _sender_handle: sender_handle, - }) - } -} - -impl Iterator for SpillStreamIter { - type Item = SendableRecordBatchStream; - - fn next(&mut self) -> Option { - Some(self.receiver.read()) - } -} - #[cfg(test)] mod tests { use super::*; @@ -3315,6 +3209,7 @@ mod tests { async fn test_write_interruption_recovery() { use super::commit::CommitBuilder; use arrow_array::record_batch; + use lance_core::utils::tempfile::TempDir; // Create a temporary directory for testing let temp_dir = TempDir::default(); diff --git a/rust/lance/src/dataset/write/insert.rs b/rust/lance/src/dataset/write/insert.rs index bfd702c9c3b..2b7d6869f18 100644 --- a/rust/lance/src/dataset/write/insert.rs +++ b/rust/lance/src/dataset/write/insert.rs @@ -5,11 +5,13 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_array::{RecordBatch, RecordBatchIterator}; +use datafusion::catalog::TableProvider; use datafusion::execution::SendableRecordBatchStream; use humantime::format_duration; use lance_core::datatypes::{NullabilityComparison, Schema, SchemaCompareOptions}; use lance_core::utils::tracing::{DATASET_WRITING_EVENT, TRACE_DATASET_EVENTS}; use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET}; +use lance_datafusion::exec::provider_to_stream; use lance_datafusion::utils::StreamingWriteSource; use lance_file::version::LanceFileVersion; use lance_io::object_store::ObjectStore; @@ -93,6 +95,19 @@ impl<'a> InsertBuilder<'a> { self.execute_stream_impl(stream, schema).await } + /// Execute the insert operation with a [`TableProvider`] source. + /// + /// The provider is scanned into a stream and written. This mirrors + /// [`crate::dataset::MergeInsertJob::execute_provider`] so the same input + /// shapes are accepted across write operations. Inserts do not retry, so the + /// provider is scanned only once. + /// + /// [`TableProvider`]: datafusion::catalog::TableProvider + pub async fn execute_provider(&self, provider: Arc) -> Result { + let stream = provider_to_stream(provider).await?; + self.execute_stream(stream).await + } + async fn execute_stream_impl( &self, stream: SendableRecordBatchStream, @@ -184,6 +199,18 @@ impl<'a> InsertBuilder<'a> { Ok(transaction) } + /// Write data files from a [`TableProvider`] source without committing. + /// + /// Use [`CommitBuilder`] to commit the returned transaction. See + /// [`Self::execute_provider`]. + pub async fn execute_uncommitted_provider( + &self, + provider: Arc, + ) -> Result { + let stream = provider_to_stream(provider).await?; + self.execute_uncommitted_stream(stream).await + } + async fn write_uncommitted_stream_impl( &self, stream: SendableRecordBatchStream, @@ -496,6 +523,26 @@ mod test { ); } + #[tokio::test] + async fn test_execute_provider() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let provider: Arc = Arc::new( + datafusion::datasource::MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(), + ); + + let dataset = InsertBuilder::new("memory://") + .execute_provider(provider) + .await + .unwrap(); + + assert_eq!(dataset.count_rows(None).await.unwrap(), 3); + } + #[tokio::test] async fn allow_overwrite_to_v2_2_without_blob_upgrade() { let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index b14421c963f..27924495d22 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -73,6 +73,8 @@ use datafusion::common::NullEquality; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::error::DataFusionError; use datafusion::{ + catalog::{TableProvider, streaming::StreamingTable}, + datasource::MemTable, execution::{ context::{SessionConfig, SessionContext}, memory_pool::MemoryConsumer, @@ -110,9 +112,10 @@ use lance_datafusion::{ chunker::chunk_stream, dataframe::BatchStreamGrouper, exec::{ - HardCapBatchSizeExec, LanceExecutionOptions, OneShotExec, analyze_plan, execute_plan, - get_session_context, + HardCapBatchSizeExec, LanceExecutionOptions, OneShotExec, OneShotPartitionStream, + analyze_plan, execute_plan, get_session_context, provider_to_stream, }, + spill::spilling_table_provider, utils::{StreamingWriteSource, reader_to_stream}, }; use lance_file::version::LanceFileVersion; @@ -336,6 +339,12 @@ struct MergeInsertParams { // Controls whether data that is not matched by the source is deleted or not delete_not_matched_by_source: WhenNotMatchedBySource, conflict_retries: u32, + // When the source is a one-shot stream and `conflict_retries > 0`, the source + // is spilled (memory, then disk) so it can be replayed on each retry. Set to + // false to fail fast on contention instead of buffering the stream. Has no + // effect on re-scannable sources (materialized batches, files), which never + // spill. + spill_for_retry: bool, retry_timeout: Duration, // List of MemWAL region generations to mark as merged when this commit succeeds. merged_generations: Vec, @@ -465,6 +474,7 @@ impl MergeInsertBuilder { insert_not_matched: true, delete_not_matched_by_source: WhenNotMatchedBySource::Keep, conflict_retries: 10, + spill_for_retry: true, retry_timeout: Duration::from_secs(30), merged_generations: Vec::new(), skip_auto_cleanup: false, @@ -512,6 +522,27 @@ impl MergeInsertBuilder { self } + /// Controls whether a one-shot stream source is spilled so it can be replayed + /// across retries. + /// + /// When the source is a one-shot stream (e.g. [`MergeInsertJob::execute`]) and + /// `conflict_retries > 0`, the source is buffered in memory and spilled to disk + /// so each retry can re-read it. Set this to `false` to skip that buffering and + /// fail fast with a contention error instead of writing the stream to disk. + /// + /// This has no effect on re-scannable sources (materialized batches via + /// [`MergeInsertJob::execute_batches`], or a [`TableProvider`] via + /// [`MergeInsertJob::execute_provider`]), which are replayed directly and never + /// spill. + /// + /// Default is true. + /// + /// [`TableProvider`]: datafusion::catalog::TableProvider + pub fn spill_for_retry(&mut self, spill: bool) -> &mut Self { + self.params.spill_for_retry = spill; + self + } + /// Set the timeout used to limit retries. /// /// This is the maximum time to spend on the operation before giving up. At @@ -591,6 +622,16 @@ enum SchemaComparison { Subschema, } +/// Wrap a one-shot stream in a non-replayable [`StreamingTable`] provider. +/// +/// The provider can only be scanned once (its single partition hands out the +/// underlying stream), so it must not be used where retries may re-scan it. +fn one_shot_provider(stream: SendableRecordBatchStream) -> Result> { + let schema = stream.schema(); + let partition = Arc::new(OneShotPartitionStream::new(stream)); + Ok(Arc::new(StreamingTable::try_new(schema, vec![partition])?)) +} + impl MergeInsertJob { pub async fn execute_reader( self, @@ -1403,24 +1444,142 @@ impl MergeInsertJob { )) } - /// Executes the merge insert job + /// Executes the merge insert job from a one-shot stream source. /// /// This will take in the source, merge it with the existing target data, and insert new - /// rows, update existing rows, and delete existing rows + /// rows, update existing rows, and delete existing rows. + /// + /// A stream can only be read once, so when `conflict_retries > 0` the stream is + /// spilled (in memory, then to disk) so it can be replayed on each retry. See + /// [`MergeInsertBuilder::spill_for_retry`] to fail fast instead, and + /// [`Self::execute_batches`] / [`Self::execute_provider`] for re-scannable + /// sources that never spill. pub async fn execute( self, source: SendableRecordBatchStream, ) -> Result<(Arc, MergeStats)> { - let source_iter = super::new_source_iter(source, self.params.conflict_retries > 0).await?; + let (provider, replayable) = self.stream_source_to_provider(source).await?; + // A stream-derived provider reports no statistics, so there is nothing to + // gain from planning against it directly; adapting it back to a stream also + // preserves the source's original error type. + self.execute_inner(provider, replayable, false).await + } + + /// Executes the merge insert job from a re-scannable [`TableProvider`]. + /// + /// This is the canonical entry point: [`Self::execute`] and + /// [`Self::execute_batches`] are thin wrappers that build a provider and call + /// this method. Because a provider can be scanned repeatedly, retries re-read + /// the source directly and never spill to disk. The provider's reported + /// statistics (e.g. from a [`MemTable`] or file source) also let DataFusion + /// optimize the merge join. + /// + /// [`MemTable`]: datafusion::datasource::MemTable + pub async fn execute_provider( + self, + provider: Arc, + ) -> Result<(Arc, MergeStats)> { + // A genuine TableProvider is re-scannable by contract, so retries are safe, + // and planning against it directly lets its statistics drive the join. + self.execute_inner(provider, true, true).await + } + + /// Executes the merge insert job from materialized record batches. + /// + /// The batches are wrapped in an in-memory [`MemTable`], which is re-scannable + /// (retries replay from memory, never spilling) and reports exact statistics to + /// the merge join. This is the preferred entry point when the full source is + /// already in memory. + pub async fn execute_batches( + self, + batches: Vec, + ) -> Result<(Arc, MergeStats)> { + let provider = self.batches_to_provider(batches)?; + self.execute_inner(provider, true, true).await + } + + /// Like [`Self::execute_batches`] but returns the uncommitted transaction. + /// + /// Use [`CommitBuilder`] to commit the returned transaction. + pub async fn execute_uncommitted_batches( + self, + batches: Vec, + ) -> Result { + let provider = self.batches_to_provider(batches)?; + self.execute_uncommitted_impl(provider, true).await + } + + /// Wrap materialized batches in a multi-partition in-memory [`MemTable`]. + fn batches_to_provider(&self, batches: Vec) -> Result> { + let schema = batches + .first() + .map(|batch| batch.schema()) + .unwrap_or_else(|| Arc::new(Schema::from(self.dataset.schema()))); + // Spread batches across partitions so the source can be scanned in parallel + // and reports per-partition statistics. A single inner Vec would be one + // partition with no parallelism. + let partitions = Self::batches_into_partitions(batches); + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) + } + + /// Distribute batches round-robin across up to `num_compute_intensive_cpus` + /// partitions, so a [`MemTable`] built from them can be scanned in parallel. + /// Always returns at least one (possibly empty) partition so an empty source + /// still produces a valid provider. + fn batches_into_partitions(batches: Vec) -> Vec> { + let num_partitions = batches.len().min(get_num_compute_intensive_cpus()).max(1); + let mut partitions = vec![Vec::new(); num_partitions]; + for (idx, batch) in batches.into_iter().enumerate() { + partitions[idx % num_partitions].push(batch); + } + partitions + } + + /// Wrap a one-shot stream source in a provider, returning whether it can be + /// replayed across retries. + /// + /// With retries enabled and spilling allowed, the stream is drained into a + /// replayable spill (memory up to 100MB, then disk). Otherwise the stream is + /// wrapped in a non-replayable one-shot provider and any conflict fails fast. + async fn stream_source_to_provider( + &self, + source: SendableRecordBatchStream, + ) -> Result<(Arc, bool)> { + if self.params.conflict_retries > 0 && self.params.spill_for_retry { + // Allow buffering up to 100MB in memory before spilling to disk. + let provider = spilling_table_provider(source, 100 * 1024 * 1024).await?; + Ok((provider, true)) + } else { + Ok((one_shot_provider(source)?, false)) + } + } + + /// Run the retry loop against a provider, re-scanning it on each attempt. + /// + /// `replayable` indicates whether the provider can be scanned more than once. + /// When it cannot (a one-shot stream that was not spilled), retries are + /// disabled so we never scan it twice; the operation runs once and surfaces any + /// commit conflict directly. + async fn execute_inner( + self, + provider: Arc, + replayable: bool, + scan_provider_directly: bool, + ) -> Result<(Arc, MergeStats)> { let dataset = self.dataset.clone(); let config = RetryConfig { - max_retries: self.params.conflict_retries, + max_retries: if replayable { + self.params.conflict_retries + } else { + 0 + }, retry_timeout: self.params.retry_timeout, }; - let wrapper = MergeInsertJobWithIterator { + let wrapper = MergeInsertJobWithProvider { job: self, - source_iter: Arc::new(Mutex::new(source_iter)), + provider, + scan_provider_directly, attempt_count: Arc::new(AtomicU32::new(0)), }; @@ -1435,7 +1594,8 @@ impl MergeInsertJob { source: impl StreamingWriteSource, ) -> Result { let stream = source.into_stream(); - self.execute_uncommitted_impl(stream).await + self.execute_uncommitted_impl(one_shot_provider(stream)?, false) + .await } fn create_plan_join_type(&self) -> JoinType { @@ -1455,14 +1615,22 @@ impl MergeInsertJob { async fn create_plan( self, - source: SendableRecordBatchStream, + provider: Arc, + scan_provider_directly: bool, ) -> Result> { // Goal: we shouldn't manually have to specify which columns to scan. // DataFusion's optimizer should be able to automatically perform // projection pushdown for us. // Goal: we shouldn't have to add new branches in this code to handle // indexed vs non-indexed cases. That should be handled by optimizer rules. - let session_config = SessionConfig::default(); + let session_config = if scan_provider_directly { + // A provider may expose multiple partitions; keep the plan single- + // partition so it satisfies the merge write node's contract (the + // provider's statistics still drive join-side selection). + SessionConfig::default().with_target_partitions(1) + } else { + SessionConfig::default() + }; let session_ctx = SessionContext::new_with_config(session_config); let scan = session_ctx.read_lance_unordered(self.dataset.clone(), true, true)?; // Wrap column names in double quotes to preserve case (DataFusion lowercases unquoted identifiers) @@ -1473,7 +1641,15 @@ impl MergeInsertJob { .map(|name| format!("\"{}\"", name)) .collect::>(); let on_cols_refs = on_cols.iter().map(|s| s.as_str()).collect::>(); - let source_df = session_ctx.read_one_shot(source)?; + // Plan against the provider directly so its statistics reach the optimizer; + // otherwise adapt it to a one-shot stream (which carries no statistics but + // preserves the source's original error type). + let source_df = if scan_provider_directly { + session_ctx.read_table(provider)? + } else { + let source = provider_to_stream(provider).await?; + session_ctx.read_one_shot(source)? + }; // Capture the source field names *before* aliasing / joining so we // can tell which dataset columns are missing from the source and // need to be filled from the target side of the join below. @@ -1552,14 +1728,15 @@ impl MergeInsertJob { async fn execute_uncommitted_v2( self, - source: SendableRecordBatchStream, + provider: Arc, + scan_provider_directly: bool, ) -> Result<( Transaction, MergeStats, Option, Option, )> { - let plan = self.create_plan(source).await?; + let plan = self.create_plan(provider, scan_provider_directly).await?; // Execute the plan // Assert that we have exactly one partition since we're designed for single-partition execution @@ -1734,14 +1911,16 @@ impl MergeInsertJob { async fn execute_uncommitted_impl( self, - source: SendableRecordBatchStream, + provider: Arc, + scan_provider_directly: bool, ) -> Result { // Check if we can use the fast path - let can_use_fast_path = self.can_use_create_plan(source.schema().as_ref()).await?; + let can_use_fast_path = self.can_use_create_plan(provider.schema().as_ref()).await?; if can_use_fast_path { - let (transaction, stats, affected_rows, inserted_rows_filter) = - self.execute_uncommitted_v2(source).await?; + let (transaction, stats, affected_rows, inserted_rows_filter) = self + .execute_uncommitted_v2(provider, scan_provider_directly) + .await?; return Ok(UncommittedMergeInsert { transaction, affected_rows, @@ -1750,6 +1929,8 @@ impl MergeInsertJob { }); } + // The slow path consumes a single stream; adapt the provider back into one. + let source = provider_to_stream(provider).await?; let source_schema = source.schema(); let lance_schema = lance_core::datatypes::Schema::try_from(source_schema.as_ref())?; let full_schema = self.dataset.schema(); @@ -1995,7 +2176,9 @@ impl MergeInsertJob { // Clone self since create_plan consumes the job let cloned_job = self.clone(); - let plan = cloned_job.create_plan(Box::pin(stream)).await?; + let plan = cloned_job + .create_plan(one_shot_provider(Box::pin(stream))?, false) + .await?; let display = DisplayableExecutionPlan::new(plan.as_ref()); Ok(format!("{}", display.indent(verbose))) @@ -2028,7 +2211,9 @@ impl MergeInsertJob { // Clone self since create_plan consumes the job let cloned_job = self.clone(); - let plan = cloned_job.create_plan(source).await?; + let plan = cloned_job + .create_plan(one_shot_provider(source)?, false) + .await?; // Use the analyze_plan function from lance_datafusion, but strip out the wrapper lines let options = LanceExecutionOptions::default(); @@ -2078,15 +2263,18 @@ pub struct UncommittedMergeInsert { pub inserted_rows_filter: Option, } -/// Wrapper struct that combines MergeInsertJob with the source iterator for retry functionality +/// Wrapper struct that combines MergeInsertJob with the source provider for retry functionality #[derive(Clone)] -struct MergeInsertJobWithIterator { +struct MergeInsertJobWithProvider { job: MergeInsertJob, - source_iter: Arc + Send + 'static>>>, + provider: Arc, + // Whether to plan against the provider directly (using its statistics) or adapt + // it to a one-shot stream. See `MergeInsertJob::execute_inner`. + scan_provider_directly: bool, attempt_count: Arc, } -impl RetryExecutor for MergeInsertJobWithIterator { +impl RetryExecutor for MergeInsertJobWithProvider { type Data = UncommittedMergeInsert; type Result = (Arc, MergeStats); @@ -2094,10 +2282,11 @@ impl RetryExecutor for MergeInsertJobWithIterator { // Increment attempt counter self.attempt_count.fetch_add(1, Ordering::SeqCst); - // We need to get a fresh stream for each retry attempt - // The source_iter provides unlimited streams from the same source data - let stream = self.source_iter.lock().unwrap().next().unwrap(); - self.job.clone().execute_uncommitted_impl(stream).await + // Re-scan the provider on each retry attempt. + self.job + .clone() + .execute_uncommitted_impl(self.provider.clone(), self.scan_provider_directly) + .await } async fn commit(&self, dataset: Arc, mut data: Self::Data) -> Result { @@ -5422,7 +5611,10 @@ mod tests { let new_data = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16)); let new_data_stream = reader_to_stream(Box::new(new_data)); - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .await + .unwrap(); // Assert the plan structure using portable plan matching // The optimized plan should have: @@ -5475,7 +5667,10 @@ mod tests { let new_data_stream = reader_to_stream(Box::new(new_data)); // This should use the fast path (execute_uncommitted_v2) - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .await + .unwrap(); // The optimized plan should use Inner join instead of Right join since we're not // inserting unmatched rows. The sentinel IS NOT NULL condition is folded away by @@ -5523,7 +5718,10 @@ mod tests { let new_data_reader = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16)); let new_data_stream = reader_to_stream(Box::new(new_data_reader)); - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .await + .unwrap(); // The optimized plan should use Inner join and include the UpdateIf condition. // The sentinel IS NOT NULL condition is folded away (sentinel is lit(true)). @@ -5575,7 +5773,10 @@ mod tests { // Should reach the v2 fast path (`create_plan` + FullSchemaMergeInsertExec). // Dropping to v1 here would return an error from create_plan instead. - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .await + .unwrap(); // The join is Right because we keep unmatched source rows (InsertAll) // but discard unmatched target rows (DoNothing on when_matched, @@ -8500,7 +8701,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch.clone())], schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .await + .unwrap(); assert_plan_node_equals( plan, "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing @@ -8582,7 +8786,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch.clone())], id_only_schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .await + .unwrap(); assert_plan_node_equals( plan, "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing @@ -8664,7 +8871,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch.clone())], schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .await + .unwrap(); assert_plan_node_equals( plan, "MergeInsert: on=[key], when_matched=Delete, when_not_matched=InsertAll, when_not_matched_by_source=Keep...THEN 2 WHEN...THEN 3 ELSE 0 END as __action]...projection=[key, value, filterme]" @@ -8769,7 +8979,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(non_matching_batch.clone())], schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .await + .unwrap(); assert_plan_node_equals( plan, "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing @@ -8889,7 +9102,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch)], schema.clone(), ))); - let plan = job.create_plan(plan_stream).await.unwrap(); + let plan = job + .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .await + .unwrap(); let plan_str = datafusion::physical_plan::displayable(plan.as_ref()) .indent(true) @@ -10058,4 +10274,306 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n "Newly written merge-insert data files should be cleaned up on apply_deletions failure" ); } + + fn id_value_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("value", DataType::UInt32, false), + ])) + } + + /// `execute_provider` is the canonical entry point; a `MemTable` source merges + /// the same way a stream does. + #[tokio::test] + async fn test_merge_insert_execute_provider() { + let schema = id_value_schema(); + let initial = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0, 1, 2])), + Arc::new(UInt32Array::from(vec![0, 0, 0])), + ], + ) + .unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // Update id=1, insert id=3. + let new_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1, 3])), + Arc::new(UInt32Array::from(vec![10, 30])), + ], + ) + .unwrap(); + let provider: Arc = Arc::new( + datafusion::datasource::MemTable::try_new(schema.clone(), vec![vec![new_data]]) + .unwrap(), + ); + + let (merged, stats) = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_provider(provider) + .await + .unwrap(); + + assert_eq!(stats.num_updated_rows, 1); + assert_eq!(stats.num_inserted_rows, 1); + + let batch = merged.scan().try_into_batch().await.unwrap(); + let ids = batch["id"].as_primitive::(); + let values = batch["value"].as_primitive::(); + let merged_rows: HashMap = ids + .values() + .iter() + .zip(values.values().iter()) + .map(|(id, value)| (*id, *value)) + .collect(); + assert_eq!( + merged_rows, + HashMap::from([(0, 0), (1, 10), (2, 0), (3, 30)]) + ); + } + + /// `execute_batches` merges materialized batches; multiple batches are spread + /// across partitions and merged correctly. + #[tokio::test] + async fn test_merge_insert_execute_batches() { + let schema = id_value_schema(); + let initial = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0, 1, 2])), + Arc::new(UInt32Array::from(vec![0, 0, 0])), + ], + ) + .unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // Two batches: update id=1 (batch 0), insert id=3 (batch 1). + let batch0 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1])), + Arc::new(UInt32Array::from(vec![10])), + ], + ) + .unwrap(); + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![3])), + Arc::new(UInt32Array::from(vec![30])), + ], + ) + .unwrap(); + + let (merged, stats) = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_batches(vec![batch0, batch1]) + .await + .unwrap(); + + assert_eq!(stats.num_updated_rows, 1); + assert_eq!(stats.num_inserted_rows, 1); + + let batch = merged.scan().try_into_batch().await.unwrap(); + let ids = batch["id"].as_primitive::(); + let values = batch["value"].as_primitive::(); + let merged_rows: HashMap = ids + .values() + .iter() + .zip(values.values().iter()) + .map(|(id, value)| (*id, *value)) + .collect(); + assert_eq!( + merged_rows, + HashMap::from([(0, 0), (1, 10), (2, 0), (3, 30)]) + ); + } + + fn collect_exact_row_counts(plan: &Arc, out: &mut Vec) { + if let Ok(stats) = plan.partition_statistics(None) + && let datafusion::common::stats::Precision::Exact(n) = stats.num_rows + { + out.push(n); + } + for child in plan.children() { + collect_exact_row_counts(child, out); + } + } + + /// Use case 3: planning against a provider directly exposes its exact source + /// statistics to the optimizer. Adapting it to a one-shot stream does not. + #[tokio::test] + async fn test_merge_insert_source_statistics_in_plan() { + let schema = id_value_schema(); + let target_rows = 1000u32; + let initial = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..target_rows)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 0, + target_rows as usize, + ))), + ], + ) + .unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // A small source whose exact row count is distinct from the target's. + let source_rows = 10usize; + let new_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..source_rows as u32)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 1, + source_rows, + ))), + ], + ) + .unwrap(); + let provider: Arc = + Arc::new(MemTable::try_new(schema.clone(), vec![vec![new_data]]).unwrap()); + + let job = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .try_build() + .unwrap(); + + // Planning against the provider directly: its exact row count reaches the plan. + let plan_with_stats = job + .clone() + .create_plan(provider.clone(), true) + .await + .unwrap(); + let mut with_stats = Vec::new(); + collect_exact_row_counts(&plan_with_stats, &mut with_stats); + assert!( + with_stats.contains(&source_rows), + "source provider's exact row count ({source_rows}) should reach the plan; got {with_stats:?}" + ); + + // Adapting to a one-shot stream: the source carries no statistics. + let plan_without_stats = job.create_plan(provider, false).await.unwrap(); + let mut without_stats = Vec::new(); + collect_exact_row_counts(&plan_without_stats, &mut without_stats); + assert!( + !without_stats.contains(&source_rows), + "one-shot stream source should not report exact row count; got {without_stats:?}" + ); + } + + /// With a one-shot stream source and `spill_for_retry(false)`, a commit + /// conflict fails fast instead of replaying the stream. The non-replayable + /// one-shot provider must be scanned exactly once (scanning it twice would + /// panic), proving retries are disabled even though `conflict_retries > 0`. + #[tokio::test] + async fn test_merge_insert_spill_for_retry_false_fails_fast() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false).with_metadata( + vec![( + "lance-schema:unenforced-primary-key".to_string(), + "true".to_string(), + )] + .into_iter() + .collect(), + ), + Field::new("value", DataType::UInt32, false), + ])); + let initial = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![0, 0, 0, 0])), + ], + ) + .unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // Merge insert job based on version 1, with retries enabled but spilling off. + let new_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![100])), + Arc::new(UInt32Array::from(vec![1])), + ], + ) + .unwrap(); + let job = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .conflict_retries(10) + .spill_for_retry(false) + .try_build() + .unwrap(); + + // An append commits first (version 2), so the merge built on version 1 hits + // an unresolvable conflict on commit. + let append_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![50])), + Arc::new(UInt32Array::from(vec![2])), + ], + ) + .unwrap(); + InsertBuilder::new(dataset.clone()) + .with_params(&WriteParams { + mode: WriteMode::Append, + ..Default::default() + }) + .execute(vec![append_batch]) + .await + .unwrap(); + + let source = RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::iter(vec![Ok(new_data)]), + ); + let merge_result = job + .execute(Box::pin(source) as SendableRecordBatchStream) + .await; + + assert!( + matches!( + merge_result, + Err(crate::Error::TooMuchWriteContention { .. }) + ), + "Expected fail-fast TooMuchWriteContention, got: {:?}", + merge_result + ); + } } From 5c4f28f5cfd7c6a73ecee81513aff9f2d29b2dbc Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 16:57:34 -0700 Subject: [PATCH 2/3] refactor: simplify merge_insert source handling per review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Always plan against the source TableProvider directly (read_table), so every source — including spilled and one-shot streams — exposes its statistics to the merge join. The merge write node already requires a single-partition input, so the optimizer coalesces multi-partition providers; the previous target_partitions(1) hack and the scan_provider_directly branch are removed. Drop the provider_to_stream first-batch peek that preserved the source error's concrete type. Source errors are shared across join partitions by DataFusion (DataFusionError::Shared), so the type cannot be recovered, and no caller needs it — Python surfaces these errors by message. The error conversion now handles Shared (recursing when sole-owner, otherwise preserving the message under the execution category). Revert the InsertBuilder provider methods: they only adapt a provider back to a stream and add no parallelism until fragment fan-out exists. Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance-core/src/error.rs | 11 ++ rust/lance-datafusion/src/exec.rs | 12 +- rust/lance/src/dataset/write/insert.rs | 47 ----- rust/lance/src/dataset/write/merge_insert.rs | 184 ++++++------------- 4 files changed, 66 insertions(+), 188 deletions(-) diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index 3dcde1fc5b2..f12608d15e0 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -625,6 +625,17 @@ impl From for Error { Self::not_supported_source(box_error(e)) } datafusion_common::DataFusionError::Execution(..) => Self::execution(e.to_string()), + datafusion_common::DataFusionError::Shared(shared) => { + // DataFusion shares an error across consumers (e.g. a join's + // build-side error fanned out to every probe partition) behind an + // `Arc`. If we are the sole owner we can recurse for full fidelity; + // otherwise the inner error can't be moved out, so we preserve its + // message under the execution category (its concrete type is lost). + match std::sync::Arc::try_unwrap(shared) { + Ok(inner) => Self::from(inner), + Err(shared) => Self::execution(shared.to_string()), + } + } datafusion_common::DataFusionError::External(source) => { // Try to downcast to lance_core::Error first match source.downcast::() { diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index e08bd00f6f0..375a5b8e887 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -39,7 +39,7 @@ use datafusion::{ use datafusion_common::{DataFusionError, Statistics}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use futures::{StreamExt, TryStreamExt, stream}; +use futures::{StreamExt, stream}; use lance_arrow::SchemaExt; use lance_core::{ Error, Result, @@ -872,10 +872,6 @@ impl SessionContextExt for SessionContext { /// Multi-partition providers are coalesced into a single partition. This adapts a /// re-scannable provider back into the one stream the writer pipeline consumes; /// re-scanning the same provider (e.g. on a write retry) yields a fresh stream. -/// -/// The first batch is read eagerly and re-chained onto the stream. This surfaces a -/// scan error from the source directly, before it can be fed into (and obscured by) -/// a downstream plan — preserving the original error type for callers. pub async fn provider_to_stream( provider: Arc, ) -> Result { @@ -887,11 +883,7 @@ pub async fn provider_to_stream( } else { plan }; - let schema = plan.schema(); - let mut stream = plan.execute(0, ctx.task_ctx())?; - let first = stream.try_next().await?; - let rechained = stream::iter(first.map(Ok)).chain(stream); - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, rechained))) + Ok(plan.execute(0, ctx.task_ctx())?) } #[derive(Clone, Debug)] diff --git a/rust/lance/src/dataset/write/insert.rs b/rust/lance/src/dataset/write/insert.rs index 2b7d6869f18..bfd702c9c3b 100644 --- a/rust/lance/src/dataset/write/insert.rs +++ b/rust/lance/src/dataset/write/insert.rs @@ -5,13 +5,11 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_array::{RecordBatch, RecordBatchIterator}; -use datafusion::catalog::TableProvider; use datafusion::execution::SendableRecordBatchStream; use humantime::format_duration; use lance_core::datatypes::{NullabilityComparison, Schema, SchemaCompareOptions}; use lance_core::utils::tracing::{DATASET_WRITING_EVENT, TRACE_DATASET_EVENTS}; use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET}; -use lance_datafusion::exec::provider_to_stream; use lance_datafusion::utils::StreamingWriteSource; use lance_file::version::LanceFileVersion; use lance_io::object_store::ObjectStore; @@ -95,19 +93,6 @@ impl<'a> InsertBuilder<'a> { self.execute_stream_impl(stream, schema).await } - /// Execute the insert operation with a [`TableProvider`] source. - /// - /// The provider is scanned into a stream and written. This mirrors - /// [`crate::dataset::MergeInsertJob::execute_provider`] so the same input - /// shapes are accepted across write operations. Inserts do not retry, so the - /// provider is scanned only once. - /// - /// [`TableProvider`]: datafusion::catalog::TableProvider - pub async fn execute_provider(&self, provider: Arc) -> Result { - let stream = provider_to_stream(provider).await?; - self.execute_stream(stream).await - } - async fn execute_stream_impl( &self, stream: SendableRecordBatchStream, @@ -199,18 +184,6 @@ impl<'a> InsertBuilder<'a> { Ok(transaction) } - /// Write data files from a [`TableProvider`] source without committing. - /// - /// Use [`CommitBuilder`] to commit the returned transaction. See - /// [`Self::execute_provider`]. - pub async fn execute_uncommitted_provider( - &self, - provider: Arc, - ) -> Result { - let stream = provider_to_stream(provider).await?; - self.execute_uncommitted_stream(stream).await - } - async fn write_uncommitted_stream_impl( &self, stream: SendableRecordBatchStream, @@ -523,26 +496,6 @@ mod test { ); } - #[tokio::test] - async fn test_execute_provider() { - let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - ) - .unwrap(); - let provider: Arc = Arc::new( - datafusion::datasource::MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap(), - ); - - let dataset = InsertBuilder::new("memory://") - .execute_provider(provider) - .await - .unwrap(); - - assert_eq!(dataset.count_rows(None).await.unwrap(), 3); - } - #[tokio::test] async fn allow_overwrite_to_v2_2_without_blob_upgrade() { let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 27924495d22..a751fa85e2a 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -1459,10 +1459,7 @@ impl MergeInsertJob { source: SendableRecordBatchStream, ) -> Result<(Arc, MergeStats)> { let (provider, replayable) = self.stream_source_to_provider(source).await?; - // A stream-derived provider reports no statistics, so there is nothing to - // gain from planning against it directly; adapting it back to a stream also - // preserves the source's original error type. - self.execute_inner(provider, replayable, false).await + self.execute_inner(provider, replayable).await } /// Executes the merge insert job from a re-scannable [`TableProvider`]. @@ -1479,9 +1476,8 @@ impl MergeInsertJob { self, provider: Arc, ) -> Result<(Arc, MergeStats)> { - // A genuine TableProvider is re-scannable by contract, so retries are safe, - // and planning against it directly lets its statistics drive the join. - self.execute_inner(provider, true, true).await + // A genuine TableProvider is re-scannable by contract, so retries are safe. + self.execute_inner(provider, true).await } /// Executes the merge insert job from materialized record batches. @@ -1495,7 +1491,7 @@ impl MergeInsertJob { batches: Vec, ) -> Result<(Arc, MergeStats)> { let provider = self.batches_to_provider(batches)?; - self.execute_inner(provider, true, true).await + self.execute_inner(provider, true).await } /// Like [`Self::execute_batches`] but returns the uncommitted transaction. @@ -1506,7 +1502,7 @@ impl MergeInsertJob { batches: Vec, ) -> Result { let provider = self.batches_to_provider(batches)?; - self.execute_uncommitted_impl(provider, true).await + self.execute_uncommitted_impl(provider).await } /// Wrap materialized batches in a multi-partition in-memory [`MemTable`]. @@ -1564,7 +1560,6 @@ impl MergeInsertJob { self, provider: Arc, replayable: bool, - scan_provider_directly: bool, ) -> Result<(Arc, MergeStats)> { let dataset = self.dataset.clone(); let config = RetryConfig { @@ -1579,7 +1574,6 @@ impl MergeInsertJob { let wrapper = MergeInsertJobWithProvider { job: self, provider, - scan_provider_directly, attempt_count: Arc::new(AtomicU32::new(0)), }; @@ -1594,7 +1588,7 @@ impl MergeInsertJob { source: impl StreamingWriteSource, ) -> Result { let stream = source.into_stream(); - self.execute_uncommitted_impl(one_shot_provider(stream)?, false) + self.execute_uncommitted_impl(one_shot_provider(stream)?) .await } @@ -1613,25 +1607,13 @@ impl MergeInsertJob { } } - async fn create_plan( - self, - provider: Arc, - scan_provider_directly: bool, - ) -> Result> { + async fn create_plan(self, provider: Arc) -> Result> { // Goal: we shouldn't manually have to specify which columns to scan. // DataFusion's optimizer should be able to automatically perform // projection pushdown for us. // Goal: we shouldn't have to add new branches in this code to handle // indexed vs non-indexed cases. That should be handled by optimizer rules. - let session_config = if scan_provider_directly { - // A provider may expose multiple partitions; keep the plan single- - // partition so it satisfies the merge write node's contract (the - // provider's statistics still drive join-side selection). - SessionConfig::default().with_target_partitions(1) - } else { - SessionConfig::default() - }; - let session_ctx = SessionContext::new_with_config(session_config); + let session_ctx = SessionContext::new(); let scan = session_ctx.read_lance_unordered(self.dataset.clone(), true, true)?; // Wrap column names in double quotes to preserve case (DataFusion lowercases unquoted identifiers) let on_cols = self @@ -1641,15 +1623,11 @@ impl MergeInsertJob { .map(|name| format!("\"{}\"", name)) .collect::>(); let on_cols_refs = on_cols.iter().map(|s| s.as_str()).collect::>(); - // Plan against the provider directly so its statistics reach the optimizer; - // otherwise adapt it to a one-shot stream (which carries no statistics but - // preserves the source's original error type). - let source_df = if scan_provider_directly { - session_ctx.read_table(provider)? - } else { - let source = provider_to_stream(provider).await?; - session_ctx.read_one_shot(source)? - }; + // Plan against the provider directly so its statistics reach the optimizer. + // The merge write node requires a single-partition input, so the optimizer + // coalesces a multi-partition provider for us (see + // `FullSchemaMergeInsertExec::required_input_distribution`). + let source_df = session_ctx.read_table(provider)?; // Capture the source field names *before* aliasing / joining so we // can tell which dataset columns are missing from the source and // need to be filled from the target side of the join below. @@ -1729,14 +1707,13 @@ impl MergeInsertJob { async fn execute_uncommitted_v2( self, provider: Arc, - scan_provider_directly: bool, ) -> Result<( Transaction, MergeStats, Option, Option, )> { - let plan = self.create_plan(provider, scan_provider_directly).await?; + let plan = self.create_plan(provider).await?; // Execute the plan // Assert that we have exactly one partition since we're designed for single-partition execution @@ -1912,15 +1889,13 @@ impl MergeInsertJob { async fn execute_uncommitted_impl( self, provider: Arc, - scan_provider_directly: bool, ) -> Result { // Check if we can use the fast path let can_use_fast_path = self.can_use_create_plan(provider.schema().as_ref()).await?; if can_use_fast_path { - let (transaction, stats, affected_rows, inserted_rows_filter) = self - .execute_uncommitted_v2(provider, scan_provider_directly) - .await?; + let (transaction, stats, affected_rows, inserted_rows_filter) = + self.execute_uncommitted_v2(provider).await?; return Ok(UncommittedMergeInsert { transaction, affected_rows, @@ -2177,7 +2152,7 @@ impl MergeInsertJob { // Clone self since create_plan consumes the job let cloned_job = self.clone(); let plan = cloned_job - .create_plan(one_shot_provider(Box::pin(stream))?, false) + .create_plan(one_shot_provider(Box::pin(stream))?) .await?; let display = DisplayableExecutionPlan::new(plan.as_ref()); @@ -2211,9 +2186,7 @@ impl MergeInsertJob { // Clone self since create_plan consumes the job let cloned_job = self.clone(); - let plan = cloned_job - .create_plan(one_shot_provider(source)?, false) - .await?; + let plan = cloned_job.create_plan(one_shot_provider(source)?).await?; // Use the analyze_plan function from lance_datafusion, but strip out the wrapper lines let options = LanceExecutionOptions::default(); @@ -2268,9 +2241,6 @@ pub struct UncommittedMergeInsert { struct MergeInsertJobWithProvider { job: MergeInsertJob, provider: Arc, - // Whether to plan against the provider directly (using its statistics) or adapt - // it to a one-shot stream. See `MergeInsertJob::execute_inner`. - scan_provider_directly: bool, attempt_count: Arc, } @@ -2285,7 +2255,7 @@ impl RetryExecutor for MergeInsertJobWithProvider { // Re-scan the provider on each retry attempt. self.job .clone() - .execute_uncommitted_impl(self.provider.clone(), self.scan_provider_directly) + .execute_uncommitted_impl(self.provider.clone()) .await } @@ -5612,7 +5582,7 @@ mod tests { let new_data_stream = reader_to_stream(Box::new(new_data)); let plan = merge_insert_job - .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .create_plan(one_shot_provider(new_data_stream).unwrap()) .await .unwrap(); @@ -5668,7 +5638,7 @@ mod tests { // This should use the fast path (execute_uncommitted_v2) let plan = merge_insert_job - .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .create_plan(one_shot_provider(new_data_stream).unwrap()) .await .unwrap(); @@ -5719,7 +5689,7 @@ mod tests { let new_data_stream = reader_to_stream(Box::new(new_data_reader)); let plan = merge_insert_job - .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .create_plan(one_shot_provider(new_data_stream).unwrap()) .await .unwrap(); @@ -5774,7 +5744,7 @@ mod tests { // Should reach the v2 fast path (`create_plan` + FullSchemaMergeInsertExec). // Dropping to v1 here would return an error from create_plan instead. let plan = merge_insert_job - .create_plan(one_shot_provider(new_data_stream).unwrap(), false) + .create_plan(one_shot_provider(new_data_stream).unwrap()) .await .unwrap(); @@ -8702,7 +8672,7 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n schema.clone(), ))); let plan = plan_job - .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .create_plan(one_shot_provider(plan_stream).unwrap()) .await .unwrap(); assert_plan_node_equals( @@ -8787,7 +8757,7 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n id_only_schema.clone(), ))); let plan = plan_job - .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .create_plan(one_shot_provider(plan_stream).unwrap()) .await .unwrap(); assert_plan_node_equals( @@ -8872,7 +8842,7 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n schema.clone(), ))); let plan = plan_job - .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .create_plan(one_shot_provider(plan_stream).unwrap()) .await .unwrap(); assert_plan_node_equals( @@ -8980,7 +8950,7 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n schema.clone(), ))); let plan = plan_job - .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .create_plan(one_shot_provider(plan_stream).unwrap()) .await .unwrap(); assert_plan_node_equals( @@ -9103,7 +9073,7 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n schema.clone(), ))); let plan = job - .create_plan(one_shot_provider(plan_stream).unwrap(), false) + .create_plan(one_shot_provider(plan_stream).unwrap()) .await .unwrap(); @@ -9196,7 +9166,7 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n impl std::error::Error for MyTestError {} #[tokio::test] - async fn test_merge_insert_execute_reader_preserves_external_error() { + async fn test_merge_insert_execute_reader_preserves_error_message() { let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("key", DataType::Int32, false), ArrowField::new("value", DataType::Int32, false), @@ -9233,14 +9203,14 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n .execute_reader(Box::new(reader) as Box) .await; - match result { - Err(Error::External { source }) => { - let original = source.downcast_ref::().unwrap(); - assert_eq!(original.code, error_code); - } - Err(other) => panic!("Expected External, got: {:?}", other), - Ok(_) => panic!("Expected error"), - } + // The source error is routed through the merge plan, which shares it + // across join partitions, so its concrete type is not recoverable. The + // message must still reach the caller. + let err = result.expect_err("expected the source error to surface"); + assert!( + err.to_string().contains("merge insert failure"), + "source error message should be preserved; got: {err}" + ); } } @@ -10286,15 +10256,8 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n /// the same way a stream does. #[tokio::test] async fn test_merge_insert_execute_provider() { - let schema = id_value_schema(); - let initial = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(UInt32Array::from(vec![0, 1, 2])), - Arc::new(UInt32Array::from(vec![0, 0, 0])), - ], - ) - .unwrap(); + let initial = + record_batch!(("id", UInt32, [0, 1, 2]), ("value", UInt32, [0, 0, 0])).unwrap(); let dataset = Arc::new( InsertBuilder::new("memory://") .execute(vec![initial]) @@ -10303,16 +10266,9 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n ); // Update id=1, insert id=3. - let new_data = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(UInt32Array::from(vec![1, 3])), - Arc::new(UInt32Array::from(vec![10, 30])), - ], - ) - .unwrap(); + let new_data = record_batch!(("id", UInt32, [1, 3]), ("value", UInt32, [10, 30])).unwrap(); let provider: Arc = Arc::new( - datafusion::datasource::MemTable::try_new(schema.clone(), vec![vec![new_data]]) + datafusion::datasource::MemTable::try_new(new_data.schema(), vec![vec![new_data]]) .unwrap(), ); @@ -10348,15 +10304,8 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n /// across partitions and merged correctly. #[tokio::test] async fn test_merge_insert_execute_batches() { - let schema = id_value_schema(); - let initial = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(UInt32Array::from(vec![0, 1, 2])), - Arc::new(UInt32Array::from(vec![0, 0, 0])), - ], - ) - .unwrap(); + let initial = + record_batch!(("id", UInt32, [0, 1, 2]), ("value", UInt32, [0, 0, 0])).unwrap(); let dataset = Arc::new( InsertBuilder::new("memory://") .execute(vec![initial]) @@ -10365,22 +10314,8 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n ); // Two batches: update id=1 (batch 0), insert id=3 (batch 1). - let batch0 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(UInt32Array::from(vec![1])), - Arc::new(UInt32Array::from(vec![10])), - ], - ) - .unwrap(); - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(UInt32Array::from(vec![3])), - Arc::new(UInt32Array::from(vec![30])), - ], - ) - .unwrap(); + let batch0 = record_batch!(("id", UInt32, [1]), ("value", UInt32, [10])).unwrap(); + let batch1 = record_batch!(("id", UInt32, [3]), ("value", UInt32, [30])).unwrap(); let (merged, stats) = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) .unwrap() @@ -10421,8 +10356,8 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n } } - /// Use case 3: planning against a provider directly exposes its exact source - /// statistics to the optimizer. Adapting it to a one-shot stream does not. + /// Use case 3: planning against the provider exposes its exact source + /// statistics to the optimizer. #[tokio::test] async fn test_merge_insert_source_statistics_in_plan() { let schema = id_value_schema(); @@ -10467,26 +10402,13 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n .try_build() .unwrap(); - // Planning against the provider directly: its exact row count reaches the plan. - let plan_with_stats = job - .clone() - .create_plan(provider.clone(), true) - .await - .unwrap(); - let mut with_stats = Vec::new(); - collect_exact_row_counts(&plan_with_stats, &mut with_stats); - assert!( - with_stats.contains(&source_rows), - "source provider's exact row count ({source_rows}) should reach the plan; got {with_stats:?}" - ); - - // Adapting to a one-shot stream: the source carries no statistics. - let plan_without_stats = job.create_plan(provider, false).await.unwrap(); - let mut without_stats = Vec::new(); - collect_exact_row_counts(&plan_without_stats, &mut without_stats); + // The provider's exact row count reaches the plan's statistics. + let plan = job.create_plan(provider).await.unwrap(); + let mut row_counts = Vec::new(); + collect_exact_row_counts(&plan, &mut row_counts); assert!( - !without_stats.contains(&source_rows), - "one-shot stream source should not report exact row count; got {without_stats:?}" + row_counts.contains(&source_rows), + "source provider's exact row count ({source_rows}) should reach the plan; got {row_counts:?}" ); } From b340516d605287caae71e856c045625927caefbd Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 17:04:16 -0700 Subject: [PATCH 3/3] test: cover empty execute_batches input Verify that merging an empty batch list is a no-op that leaves the target unchanged, exercising the empty-source partition and schema-fallback paths in batches_to_provider / batches_into_partitions. Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance/src/dataset/write/merge_insert.rs | 38 ++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index a751fa85e2a..8d7aed6cad5 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -10345,6 +10345,44 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n ); } + /// An empty batch list still produces a valid (single, empty) partition, so the + /// merge is a no-op and the target is unchanged. + #[tokio::test] + async fn test_merge_insert_execute_batches_empty() { + let initial = + record_batch!(("id", UInt32, [0, 1, 2]), ("value", UInt32, [0, 0, 0])).unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + let (merged, stats) = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_batches(vec![]) + .await + .unwrap(); + + assert_eq!(stats.num_updated_rows, 0); + assert_eq!(stats.num_inserted_rows, 0); + + let batch = merged.scan().try_into_batch().await.unwrap(); + let ids = batch["id"].as_primitive::(); + let values = batch["value"].as_primitive::(); + let merged_rows: HashMap = ids + .values() + .iter() + .zip(values.values().iter()) + .map(|(id, value)| (*id, *value)) + .collect(); + assert_eq!(merged_rows, HashMap::from([(0, 0), (1, 0), (2, 0)])); + } + fn collect_exact_row_counts(plan: &Arc, out: &mut Vec) { if let Ok(stats) = plan.partition_statistics(None) && let datafusion::common::stats::Precision::Exact(n) = stats.num_rows