diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 45dc1b253d3..f361122bb4f 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -70,7 +70,7 @@ from .lance import __version__ as __version__ from .lance import _Session as Session from .query import FullTextQuery -from .types import _coerce_reader +from .types import _coerce_reader, _is_materialized from .udf import BatchUDF, normalize_transform from .udf import BatchUDFCheckpoint as BatchUDFCheckpoint from .udf import batch_udf as batch_udf @@ -392,6 +392,12 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None): """ reader = _coerce_reader(data_obj, schema) + # Materialized sources are wrapped in an in-memory table so retries never + # spill and the source's statistics can drive the join; everything else is + # treated as a one-shot stream. + if _is_materialized(data_obj): + return super(MergeInsertBuilder, self).execute_batches(reader) + return super(MergeInsertBuilder, self).execute(reader) def execute_uncommitted( @@ -416,6 +422,9 @@ def execute_uncommitted( """ reader = _coerce_reader(data_obj, schema) + if _is_materialized(data_obj): + return super(MergeInsertBuilder, self).execute_uncommitted_batches(reader) + return super(MergeInsertBuilder, self).execute_uncommitted(reader) # These next three overrides exist only to document the methods diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 74db076db41..323e78879cb 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -471,6 +471,13 @@ class _MergeInsertBuilder: def when_not_matched_insert_all(self) -> Self: ... def when_not_matched_by_source_delete(self, expr: Optional[str] = None) -> Self: ... def execute(self, new_data: pa.RecordBatchReader) -> ExecuteResult: ... + def execute_batches(self, new_data: pa.RecordBatchReader) -> ExecuteResult: ... + def execute_uncommitted( + self, new_data: pa.RecordBatchReader + ) -> tuple[Transaction, ExecuteResult]: ... + def execute_uncommitted_batches( + self, new_data: pa.RecordBatchReader + ) -> tuple[Transaction, ExecuteResult]: ... class _Scanner: @property diff --git a/python/python/lance/types.py b/python/python/lance/types.py index 41cc191e4d6..0ed2ae9f95e 100644 --- a/python/python/lance/types.py +++ b/python/python/lance/types.py @@ -52,6 +52,34 @@ def _casting_recordbatch_iter( yield batch +def _is_materialized(data_obj: ReaderLike) -> bool: + """Whether ``data_obj`` is fully materialized in memory. + + Materialized sources (tables, in-memory frames) can be wrapped in an + in-memory table for replay without spilling and to expose exact statistics. + Streaming or re-readable sources (readers, scanners, datasets, generators) + are not considered materialized. + """ + if _check_for_pandas(data_obj) and isinstance(data_obj, pd.DataFrame): + return True + if isinstance(data_obj, (pa.Table, pa.RecordBatch)): + return True + if ( + type(data_obj).__module__.startswith("polars") + and data_obj.__class__.__name__ == "DataFrame" + ): + return True + if isinstance(data_obj, dict): + return True + if ( + isinstance(data_obj, list) + and len(data_obj) > 0 + and isinstance(data_obj[0], dict) + ): + return True + return False + + def _coerce_reader( data_obj: ReaderLike, schema: Optional[pa.Schema] = None ) -> pa.RecordBatchReader: diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 45866f3c4da..4c373ad0180 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -2342,6 +2342,29 @@ def test_merge_insert(tmp_path: Path): check_merge_stats(merge_dict, (None, None, None)) +@pytest.mark.parametrize("materialized", [True, False]) +def test_merge_insert_input_kinds(tmp_path: Path, materialized: bool): + # A materialized pa.Table is routed through the in-memory (MemTable) path, + # while a RecordBatchReader is routed through the streaming path. Both must + # produce identical results. + schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())]) + base = pa.table({"id": range(5), "value": [0] * 5}, schema=schema) + new = pa.table({"id": [1, 2, 5, 6], "value": [10, 20, 50, 60]}, schema=schema) + + dataset = lance.write_dataset(base, tmp_path / "dataset", mode="create") + source = new if materialized else new.to_reader() + + dataset.merge_insert( + "id" + ).when_matched_update_all().when_not_matched_insert_all().execute(source) + + result = dataset.to_table().sort_by("id").to_pydict() + assert result == { + "id": [0, 1, 2, 3, 4, 5, 6], + "value": [0, 10, 20, 0, 0, 50, 60], + } + + def test_merge_insert_subcols(tmp_path: Path): initial_data = pa.table( { diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 8bfa81aeae4..1d75d9bd128 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -434,6 +434,60 @@ impl MergeInsertBuilder { Ok((PyLance(transaction), stats)) } + /// Execute the merge insert from fully-materialized data. + /// + /// The data is read into memory and wrapped in an in-memory table, so retries + /// never spill to disk and the source's statistics drive the join. Callers + /// should only route in-memory inputs (e.g. a `pa.Table`) here. + pub fn execute_batches(&mut self, new_data: &Bound) -> PyResult> { + let py = new_data.py(); + let reader = convert_reader(new_data)?; + let batches = reader + .collect::, _>>() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let job = self + .builder + .try_build() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let (new_dataset, stats) = rt() + .spawn(Some(py), job.execute_batches(batches))? + .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; + + let dataset = self.dataset.bind(py); + dataset.borrow_mut().ds = new_dataset; + + Ok(Self::build_stats(&stats, py)?.into()) + } + + /// [`Self::execute_batches`] without committing; returns the transaction. + pub fn execute_uncommitted_batches<'a>( + &mut self, + new_data: &Bound<'a, PyAny>, + ) -> PyResult<(PyLance, Bound<'a, PyDict>)> { + let py = new_data.py(); + let reader = convert_reader(new_data)?; + let batches = reader + .collect::, _>>() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let job = self + .builder + .try_build() + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let UncommittedMergeInsert { + transaction, stats, .. + } = rt() + .spawn(Some(py), job.execute_uncommitted_batches(batches))? + .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; + + let stats = Self::build_stats(&stats, py)?; + + Ok((PyLance(transaction), stats)) + } + #[pyo3(signature=(schema = None, verbose = false))] pub fn explain_plan( &mut self, diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index 3dcde1fc5b2..f12608d15e0 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -625,6 +625,17 @@ impl From for Error { Self::not_supported_source(box_error(e)) } datafusion_common::DataFusionError::Execution(..) => Self::execution(e.to_string()), + datafusion_common::DataFusionError::Shared(shared) => { + // DataFusion shares an error across consumers (e.g. a join's + // build-side error fanned out to every probe partition) behind an + // `Arc`. If we are the sole owner we can recurse for full fidelity; + // otherwise the inner error can't be moved out, so we preserve its + // message under the execution category (its concrete type is lost). + match std::sync::Arc::try_unwrap(shared) { + Ok(inner) => Self::from(inner), + Err(shared) => Self::execution(shared.to_string()), + } + } datafusion_common::DataFusionError::External(source) => { // Try to downcast to lance_core::Error first match source.downcast::() { diff --git a/rust/lance-datafusion/src/exec.rs b/rust/lance-datafusion/src/exec.rs index 8f346f45612..375a5b8e887 100644 --- a/rust/lance-datafusion/src/exec.rs +++ b/rust/lance-datafusion/src/exec.rs @@ -16,7 +16,7 @@ use arrow_array::RecordBatch; use arrow_schema::Schema as ArrowSchema; use datafusion::physical_plan::metrics::MetricType; use datafusion::{ - catalog::streaming::StreamingTable, + catalog::{TableProvider, streaming::StreamingTable}, dataframe::DataFrame, execution::{ TaskContext, @@ -867,6 +867,25 @@ impl SessionContextExt for SessionContext { } } +/// Scan a [`TableProvider`] into a single-partition [`SendableRecordBatchStream`]. +/// +/// Multi-partition providers are coalesced into a single partition. This adapts a +/// re-scannable provider back into the one stream the writer pipeline consumes; +/// re-scanning the same provider (e.g. on a write retry) yields a fresh stream. +pub async fn provider_to_stream( + provider: Arc, +) -> Result { + let ctx = SessionContext::new(); + let plan = provider.scan(&ctx.state(), None, &[], None).await?; + let plan: Arc = + if plan.properties().output_partitioning().partition_count() > 1 { + Arc::new(CoalescePartitionsExec::new(plan)) + } else { + plan + }; + Ok(plan.execute(0, ctx.task_ctx())?) +} + #[derive(Clone, Debug)] pub struct StrictBatchSizeExec { input: Arc, diff --git a/rust/lance-datafusion/src/spill.rs b/rust/lance-datafusion/src/spill.rs index 8fa60c93ab6..1ffbe9083c4 100644 --- a/rust/lance-datafusion/src/spill.rs +++ b/rust/lance-datafusion/src/spill.rs @@ -9,13 +9,17 @@ use std::{ use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; use arrow_array::RecordBatch; -use arrow_schema::{ArrowError, Schema}; +use arrow_schema::{ArrowError, Schema, SchemaRef}; use datafusion::{ - execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter, + catalog::{TableProvider, streaming::StreamingTable}, + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{stream::RecordBatchStreamAdapter, streaming::PartitionStream}, }; use datafusion_common::DataFusionError; +use futures::StreamExt; use lance_arrow::memory::MemoryAccumulator; use lance_core::error::LanceOptionExt; +use lance_core::utils::tempfile::TempDir; /// Start a spill of Arrow data to a file that can be read later multiple times. /// @@ -60,6 +64,101 @@ pub fn create_replay_spill( (sender, receiver) } +/// Wrap a one-shot [`SendableRecordBatchStream`] in a re-scannable [`TableProvider`]. +/// +/// The source is drained in the background into a replayable spill: up to +/// `memory_limit` bytes are buffered in memory before spilling to a temporary file +/// on disk. Each scan of the returned provider replays the full source from the +/// spill, which is what makes a one-shot stream usable in the write retry loop. +/// +/// The provider reports no statistics — the source size is not known until it has +/// been fully drained — so callers that need source statistics (e.g. to drive join +/// ordering) should prefer a materialized or file-backed provider instead. +pub async fn spilling_table_provider( + mut source: SendableRecordBatchStream, + memory_limit: usize, +) -> Result, DataFusionError> { + let schema = source.schema(); + let tmp_dir = tokio::task::spawn_blocking(TempDir::try_new) + .await + .map_err(|e| DataFusionError::Execution(format!("Failed to spawn temp dir task: {e}")))? + .map_err(|e| DataFusionError::Execution(format!("Failed to create temp dir: {e}")))?; + let tmp_path = tmp_dir.std_path().join("spill.arrows"); + let (mut sender, receiver) = create_replay_spill(tmp_path, schema.clone(), memory_limit); + + // Drain the one-shot source into the spill once, in the background. The spill + // tees to memory/disk so the first reader can consume batches as they arrive + // while later readers replay the complete source. + let drain_handle = tokio::task::spawn(async move { + let mut errored = false; + while let Some(res) = source.next().await { + match res { + Ok(batch) => { + if let Err(e) = sender.write(batch).await { + sender.send_error(e); + errored = true; + break; + } + } + Err(e) => { + sender.send_error(e); + errored = true; + break; + } + } + } + // Only finish on a clean drain. Calling finish() after an error would + // overwrite the original (replayable) error with a generic one, losing + // the source error's type (e.g. an external error from user code). + if !errored && let Err(err) = sender.finish().await { + sender.send_error(err); + } + sender + }); + + let partition = Arc::new(SpillPartition { + schema: schema.clone(), + receiver, + _tmp_dir: Arc::new(tmp_dir), + _drain_handle: Arc::new(drain_handle), + }); + Ok(Arc::new(StreamingTable::try_new(schema, vec![partition])?)) +} + +/// A [`PartitionStream`] backed by a replayable spill. +/// +/// Each call to [`PartitionStream::execute`] opens a fresh stream over the spill, +/// so the partition can be scanned repeatedly. The spill file and the background +/// task draining the source are kept alive for as long as this partition exists. +struct SpillPartition { + schema: SchemaRef, + receiver: SpillReceiver, + // The spilled data lives in this temp dir; dropping it deletes the spill file. + _tmp_dir: Arc, + // Keeps the background drain task (which owns the `SpillSender`) alive. The + // `SpillSender` must outlive the readers or they error out, so we hold the + // handle rather than detaching it. + _drain_handle: Arc>, +} + +impl std::fmt::Debug for SpillPartition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SpillPartition") + .field("schema", &self.schema) + .finish() + } +} + +impl PartitionStream for SpillPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + self.receiver.read() + } +} + #[derive(Clone)] pub struct SpillReceiver { status_receiver: tokio::sync::watch::Receiver, diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index ff0a119158c..9af133d2e9c 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -4,8 +4,7 @@ use arrow_array::RecordBatch; use chrono::TimeDelta; use datafusion::physical_plan::SendableRecordBatchStream; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use lance_arrow::{ ARROW_EXT_NAME_KEY, BLOB_DEDICATED_SIZE_THRESHOLD_META_KEY, BLOB_INLINE_SIZE_THRESHOLD_META_KEY, BLOB_META_KEY, BLOB_V2_EXT_NAME, @@ -13,12 +12,9 @@ use lance_arrow::{ use lance_core::datatypes::{ NullabilityComparison, OnMissing, OnTypeMismatch, SchemaCompareOptions, }; -use lance_core::error::LanceOptionExt; -use lance_core::utils::tempfile::TempDir; use lance_core::utils::tracing::{AUDIT_MODE_CREATE, AUDIT_TYPE_DATA, TRACE_FILE_AUDIT}; use lance_core::{Error, Result, datatypes::Schema}; use lance_datafusion::chunker::{break_stream, chunk_stream}; -use lance_datafusion::spill::{SpillReceiver, SpillSender, create_replay_spill}; use lance_datafusion::utils::StreamingWriteSource; use lance_file::previous::writer::{ FileWriter as PreviousFileWriter, ManifestProvider as PreviousManifestProvider, @@ -1510,108 +1506,6 @@ async fn resolve_commit_handler( } } -/// Create an iterator of record batch streams from the given source. -/// -/// If `enable_retries` is true, then the source will be saved either in memory -/// or spilled to disk to allow replaying the source in case of a failure. The -/// source will be kept in memory if either (1) the size hint shows that -/// there is only one batch or (2) the stream contains less than 100MB of -/// data. Otherwise, the source will be spilled to a temporary file on disk. -/// -/// This is used to support retries on write operations. -async fn new_source_iter( - source: SendableRecordBatchStream, - enable_retries: bool, -) -> Result + Send + 'static>> { - if enable_retries { - let schema = source.schema(); - - // If size hint shows there is only one batch, spilling has no benefit, just keep that - // in memory. (This is a pretty common case.) - let size_hint = source.size_hint(); - if size_hint.0 == 1 && size_hint.1 == Some(1) { - let batches: Vec = source.try_collect().await?; - Ok(Box::new(std::iter::repeat_with(move || { - Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), - futures::stream::iter(batches.clone().into_iter().map(Ok)), - )) as SendableRecordBatchStream - }))) - } else { - // Allow buffering up to 100MB in memory before spilling to disk. - Ok(Box::new( - SpillStreamIter::try_new(source, 100 * 1024 * 1024).await?, - )) - } - } else { - Ok(Box::new(std::iter::once(source))) - } -} - -struct SpillStreamIter { - receiver: SpillReceiver, - _sender_handle: tokio::task::JoinHandle, - // This temp dir is used to store the spilled data. It is kept alive by - // this struct. When this struct is dropped, the Drop implementation of - // tempfile::TempDir will delete the temp dir. - _tmp_dir: TempDir, -} - -impl SpillStreamIter { - pub async fn try_new( - mut source: SendableRecordBatchStream, - memory_limit: usize, - ) -> Result { - let tmp_dir = tokio::task::spawn_blocking(|| { - TempDir::try_new() - .map_err(|e| Error::invalid_input(format!("Failed to create temp dir: {}", e))) - }) - .await - .ok() - .expect_ok()??; - - let tmp_path = tmp_dir.std_path().join("spill.arrows"); - let (mut sender, receiver) = create_replay_spill(tmp_path, source.schema(), memory_limit); - - let sender_handle = tokio::task::spawn(async move { - while let Some(res) = source.next().await { - match res { - Ok(batch) => match sender.write(batch).await { - Ok(_) => {} - Err(e) => { - sender.send_error(e); - break; - } - }, - Err(e) => { - sender.send_error(e); - break; - } - } - } - - if let Err(err) = sender.finish().await { - sender.send_error(err); - } - sender - }); - - Ok(Self { - receiver, - _tmp_dir: tmp_dir, - _sender_handle: sender_handle, - }) - } -} - -impl Iterator for SpillStreamIter { - type Item = SendableRecordBatchStream; - - fn next(&mut self) -> Option { - Some(self.receiver.read()) - } -} - #[cfg(test)] mod tests { use super::*; @@ -3315,6 +3209,7 @@ mod tests { async fn test_write_interruption_recovery() { use super::commit::CommitBuilder; use arrow_array::record_batch; + use lance_core::utils::tempfile::TempDir; // Create a temporary directory for testing let temp_dir = TempDir::default(); diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index b14421c963f..8d7aed6cad5 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -73,6 +73,8 @@ use datafusion::common::NullEquality; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::error::DataFusionError; use datafusion::{ + catalog::{TableProvider, streaming::StreamingTable}, + datasource::MemTable, execution::{ context::{SessionConfig, SessionContext}, memory_pool::MemoryConsumer, @@ -110,9 +112,10 @@ use lance_datafusion::{ chunker::chunk_stream, dataframe::BatchStreamGrouper, exec::{ - HardCapBatchSizeExec, LanceExecutionOptions, OneShotExec, analyze_plan, execute_plan, - get_session_context, + HardCapBatchSizeExec, LanceExecutionOptions, OneShotExec, OneShotPartitionStream, + analyze_plan, execute_plan, get_session_context, provider_to_stream, }, + spill::spilling_table_provider, utils::{StreamingWriteSource, reader_to_stream}, }; use lance_file::version::LanceFileVersion; @@ -336,6 +339,12 @@ struct MergeInsertParams { // Controls whether data that is not matched by the source is deleted or not delete_not_matched_by_source: WhenNotMatchedBySource, conflict_retries: u32, + // When the source is a one-shot stream and `conflict_retries > 0`, the source + // is spilled (memory, then disk) so it can be replayed on each retry. Set to + // false to fail fast on contention instead of buffering the stream. Has no + // effect on re-scannable sources (materialized batches, files), which never + // spill. + spill_for_retry: bool, retry_timeout: Duration, // List of MemWAL region generations to mark as merged when this commit succeeds. merged_generations: Vec, @@ -465,6 +474,7 @@ impl MergeInsertBuilder { insert_not_matched: true, delete_not_matched_by_source: WhenNotMatchedBySource::Keep, conflict_retries: 10, + spill_for_retry: true, retry_timeout: Duration::from_secs(30), merged_generations: Vec::new(), skip_auto_cleanup: false, @@ -512,6 +522,27 @@ impl MergeInsertBuilder { self } + /// Controls whether a one-shot stream source is spilled so it can be replayed + /// across retries. + /// + /// When the source is a one-shot stream (e.g. [`MergeInsertJob::execute`]) and + /// `conflict_retries > 0`, the source is buffered in memory and spilled to disk + /// so each retry can re-read it. Set this to `false` to skip that buffering and + /// fail fast with a contention error instead of writing the stream to disk. + /// + /// This has no effect on re-scannable sources (materialized batches via + /// [`MergeInsertJob::execute_batches`], or a [`TableProvider`] via + /// [`MergeInsertJob::execute_provider`]), which are replayed directly and never + /// spill. + /// + /// Default is true. + /// + /// [`TableProvider`]: datafusion::catalog::TableProvider + pub fn spill_for_retry(&mut self, spill: bool) -> &mut Self { + self.params.spill_for_retry = spill; + self + } + /// Set the timeout used to limit retries. /// /// This is the maximum time to spend on the operation before giving up. At @@ -591,6 +622,16 @@ enum SchemaComparison { Subschema, } +/// Wrap a one-shot stream in a non-replayable [`StreamingTable`] provider. +/// +/// The provider can only be scanned once (its single partition hands out the +/// underlying stream), so it must not be used where retries may re-scan it. +fn one_shot_provider(stream: SendableRecordBatchStream) -> Result> { + let schema = stream.schema(); + let partition = Arc::new(OneShotPartitionStream::new(stream)); + Ok(Arc::new(StreamingTable::try_new(schema, vec![partition])?)) +} + impl MergeInsertJob { pub async fn execute_reader( self, @@ -1403,24 +1444,136 @@ impl MergeInsertJob { )) } - /// Executes the merge insert job + /// Executes the merge insert job from a one-shot stream source. /// /// This will take in the source, merge it with the existing target data, and insert new - /// rows, update existing rows, and delete existing rows + /// rows, update existing rows, and delete existing rows. + /// + /// A stream can only be read once, so when `conflict_retries > 0` the stream is + /// spilled (in memory, then to disk) so it can be replayed on each retry. See + /// [`MergeInsertBuilder::spill_for_retry`] to fail fast instead, and + /// [`Self::execute_batches`] / [`Self::execute_provider`] for re-scannable + /// sources that never spill. pub async fn execute( self, source: SendableRecordBatchStream, ) -> Result<(Arc, MergeStats)> { - let source_iter = super::new_source_iter(source, self.params.conflict_retries > 0).await?; + let (provider, replayable) = self.stream_source_to_provider(source).await?; + self.execute_inner(provider, replayable).await + } + + /// Executes the merge insert job from a re-scannable [`TableProvider`]. + /// + /// This is the canonical entry point: [`Self::execute`] and + /// [`Self::execute_batches`] are thin wrappers that build a provider and call + /// this method. Because a provider can be scanned repeatedly, retries re-read + /// the source directly and never spill to disk. The provider's reported + /// statistics (e.g. from a [`MemTable`] or file source) also let DataFusion + /// optimize the merge join. + /// + /// [`MemTable`]: datafusion::datasource::MemTable + pub async fn execute_provider( + self, + provider: Arc, + ) -> Result<(Arc, MergeStats)> { + // A genuine TableProvider is re-scannable by contract, so retries are safe. + self.execute_inner(provider, true).await + } + + /// Executes the merge insert job from materialized record batches. + /// + /// The batches are wrapped in an in-memory [`MemTable`], which is re-scannable + /// (retries replay from memory, never spilling) and reports exact statistics to + /// the merge join. This is the preferred entry point when the full source is + /// already in memory. + pub async fn execute_batches( + self, + batches: Vec, + ) -> Result<(Arc, MergeStats)> { + let provider = self.batches_to_provider(batches)?; + self.execute_inner(provider, true).await + } + + /// Like [`Self::execute_batches`] but returns the uncommitted transaction. + /// + /// Use [`CommitBuilder`] to commit the returned transaction. + pub async fn execute_uncommitted_batches( + self, + batches: Vec, + ) -> Result { + let provider = self.batches_to_provider(batches)?; + self.execute_uncommitted_impl(provider).await + } + + /// Wrap materialized batches in a multi-partition in-memory [`MemTable`]. + fn batches_to_provider(&self, batches: Vec) -> Result> { + let schema = batches + .first() + .map(|batch| batch.schema()) + .unwrap_or_else(|| Arc::new(Schema::from(self.dataset.schema()))); + // Spread batches across partitions so the source can be scanned in parallel + // and reports per-partition statistics. A single inner Vec would be one + // partition with no parallelism. + let partitions = Self::batches_into_partitions(batches); + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) + } + + /// Distribute batches round-robin across up to `num_compute_intensive_cpus` + /// partitions, so a [`MemTable`] built from them can be scanned in parallel. + /// Always returns at least one (possibly empty) partition so an empty source + /// still produces a valid provider. + fn batches_into_partitions(batches: Vec) -> Vec> { + let num_partitions = batches.len().min(get_num_compute_intensive_cpus()).max(1); + let mut partitions = vec![Vec::new(); num_partitions]; + for (idx, batch) in batches.into_iter().enumerate() { + partitions[idx % num_partitions].push(batch); + } + partitions + } + + /// Wrap a one-shot stream source in a provider, returning whether it can be + /// replayed across retries. + /// + /// With retries enabled and spilling allowed, the stream is drained into a + /// replayable spill (memory up to 100MB, then disk). Otherwise the stream is + /// wrapped in a non-replayable one-shot provider and any conflict fails fast. + async fn stream_source_to_provider( + &self, + source: SendableRecordBatchStream, + ) -> Result<(Arc, bool)> { + if self.params.conflict_retries > 0 && self.params.spill_for_retry { + // Allow buffering up to 100MB in memory before spilling to disk. + let provider = spilling_table_provider(source, 100 * 1024 * 1024).await?; + Ok((provider, true)) + } else { + Ok((one_shot_provider(source)?, false)) + } + } + + /// Run the retry loop against a provider, re-scanning it on each attempt. + /// + /// `replayable` indicates whether the provider can be scanned more than once. + /// When it cannot (a one-shot stream that was not spilled), retries are + /// disabled so we never scan it twice; the operation runs once and surfaces any + /// commit conflict directly. + async fn execute_inner( + self, + provider: Arc, + replayable: bool, + ) -> Result<(Arc, MergeStats)> { let dataset = self.dataset.clone(); let config = RetryConfig { - max_retries: self.params.conflict_retries, + max_retries: if replayable { + self.params.conflict_retries + } else { + 0 + }, retry_timeout: self.params.retry_timeout, }; - let wrapper = MergeInsertJobWithIterator { + let wrapper = MergeInsertJobWithProvider { job: self, - source_iter: Arc::new(Mutex::new(source_iter)), + provider, attempt_count: Arc::new(AtomicU32::new(0)), }; @@ -1435,7 +1588,8 @@ impl MergeInsertJob { source: impl StreamingWriteSource, ) -> Result { let stream = source.into_stream(); - self.execute_uncommitted_impl(stream).await + self.execute_uncommitted_impl(one_shot_provider(stream)?) + .await } fn create_plan_join_type(&self) -> JoinType { @@ -1453,17 +1607,13 @@ impl MergeInsertJob { } } - async fn create_plan( - self, - source: SendableRecordBatchStream, - ) -> Result> { + async fn create_plan(self, provider: Arc) -> Result> { // Goal: we shouldn't manually have to specify which columns to scan. // DataFusion's optimizer should be able to automatically perform // projection pushdown for us. // Goal: we shouldn't have to add new branches in this code to handle // indexed vs non-indexed cases. That should be handled by optimizer rules. - let session_config = SessionConfig::default(); - let session_ctx = SessionContext::new_with_config(session_config); + let session_ctx = SessionContext::new(); let scan = session_ctx.read_lance_unordered(self.dataset.clone(), true, true)?; // Wrap column names in double quotes to preserve case (DataFusion lowercases unquoted identifiers) let on_cols = self @@ -1473,7 +1623,11 @@ impl MergeInsertJob { .map(|name| format!("\"{}\"", name)) .collect::>(); let on_cols_refs = on_cols.iter().map(|s| s.as_str()).collect::>(); - let source_df = session_ctx.read_one_shot(source)?; + // Plan against the provider directly so its statistics reach the optimizer. + // The merge write node requires a single-partition input, so the optimizer + // coalesces a multi-partition provider for us (see + // `FullSchemaMergeInsertExec::required_input_distribution`). + let source_df = session_ctx.read_table(provider)?; // Capture the source field names *before* aliasing / joining so we // can tell which dataset columns are missing from the source and // need to be filled from the target side of the join below. @@ -1552,14 +1706,14 @@ impl MergeInsertJob { async fn execute_uncommitted_v2( self, - source: SendableRecordBatchStream, + provider: Arc, ) -> Result<( Transaction, MergeStats, Option, Option, )> { - let plan = self.create_plan(source).await?; + let plan = self.create_plan(provider).await?; // Execute the plan // Assert that we have exactly one partition since we're designed for single-partition execution @@ -1734,14 +1888,14 @@ impl MergeInsertJob { async fn execute_uncommitted_impl( self, - source: SendableRecordBatchStream, + provider: Arc, ) -> Result { // Check if we can use the fast path - let can_use_fast_path = self.can_use_create_plan(source.schema().as_ref()).await?; + let can_use_fast_path = self.can_use_create_plan(provider.schema().as_ref()).await?; if can_use_fast_path { let (transaction, stats, affected_rows, inserted_rows_filter) = - self.execute_uncommitted_v2(source).await?; + self.execute_uncommitted_v2(provider).await?; return Ok(UncommittedMergeInsert { transaction, affected_rows, @@ -1750,6 +1904,8 @@ impl MergeInsertJob { }); } + // The slow path consumes a single stream; adapt the provider back into one. + let source = provider_to_stream(provider).await?; let source_schema = source.schema(); let lance_schema = lance_core::datatypes::Schema::try_from(source_schema.as_ref())?; let full_schema = self.dataset.schema(); @@ -1995,7 +2151,9 @@ impl MergeInsertJob { // Clone self since create_plan consumes the job let cloned_job = self.clone(); - let plan = cloned_job.create_plan(Box::pin(stream)).await?; + let plan = cloned_job + .create_plan(one_shot_provider(Box::pin(stream))?) + .await?; let display = DisplayableExecutionPlan::new(plan.as_ref()); Ok(format!("{}", display.indent(verbose))) @@ -2028,7 +2186,7 @@ impl MergeInsertJob { // Clone self since create_plan consumes the job let cloned_job = self.clone(); - let plan = cloned_job.create_plan(source).await?; + let plan = cloned_job.create_plan(one_shot_provider(source)?).await?; // Use the analyze_plan function from lance_datafusion, but strip out the wrapper lines let options = LanceExecutionOptions::default(); @@ -2078,15 +2236,15 @@ pub struct UncommittedMergeInsert { pub inserted_rows_filter: Option, } -/// Wrapper struct that combines MergeInsertJob with the source iterator for retry functionality +/// Wrapper struct that combines MergeInsertJob with the source provider for retry functionality #[derive(Clone)] -struct MergeInsertJobWithIterator { +struct MergeInsertJobWithProvider { job: MergeInsertJob, - source_iter: Arc + Send + 'static>>>, + provider: Arc, attempt_count: Arc, } -impl RetryExecutor for MergeInsertJobWithIterator { +impl RetryExecutor for MergeInsertJobWithProvider { type Data = UncommittedMergeInsert; type Result = (Arc, MergeStats); @@ -2094,10 +2252,11 @@ impl RetryExecutor for MergeInsertJobWithIterator { // Increment attempt counter self.attempt_count.fetch_add(1, Ordering::SeqCst); - // We need to get a fresh stream for each retry attempt - // The source_iter provides unlimited streams from the same source data - let stream = self.source_iter.lock().unwrap().next().unwrap(); - self.job.clone().execute_uncommitted_impl(stream).await + // Re-scan the provider on each retry attempt. + self.job + .clone() + .execute_uncommitted_impl(self.provider.clone()) + .await } async fn commit(&self, dataset: Arc, mut data: Self::Data) -> Result { @@ -5422,7 +5581,10 @@ mod tests { let new_data = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16)); let new_data_stream = reader_to_stream(Box::new(new_data)); - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap()) + .await + .unwrap(); // Assert the plan structure using portable plan matching // The optimized plan should have: @@ -5475,7 +5637,10 @@ mod tests { let new_data_stream = reader_to_stream(Box::new(new_data)); // This should use the fast path (execute_uncommitted_v2) - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap()) + .await + .unwrap(); // The optimized plan should use Inner join instead of Right join since we're not // inserting unmatched rows. The sentinel IS NOT NULL condition is folded away by @@ -5523,7 +5688,10 @@ mod tests { let new_data_reader = new_data.into_reader_rows(RowCount::from(512), BatchCount::from(16)); let new_data_stream = reader_to_stream(Box::new(new_data_reader)); - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap()) + .await + .unwrap(); // The optimized plan should use Inner join and include the UpdateIf condition. // The sentinel IS NOT NULL condition is folded away (sentinel is lit(true)). @@ -5575,7 +5743,10 @@ mod tests { // Should reach the v2 fast path (`create_plan` + FullSchemaMergeInsertExec). // Dropping to v1 here would return an error from create_plan instead. - let plan = merge_insert_job.create_plan(new_data_stream).await.unwrap(); + let plan = merge_insert_job + .create_plan(one_shot_provider(new_data_stream).unwrap()) + .await + .unwrap(); // The join is Right because we keep unmatched source rows (InsertAll) // but discard unmatched target rows (DoNothing on when_matched, @@ -8500,7 +8671,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch.clone())], schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap()) + .await + .unwrap(); assert_plan_node_equals( plan, "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing @@ -8582,7 +8756,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch.clone())], id_only_schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap()) + .await + .unwrap(); assert_plan_node_equals( plan, "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing @@ -8664,7 +8841,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch.clone())], schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap()) + .await + .unwrap(); assert_plan_node_equals( plan, "MergeInsert: on=[key], when_matched=Delete, when_not_matched=InsertAll, when_not_matched_by_source=Keep...THEN 2 WHEN...THEN 3 ELSE 0 END as __action]...projection=[key, value, filterme]" @@ -8769,7 +8949,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(non_matching_batch.clone())], schema.clone(), ))); - let plan = plan_job.create_plan(plan_stream).await.unwrap(); + let plan = plan_job + .create_plan(one_shot_provider(plan_stream).unwrap()) + .await + .unwrap(); assert_plan_node_equals( plan, "DeleteOnlyMergeInsert: on=[key], when_matched=Delete, when_not_matched=DoNothing @@ -8889,7 +9072,10 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n [Ok(new_batch)], schema.clone(), ))); - let plan = job.create_plan(plan_stream).await.unwrap(); + let plan = job + .create_plan(one_shot_provider(plan_stream).unwrap()) + .await + .unwrap(); let plan_str = datafusion::physical_plan::displayable(plan.as_ref()) .indent(true) @@ -8980,7 +9166,7 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n impl std::error::Error for MyTestError {} #[tokio::test] - async fn test_merge_insert_execute_reader_preserves_external_error() { + async fn test_merge_insert_execute_reader_preserves_error_message() { let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("key", DataType::Int32, false), ArrowField::new("value", DataType::Int32, false), @@ -9017,14 +9203,14 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n .execute_reader(Box::new(reader) as Box) .await; - match result { - Err(Error::External { source }) => { - let original = source.downcast_ref::().unwrap(); - assert_eq!(original.code, error_code); - } - Err(other) => panic!("Expected External, got: {:?}", other), - Ok(_) => panic!("Expected error"), - } + // The source error is routed through the merge plan, which shares it + // across join partitions, so its concrete type is not recoverable. The + // message must still reach the caller. + let err = result.expect_err("expected the source error to surface"); + assert!( + err.to_string().contains("merge insert failure"), + "source error message should be preserved; got: {err}" + ); } } @@ -10058,4 +10244,296 @@ MergeInsert: on=[id], when_matched=DoNothing, when_not_matched=InsertAll, when_n "Newly written merge-insert data files should be cleaned up on apply_deletions failure" ); } + + fn id_value_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("value", DataType::UInt32, false), + ])) + } + + /// `execute_provider` is the canonical entry point; a `MemTable` source merges + /// the same way a stream does. + #[tokio::test] + async fn test_merge_insert_execute_provider() { + let initial = + record_batch!(("id", UInt32, [0, 1, 2]), ("value", UInt32, [0, 0, 0])).unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // Update id=1, insert id=3. + let new_data = record_batch!(("id", UInt32, [1, 3]), ("value", UInt32, [10, 30])).unwrap(); + let provider: Arc = Arc::new( + datafusion::datasource::MemTable::try_new(new_data.schema(), vec![vec![new_data]]) + .unwrap(), + ); + + let (merged, stats) = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_provider(provider) + .await + .unwrap(); + + assert_eq!(stats.num_updated_rows, 1); + assert_eq!(stats.num_inserted_rows, 1); + + let batch = merged.scan().try_into_batch().await.unwrap(); + let ids = batch["id"].as_primitive::(); + let values = batch["value"].as_primitive::(); + let merged_rows: HashMap = ids + .values() + .iter() + .zip(values.values().iter()) + .map(|(id, value)| (*id, *value)) + .collect(); + assert_eq!( + merged_rows, + HashMap::from([(0, 0), (1, 10), (2, 0), (3, 30)]) + ); + } + + /// `execute_batches` merges materialized batches; multiple batches are spread + /// across partitions and merged correctly. + #[tokio::test] + async fn test_merge_insert_execute_batches() { + let initial = + record_batch!(("id", UInt32, [0, 1, 2]), ("value", UInt32, [0, 0, 0])).unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // Two batches: update id=1 (batch 0), insert id=3 (batch 1). + let batch0 = record_batch!(("id", UInt32, [1]), ("value", UInt32, [10])).unwrap(); + let batch1 = record_batch!(("id", UInt32, [3]), ("value", UInt32, [30])).unwrap(); + + let (merged, stats) = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_batches(vec![batch0, batch1]) + .await + .unwrap(); + + assert_eq!(stats.num_updated_rows, 1); + assert_eq!(stats.num_inserted_rows, 1); + + let batch = merged.scan().try_into_batch().await.unwrap(); + let ids = batch["id"].as_primitive::(); + let values = batch["value"].as_primitive::(); + let merged_rows: HashMap = ids + .values() + .iter() + .zip(values.values().iter()) + .map(|(id, value)| (*id, *value)) + .collect(); + assert_eq!( + merged_rows, + HashMap::from([(0, 0), (1, 10), (2, 0), (3, 30)]) + ); + } + + /// An empty batch list still produces a valid (single, empty) partition, so the + /// merge is a no-op and the target is unchanged. + #[tokio::test] + async fn test_merge_insert_execute_batches_empty() { + let initial = + record_batch!(("id", UInt32, [0, 1, 2]), ("value", UInt32, [0, 0, 0])).unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + let (merged, stats) = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .try_build() + .unwrap() + .execute_batches(vec![]) + .await + .unwrap(); + + assert_eq!(stats.num_updated_rows, 0); + assert_eq!(stats.num_inserted_rows, 0); + + let batch = merged.scan().try_into_batch().await.unwrap(); + let ids = batch["id"].as_primitive::(); + let values = batch["value"].as_primitive::(); + let merged_rows: HashMap = ids + .values() + .iter() + .zip(values.values().iter()) + .map(|(id, value)| (*id, *value)) + .collect(); + assert_eq!(merged_rows, HashMap::from([(0, 0), (1, 0), (2, 0)])); + } + + fn collect_exact_row_counts(plan: &Arc, out: &mut Vec) { + if let Ok(stats) = plan.partition_statistics(None) + && let datafusion::common::stats::Precision::Exact(n) = stats.num_rows + { + out.push(n); + } + for child in plan.children() { + collect_exact_row_counts(child, out); + } + } + + /// Use case 3: planning against the provider exposes its exact source + /// statistics to the optimizer. + #[tokio::test] + async fn test_merge_insert_source_statistics_in_plan() { + let schema = id_value_schema(); + let target_rows = 1000u32; + let initial = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..target_rows)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 0, + target_rows as usize, + ))), + ], + ) + .unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // A small source whose exact row count is distinct from the target's. + let source_rows = 10usize; + let new_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from_iter_values(0..source_rows as u32)), + Arc::new(UInt32Array::from_iter_values(std::iter::repeat_n( + 1, + source_rows, + ))), + ], + ) + .unwrap(); + let provider: Arc = + Arc::new(MemTable::try_new(schema.clone(), vec![vec![new_data]]).unwrap()); + + let job = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .try_build() + .unwrap(); + + // The provider's exact row count reaches the plan's statistics. + let plan = job.create_plan(provider).await.unwrap(); + let mut row_counts = Vec::new(); + collect_exact_row_counts(&plan, &mut row_counts); + assert!( + row_counts.contains(&source_rows), + "source provider's exact row count ({source_rows}) should reach the plan; got {row_counts:?}" + ); + } + + /// With a one-shot stream source and `spill_for_retry(false)`, a commit + /// conflict fails fast instead of replaying the stream. The non-replayable + /// one-shot provider must be scanned exactly once (scanning it twice would + /// panic), proving retries are disabled even though `conflict_retries > 0`. + #[tokio::test] + async fn test_merge_insert_spill_for_retry_false_fails_fast() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt32, false).with_metadata( + vec![( + "lance-schema:unenforced-primary-key".to_string(), + "true".to_string(), + )] + .into_iter() + .collect(), + ), + Field::new("value", DataType::UInt32, false), + ])); + let initial = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![0, 0, 0, 0])), + ], + ) + .unwrap(); + let dataset = Arc::new( + InsertBuilder::new("memory://") + .execute(vec![initial]) + .await + .unwrap(), + ); + + // Merge insert job based on version 1, with retries enabled but spilling off. + let new_data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![100])), + Arc::new(UInt32Array::from(vec![1])), + ], + ) + .unwrap(); + let job = MergeInsertBuilder::try_new(dataset.clone(), vec!["id".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll) + .conflict_retries(10) + .spill_for_retry(false) + .try_build() + .unwrap(); + + // An append commits first (version 2), so the merge built on version 1 hits + // an unresolvable conflict on commit. + let append_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![50])), + Arc::new(UInt32Array::from(vec![2])), + ], + ) + .unwrap(); + InsertBuilder::new(dataset.clone()) + .with_params(&WriteParams { + mode: WriteMode::Append, + ..Default::default() + }) + .execute(vec![append_batch]) + .await + .unwrap(); + + let source = RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::iter(vec![Ok(new_data)]), + ); + let merge_result = job + .execute(Box::pin(source) as SendableRecordBatchStream) + .await; + + assert!( + matches!( + merge_result, + Err(crate::Error::TooMuchWriteContention { .. }) + ), + "Expected fail-fast TooMuchWriteContention, got: {:?}", + merge_result + ); + } }