diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index b14421c963f..951cf943452 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -1438,6 +1438,27 @@ impl MergeInsertJob { self.execute_uncommitted_impl(stream).await } + /// Join type for the `create_plan` fast path, which builds the join as + /// `source.join(target)` — i.e. the SOURCE is the left input and the TARGET + /// is the right input. + /// + /// At scale the optimizer plans this as a `Partitioned` hash join whose + /// build (left) side is hashed and held in memory per partition. Putting the + /// (typically small) source there keeps the per-partition hash tables + /// bounded by the source size, while the (potentially huge) target streams + /// through as the probe side. Neither input carries row statistics the + /// optimizer can compare (the source is a one-shot stream), so + /// `should_swap_join_order` is `false` and the operands are kept as written + /// rather than swapped to a target build side — which would materialize the + /// entire target per partition and can exhaust memory when the target is + /// large and several partitions run concurrently. + /// + /// Because the source is the left input, the operands are ordered + /// `(keep_unmatched_source_rows, keep_unmatched_target_rows)`: keeping + /// unmatched left (source) rows is a `Left` join, keeping unmatched right + /// (target) rows is a `Right` join. Every column is referenced downstream by + /// qualified name, so the join output is semantically identical to a + /// target-left orientation. fn create_plan_join_type(&self) -> JoinType { let keep_unmatched_source_rows = self.params.insert_not_matched; let keep_unmatched_target_rows = !matches!( @@ -1445,7 +1466,7 @@ impl MergeInsertJob { WhenNotMatchedBySource::Keep ); - match (keep_unmatched_target_rows, keep_unmatched_source_rows) { + match (keep_unmatched_source_rows, keep_unmatched_target_rows) { (false, false) => JoinType::Inner, (false, true) => JoinType::Right, (true, false) => JoinType::Left, @@ -1492,16 +1513,15 @@ impl MergeInsertJob { .map_err(crate::Error::from)?; let source_df_aliased = source_df.alias("source")?; let scan_aliased = scan.alias("target")?; + // Build the join as source.join(target) so the (typically small) source + // is the hash join's build side and the (potentially huge) target is + // streamed as the probe side. See `create_plan_join_type` for why this + // orientation is kept by the optimizer and avoids materializing the + // whole target per partition. let join_type = self.create_plan_join_type(); let dataset_schema: Schema = self.dataset.schema().into(); - let mut df = scan_aliased - .join( - source_df_aliased, - join_type, - &on_cols_refs, - &on_cols_refs, - None, - )? + let mut df = source_df_aliased + .join(scan_aliased, join_type, &on_cols_refs, &on_cols_refs, None)? .with_column( MERGE_ACTION_COLUMN, merge_insert_action(&self.params, Some(&dataset_schema))?, @@ -5436,7 +5456,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, \ row_id=true, row_addr=true, full_filter=--, refine_filter=-- @@ -5484,7 +5504,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=DoNothing, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=-- RepartitionExec... @@ -5531,7 +5551,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=UpdateIf(source.value > 20), when_not_matched=DoNothing, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=-- RepartitionExec... @@ -5585,7 +5605,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=DoNothing, when_not_matched=InsertAll, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=-- RepartitionExec... @@ -5596,6 +5616,93 @@ mod tests { .unwrap(); } + /// Regression test for the join build-side orientation at scale. + /// + /// When the target exceeds DataFusion's hash-join collect threshold + /// (`hash_join_single_partition_threshold_rows`, 128K), the join is planned + /// as `mode=Partitioned`, and the build (left) side is hashed in memory per + /// partition. `create_plan` must keep the small source on the build side and + /// stream the large target as the probe side; otherwise the entire target is + /// materialized per partition, which can exhaust memory on large tables. + /// + /// The toy-sized plan-snapshot tests above cannot catch a regression here: + /// at small scale the join is `mode=CollectLeft` and the optimizer freely + /// swaps the sides, so they would still pass with the operands reversed. + /// This test exercises the production-representative `Partitioned` plan and + /// asserts the target (`LanceRead`) is the right/probe input. + #[tokio::test] + async fn test_plan_keeps_target_on_probe_side_at_scale() { + use datafusion::physical_plan::{displayable, joins::HashJoinExec, joins::PartitionMode}; + + // Target with > 128K rows so the target cannot be collected and the + // optimizer plans a Partitioned (not CollectLeft) hash join. + let data = lance_datagen::gen_batch() + .with_seed(Seed::from(1)) + .col("value", array::step::()) + .col("key", array::step::()); + let data = data.into_reader_rows(RowCount::from(50_000), BatchCount::from(8)); // 400K rows + let ds = Dataset::write(data, "memory://", None).await.unwrap(); + + let job = + crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()]) + .unwrap() + .when_matched(crate::dataset::WhenMatched::UpdateAll) + .when_not_matched(crate::dataset::WhenNotMatched::InsertAll) + .try_build() + .unwrap(); + + // A small source — the side that should be hashed/built. + let new_data = lance_datagen::gen_batch() + .with_seed(Seed::from(2)) + .col("value", array::step::()) + .col("key", array::step::()); + let new_data = new_data.into_reader_rows(RowCount::from(1000), BatchCount::from(1)); + let stream = reader_to_stream(Box::new(new_data)); + let plan = job.create_plan(stream).await.unwrap(); + + // Locate the HashJoinExec in the physical plan. + fn find_hash_join(plan: &Arc) -> Option> { + if plan.as_any().is::() { + return Some(plan.clone()); + } + for child in plan.children() { + if let Some(found) = find_hash_join(child) { + return Some(found); + } + } + None + } + + let rendered = format!("{}", displayable(plan.as_ref()).indent(true)); + let hash_join = find_hash_join(&plan) + .unwrap_or_else(|| panic!("expected a HashJoinExec in the plan:\n{rendered}")); + let hash_join = hash_join + .as_any() + .downcast_ref::() + .expect("HashJoinExec"); + + // At this scale the join must be Partitioned, not CollectLeft. + assert_eq!( + hash_join.partition_mode(), + &PartitionMode::Partitioned, + "expected a Partitioned hash join at scale; plan was:\n{rendered}" + ); + + // The target scan must be the right (probe) input, not the left (build) + // input — that is the whole point of building source.join(target). + let right_has_lance_scan = + format!("{}", displayable(hash_join.right().as_ref()).indent(true)) + .contains("LanceRead"); + let left_has_lance_scan = + format!("{}", displayable(hash_join.left().as_ref()).indent(true)) + .contains("LanceRead"); + assert!( + right_has_lance_scan && !left_has_lance_scan, + "target (LanceRead) must be the probe (right) side of the hash join so it \ + is streamed rather than materialized; plan was:\n{rendered}" + ); + } + #[tokio::test] async fn test_skip_auto_cleanup() { let tmpdir = TempStrDir::default();