Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from .lance import __version__ as __version__
from .lance import _Session as Session
from .query import FullTextQuery
from .types import _coerce_reader
from .types import _coerce_reader, _is_materialized
from .udf import BatchUDF, normalize_transform
from .udf import BatchUDFCheckpoint as BatchUDFCheckpoint
from .udf import batch_udf as batch_udf
Expand Down Expand Up @@ -392,6 +392,12 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
"""
reader = _coerce_reader(data_obj, schema)

# Materialized sources are wrapped in an in-memory table so retries never
# spill and the source's statistics can drive the join; everything else is
# treated as a one-shot stream.
if _is_materialized(data_obj):
return super(MergeInsertBuilder, self).execute_batches(reader)

return super(MergeInsertBuilder, self).execute(reader)

def execute_uncommitted(
Expand All @@ -416,6 +422,9 @@ def execute_uncommitted(
"""
reader = _coerce_reader(data_obj, schema)

if _is_materialized(data_obj):
return super(MergeInsertBuilder, self).execute_uncommitted_batches(reader)

return super(MergeInsertBuilder, self).execute_uncommitted(reader)

# These next three overrides exist only to document the methods
Expand Down
7 changes: 7 additions & 0 deletions python/python/lance/lance/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,13 @@ class _MergeInsertBuilder:
def when_not_matched_insert_all(self) -> Self: ...
def when_not_matched_by_source_delete(self, expr: Optional[str] = None) -> Self: ...
def execute(self, new_data: pa.RecordBatchReader) -> ExecuteResult: ...
def execute_batches(self, new_data: pa.RecordBatchReader) -> ExecuteResult: ...
def execute_uncommitted(
self, new_data: pa.RecordBatchReader
) -> tuple[Transaction, ExecuteResult]: ...
def execute_uncommitted_batches(
self, new_data: pa.RecordBatchReader
) -> tuple[Transaction, ExecuteResult]: ...

class _Scanner:
@property
Expand Down
28 changes: 28 additions & 0 deletions python/python/lance/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ def _casting_recordbatch_iter(
yield batch


def _is_materialized(data_obj: ReaderLike) -> bool:
"""Whether ``data_obj`` is fully materialized in memory.

Materialized sources (tables, in-memory frames) can be wrapped in an
in-memory table for replay without spilling and to expose exact statistics.
Streaming or re-readable sources (readers, scanners, datasets, generators)
are not considered materialized.
"""
if _check_for_pandas(data_obj) and isinstance(data_obj, pd.DataFrame):
return True
if isinstance(data_obj, (pa.Table, pa.RecordBatch)):
return True
if (
type(data_obj).__module__.startswith("polars")
and data_obj.__class__.__name__ == "DataFrame"
):
return True
if isinstance(data_obj, dict):
return True
if (
isinstance(data_obj, list)
and len(data_obj) > 0
and isinstance(data_obj[0], dict)
):
return True
return False


def _coerce_reader(
data_obj: ReaderLike, schema: Optional[pa.Schema] = None
) -> pa.RecordBatchReader:
Expand Down
23 changes: 23 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2342,6 +2342,29 @@ def test_merge_insert(tmp_path: Path):
check_merge_stats(merge_dict, (None, None, None))


@pytest.mark.parametrize("materialized", [True, False])
def test_merge_insert_input_kinds(tmp_path: Path, materialized: bool):
# A materialized pa.Table is routed through the in-memory (MemTable) path,
# while a RecordBatchReader is routed through the streaming path. Both must
# produce identical results.
schema = pa.schema([pa.field("id", pa.int64()), pa.field("value", pa.int64())])
base = pa.table({"id": range(5), "value": [0] * 5}, schema=schema)
new = pa.table({"id": [1, 2, 5, 6], "value": [10, 20, 50, 60]}, schema=schema)

dataset = lance.write_dataset(base, tmp_path / "dataset", mode="create")
source = new if materialized else new.to_reader()

dataset.merge_insert(
"id"
).when_matched_update_all().when_not_matched_insert_all().execute(source)

result = dataset.to_table().sort_by("id").to_pydict()
assert result == {
"id": [0, 1, 2, 3, 4, 5, 6],
"value": [0, 10, 20, 0, 0, 50, 60],
}


def test_merge_insert_subcols(tmp_path: Path):
initial_data = pa.table(
{
Expand Down
54 changes: 54 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,60 @@ impl MergeInsertBuilder {
Ok((PyLance(transaction), stats))
}

/// Execute the merge insert from fully-materialized data.
///
/// The data is read into memory and wrapped in an in-memory table, so retries
/// never spill to disk and the source's statistics drive the join. Callers
/// should only route in-memory inputs (e.g. a `pa.Table`) here.
pub fn execute_batches(&mut self, new_data: &Bound<PyAny>) -> PyResult<Py<PyAny>> {
let py = new_data.py();
let reader = convert_reader(new_data)?;
let batches = reader
.collect::<std::result::Result<Vec<RecordBatch>, _>>()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let job = self
.builder
.try_build()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let (new_dataset, stats) = rt()
.spawn(Some(py), job.execute_batches(batches))?
.map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?;

let dataset = self.dataset.bind(py);
dataset.borrow_mut().ds = new_dataset;

Ok(Self::build_stats(&stats, py)?.into())
}

/// [`Self::execute_batches`] without committing; returns the transaction.
pub fn execute_uncommitted_batches<'a>(
&mut self,
new_data: &Bound<'a, PyAny>,
) -> PyResult<(PyLance<Transaction>, Bound<'a, PyDict>)> {
let py = new_data.py();
let reader = convert_reader(new_data)?;
let batches = reader
.collect::<std::result::Result<Vec<RecordBatch>, _>>()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let job = self
.builder
.try_build()
.map_err(|err| PyValueError::new_err(err.to_string()))?;

let UncommittedMergeInsert {
transaction, stats, ..
} = rt()
.spawn(Some(py), job.execute_uncommitted_batches(batches))?
.map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?;

let stats = Self::build_stats(&stats, py)?;

Ok((PyLance(transaction), stats))
}

#[pyo3(signature=(schema = None, verbose = false))]
pub fn explain_plan(
&mut self,
Expand Down
11 changes: 11 additions & 0 deletions rust/lance-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,17 @@ impl From<datafusion_common::DataFusionError> for Error {
Self::not_supported_source(box_error(e))
}
datafusion_common::DataFusionError::Execution(..) => Self::execution(e.to_string()),
datafusion_common::DataFusionError::Shared(shared) => {
// DataFusion shares an error across consumers (e.g. a join's
// build-side error fanned out to every probe partition) behind an
// `Arc`. If we are the sole owner we can recurse for full fidelity;
// otherwise the inner error can't be moved out, so we preserve its
// message under the execution category (its concrete type is lost).
match std::sync::Arc::try_unwrap(shared) {
Ok(inner) => Self::from(inner),
Err(shared) => Self::execution(shared.to_string()),
}
}
datafusion_common::DataFusionError::External(source) => {
// Try to downcast to lance_core::Error first
match source.downcast::<Self>() {
Expand Down
21 changes: 20 additions & 1 deletion rust/lance-datafusion/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use arrow_array::RecordBatch;
use arrow_schema::Schema as ArrowSchema;
use datafusion::physical_plan::metrics::MetricType;
use datafusion::{
catalog::streaming::StreamingTable,
catalog::{TableProvider, streaming::StreamingTable},
dataframe::DataFrame,
execution::{
TaskContext,
Expand Down Expand Up @@ -867,6 +867,25 @@ impl SessionContextExt for SessionContext {
}
}

/// Scan a [`TableProvider`] into a single-partition [`SendableRecordBatchStream`].
///
/// Multi-partition providers are coalesced into a single partition. This adapts a
/// re-scannable provider back into the one stream the writer pipeline consumes;
/// re-scanning the same provider (e.g. on a write retry) yields a fresh stream.
pub async fn provider_to_stream(
provider: Arc<dyn TableProvider>,
) -> Result<SendableRecordBatchStream> {
let ctx = SessionContext::new();
let plan = provider.scan(&ctx.state(), None, &[], None).await?;
let plan: Arc<dyn ExecutionPlan> =
if plan.properties().output_partitioning().partition_count() > 1 {
Arc::new(CoalescePartitionsExec::new(plan))
} else {
plan
};
Ok(plan.execute(0, ctx.task_ctx())?)
}

#[derive(Clone, Debug)]
pub struct StrictBatchSizeExec {
input: Arc<dyn ExecutionPlan>,
Expand Down
103 changes: 101 additions & 2 deletions rust/lance-datafusion/src/spill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ use std::{

use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
use arrow_array::RecordBatch;
use arrow_schema::{ArrowError, Schema};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use datafusion::{
execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
catalog::{TableProvider, streaming::StreamingTable},
execution::{SendableRecordBatchStream, TaskContext},
physical_plan::{stream::RecordBatchStreamAdapter, streaming::PartitionStream},
};
use datafusion_common::DataFusionError;
use futures::StreamExt;
use lance_arrow::memory::MemoryAccumulator;
use lance_core::error::LanceOptionExt;
use lance_core::utils::tempfile::TempDir;

/// Start a spill of Arrow data to a file that can be read later multiple times.
///
Expand Down Expand Up @@ -60,6 +64,101 @@ pub fn create_replay_spill(
(sender, receiver)
}

/// Wrap a one-shot [`SendableRecordBatchStream`] in a re-scannable [`TableProvider`].
///
/// The source is drained in the background into a replayable spill: up to
/// `memory_limit` bytes are buffered in memory before spilling to a temporary file
/// on disk. Each scan of the returned provider replays the full source from the
/// spill, which is what makes a one-shot stream usable in the write retry loop.
///
/// The provider reports no statistics — the source size is not known until it has
/// been fully drained — so callers that need source statistics (e.g. to drive join
/// ordering) should prefer a materialized or file-backed provider instead.
pub async fn spilling_table_provider(
mut source: SendableRecordBatchStream,
memory_limit: usize,
) -> Result<Arc<dyn TableProvider>, DataFusionError> {
let schema = source.schema();
let tmp_dir = tokio::task::spawn_blocking(TempDir::try_new)
.await
.map_err(|e| DataFusionError::Execution(format!("Failed to spawn temp dir task: {e}")))?
.map_err(|e| DataFusionError::Execution(format!("Failed to create temp dir: {e}")))?;
let tmp_path = tmp_dir.std_path().join("spill.arrows");
let (mut sender, receiver) = create_replay_spill(tmp_path, schema.clone(), memory_limit);

// Drain the one-shot source into the spill once, in the background. The spill
// tees to memory/disk so the first reader can consume batches as they arrive
// while later readers replay the complete source.
let drain_handle = tokio::task::spawn(async move {
let mut errored = false;
while let Some(res) = source.next().await {
match res {
Ok(batch) => {
if let Err(e) = sender.write(batch).await {
sender.send_error(e);
errored = true;
break;
}
}
Err(e) => {
sender.send_error(e);
errored = true;
break;
}
}
}
// Only finish on a clean drain. Calling finish() after an error would
// overwrite the original (replayable) error with a generic one, losing
// the source error's type (e.g. an external error from user code).
if !errored && let Err(err) = sender.finish().await {
sender.send_error(err);
}
sender
});

let partition = Arc::new(SpillPartition {
schema: schema.clone(),
receiver,
_tmp_dir: Arc::new(tmp_dir),
_drain_handle: Arc::new(drain_handle),
});
Ok(Arc::new(StreamingTable::try_new(schema, vec![partition])?))
}

/// A [`PartitionStream`] backed by a replayable spill.
///
/// Each call to [`PartitionStream::execute`] opens a fresh stream over the spill,
/// so the partition can be scanned repeatedly. The spill file and the background
/// task draining the source are kept alive for as long as this partition exists.
struct SpillPartition {
schema: SchemaRef,
receiver: SpillReceiver,
// The spilled data lives in this temp dir; dropping it deletes the spill file.
_tmp_dir: Arc<TempDir>,
// Keeps the background drain task (which owns the `SpillSender`) alive. The
// `SpillSender` must outlive the readers or they error out, so we hold the
// handle rather than detaching it.
_drain_handle: Arc<tokio::task::JoinHandle<SpillSender>>,
}

impl std::fmt::Debug for SpillPartition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SpillPartition")
.field("schema", &self.schema)
.finish()
}
}

impl PartitionStream for SpillPartition {
fn schema(&self) -> &SchemaRef {
&self.schema
}

fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
self.receiver.read()
}
}

#[derive(Clone)]
pub struct SpillReceiver {
status_receiver: tokio::sync::watch::Receiver<WriteStatus>,
Expand Down
Loading
Loading