Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 120 additions & 13 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1438,14 +1438,35 @@ impl MergeInsertJob {
self.execute_uncommitted_impl(stream).await
}

/// Join type for the `create_plan` fast path, which builds the join as
/// `source.join(target)` — i.e. the SOURCE is the left input and the TARGET
/// is the right input.
///
/// At scale the optimizer plans this as a `Partitioned` hash join whose
/// build (left) side is hashed and held in memory per partition. Putting the
/// (typically small) source there keeps the per-partition hash tables
/// bounded by the source size, while the (potentially huge) target streams
/// through as the probe side. Neither input carries row statistics the
/// optimizer can compare (the source is a one-shot stream), so
/// `should_swap_join_order` is `false` and the operands are kept as written
/// rather than swapped to a target build side — which would materialize the
/// entire target per partition and can exhaust memory when the target is
/// large and several partitions run concurrently.
///
/// Because the source is the left input, the operands are ordered
/// `(keep_unmatched_source_rows, keep_unmatched_target_rows)`: keeping
/// unmatched left (source) rows is a `Left` join, keeping unmatched right
/// (target) rows is a `Right` join. Every column is referenced downstream by
/// qualified name, so the join output is semantically identical to a
/// target-left orientation.
fn create_plan_join_type(&self) -> JoinType {
let keep_unmatched_source_rows = self.params.insert_not_matched;
let keep_unmatched_target_rows = !matches!(
self.params.delete_not_matched_by_source,
WhenNotMatchedBySource::Keep
);

match (keep_unmatched_target_rows, keep_unmatched_source_rows) {
match (keep_unmatched_source_rows, keep_unmatched_target_rows) {
(false, false) => JoinType::Inner,
(false, true) => JoinType::Right,
(true, false) => JoinType::Left,
Expand Down Expand Up @@ -1492,16 +1513,15 @@ impl MergeInsertJob {
.map_err(crate::Error::from)?;
let source_df_aliased = source_df.alias("source")?;
let scan_aliased = scan.alias("target")?;
// Build the join as source.join(target) so the (typically small) source
// is the hash join's build side and the (potentially huge) target is
// streamed as the probe side. See `create_plan_join_type` for why this
// orientation is kept by the optimizer and avoids materializing the
// whole target per partition.
let join_type = self.create_plan_join_type();
let dataset_schema: Schema = self.dataset.schema().into();
let mut df = scan_aliased
.join(
source_df_aliased,
join_type,
&on_cols_refs,
&on_cols_refs,
None,
)?
let mut df = source_df_aliased
.join(scan_aliased, join_type, &on_cols_refs, &on_cols_refs, None)?
.with_column(
MERGE_ACTION_COLUMN,
merge_insert_action(&self.params, Some(&dataset_schema))?,
Expand Down Expand Up @@ -5436,7 +5456,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, \
row_id=true, row_addr=true, full_filter=--, refine_filter=--
Expand Down Expand Up @@ -5484,7 +5504,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=DoNothing, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
Expand Down Expand Up @@ -5531,7 +5551,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=UpdateIf(source.value > 20), when_not_matched=DoNothing, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
Expand Down Expand Up @@ -5585,7 +5605,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=DoNothing, when_not_matched=InsertAll, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
Expand All @@ -5596,6 +5616,93 @@ mod tests {
.unwrap();
}

/// Regression test for the join build-side orientation at scale.
///
/// When the target exceeds DataFusion's hash-join collect threshold
/// (`hash_join_single_partition_threshold_rows`, 128K), the join is planned
/// as `mode=Partitioned`, and the build (left) side is hashed in memory per
/// partition. `create_plan` must keep the small source on the build side and
/// stream the large target as the probe side; otherwise the entire target is
/// materialized per partition, which can exhaust memory on large tables.
///
/// The toy-sized plan-snapshot tests above cannot catch a regression here:
/// at small scale the join is `mode=CollectLeft` and the optimizer freely
/// swaps the sides, so they would still pass with the operands reversed.
/// This test exercises the production-representative `Partitioned` plan and
/// asserts the target (`LanceRead`) is the right/probe input.
#[tokio::test]
async fn test_plan_keeps_target_on_probe_side_at_scale() {
use datafusion::physical_plan::{displayable, joins::HashJoinExec, joins::PartitionMode};

// Target with > 128K rows so the target cannot be collected and the
// optimizer plans a Partitioned (not CollectLeft) hash join.
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("value", array::step::<UInt32Type>())
.col("key", array::step::<UInt64Type>());
let data = data.into_reader_rows(RowCount::from(50_000), BatchCount::from(8)); // 400K rows
let ds = Dataset::write(data, "memory://", None).await.unwrap();

let job =
crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()])
.unwrap()
.when_matched(crate::dataset::WhenMatched::UpdateAll)
.when_not_matched(crate::dataset::WhenNotMatched::InsertAll)
.try_build()
.unwrap();

// A small source — the side that should be hashed/built.
let new_data = lance_datagen::gen_batch()
.with_seed(Seed::from(2))
.col("value", array::step::<UInt32Type>())
.col("key", array::step::<UInt64Type>());
let new_data = new_data.into_reader_rows(RowCount::from(1000), BatchCount::from(1));
let stream = reader_to_stream(Box::new(new_data));
let plan = job.create_plan(stream).await.unwrap();

// Locate the HashJoinExec in the physical plan.
fn find_hash_join(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
if plan.as_any().is::<HashJoinExec>() {
return Some(plan.clone());
}
for child in plan.children() {
if let Some(found) = find_hash_join(child) {
return Some(found);
}
}
None
}

let rendered = format!("{}", displayable(plan.as_ref()).indent(true));
let hash_join = find_hash_join(&plan)
.unwrap_or_else(|| panic!("expected a HashJoinExec in the plan:\n{rendered}"));
let hash_join = hash_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("HashJoinExec");

// At this scale the join must be Partitioned, not CollectLeft.
assert_eq!(
hash_join.partition_mode(),
&PartitionMode::Partitioned,
"expected a Partitioned hash join at scale; plan was:\n{rendered}"
);

// The target scan must be the right (probe) input, not the left (build)
// input — that is the whole point of building source.join(target).
let right_has_lance_scan =
format!("{}", displayable(hash_join.right().as_ref()).indent(true))
.contains("LanceRead");
let left_has_lance_scan =
format!("{}", displayable(hash_join.left().as_ref()).indent(true))
.contains("LanceRead");
assert!(
right_has_lance_scan && !left_has_lance_scan,
"target (LanceRead) must be the probe (right) side of the hash join so it \
is streamed rather than materialized; plan was:\n{rendered}"
);
}

#[tokio::test]
async fn test_skip_auto_cleanup() {
let tmpdir = TempStrDir::default();
Expand Down
Loading