From 4a31cea63e232b93ac985949f9f66a78edf4f1f4 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 13:29:16 -0700 Subject: [PATCH 1/7] perf(merge_insert): defer reading non-source columns via late-materialization rule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Partial-schema upserts fill missing (non-source) columns from the target side of the join via `col("target.")` above the join, so projection pushdown keeps those columns in the target scan — reading wide columns for every target row even when only a few rows match. This adds a late-materialization physical optimizer rule that drops those carried-through columns from the target scan and re-fetches them by `_rowaddr` with a `TakeExec` inserted above the join. The parent projection (including the `__action` expression) is re-indexed by name onto the take, so the plan's output schema is unchanged. The rule is applied only at the merge_insert call site, not the session-wide optimizer, to bound its blast radius. A width/storage gate (reusing the scanner's late-materialization heuristic) only defers columns wide enough that re-fetching beats scanning. `TakeExec` is now null-tolerant: outer-join insert rows have a null `target._rowaddr`, so the take fetches only the non-null addresses and scatters results back with NULLs at the insert positions. The taken fields are marked nullable when the take sits above an outer join. Read-side only; write-side amplification stays on #4193. Closes #7363 --- rust/lance/src/dataset/write/merge_insert.rs | 384 +++++++++++++++++- rust/lance/src/io/exec.rs | 2 + .../lance/src/io/exec/late_materialization.rs | 329 +++++++++++++++ rust/lance/src/io/exec/take.rs | 229 ++++++++++- 4 files changed, 927 insertions(+), 17 deletions(-) create mode 100644 rust/lance/src/io/exec/late_materialization.rs diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index b14421c963f..3b696f5d2f1 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -1547,6 +1547,15 @@ impl MergeInsertJob { .create_physical_plan(&logical_plan, &session_state) .await?; + // Defer reading non-source columns: a partial-schema upsert reads the + // missing columns from the target side of the join only to rewrite full + // rows. This rule pushes those reads past the join so a selective match + // does not scan wide columns for every target row. It is applied only + // here (not in the session-wide optimizer) to bound its blast radius. + use datafusion::physical_optimizer::PhysicalOptimizerRule; + let physical_plan = crate::io::exec::LateMaterializeOverReducingJoin + .optimize(physical_plan, &datafusion::config::ConfigOptions::default())?; + Ok(physical_plan) } @@ -4521,20 +4530,379 @@ mod tests { "expected HashJoinExec in plan, got: {}", plan ); - // Evidence that the partial-schema fix is active: the target - // side of the join reads the `other` column (which is missing - // from the source) and an explicit projection carries it - // through to the write exec alongside source columns. + // Late materialization is active: the `other` column (missing from + // the source) is *not* read by the target scan. Instead it is + // fetched by a `Take` inserted above the join, so a selective match + // does not scan `other` for every target row. + assert!( + plan.contains("LanceRead") && plan.contains("projection=[key]"), + "target-side scan should only read the join key, not `other`: {}", + plan + ); + assert!( + !plan.contains("projection=[other"), + "deferred `other` column must not be in the target scan projection: {}", + plan + ); + assert!( + plan.contains("Take") && plan.contains("(other)"), + "expected a Take above the join fetching the deferred `other` column: {}", + plan + ); + } + + /// Extract the `output_rows` metric of the first plan node whose + /// (trimmed) display line starts with `node_prefix`, from an + /// `analyze_plan` string. + fn output_rows_for_node(analysis: &str, node_prefix: &str) -> Option { + let line = analysis + .lines() + .find(|l| l.trim_start().starts_with(node_prefix))?; + let start = line.find("output_rows=")? + "output_rows=".len(); + let rest = &line[start..]; + let end = rest + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(rest.len()); + rest[..end].parse().ok() + } + + /// The read-amplification payoff: on a selective partial-schema update, + /// the wide non-source column must be fetched only for the matched rows + /// (via the `Take`), not scanned for every target row. + #[tokio::test] + async fn test_merge_insert_subcols_defers_wide_column_reads() { + // 100 rows across 4 fragments; `other` is a wide (string) column, + // `key`/`value` are narrow. Keys are 0..100 so matches are precise. + let batch = lance_datagen::gen_batch() + .with_seed(Seed::from(1)) + .col("other", array::rand_utf8(64.into(), false)) + .col("value", array::step::()) + .col("key", array::step_custom::(0, 1)) + .into_batch_rows(RowCount::from(100)) + .unwrap(); + let schema = batch.schema(); + let ds = Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + Some(WriteParams { + max_rows_per_file: 25, + ..Default::default() + }), + ) + .await + .unwrap(); + let ds = Arc::new(ds); + + // Partial-schema source (key, value) updating only 5 of 100 keys. + let update_schema = Arc::new(schema.project(&[2, 1]).unwrap()); + let new_data = RecordBatch::try_new( + update_schema, + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3, 4])), + Arc::new(UInt32Array::from(vec![1000u32, 1001, 1002, 1003, 1004])), + ], + ) + .unwrap(); + + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + + let source = reader_to_stream(Box::new(RecordBatchIterator::new( + [Ok(new_data.clone())], + new_data.schema(), + ))); + let analysis = job.analyze_plan(source).await.unwrap(); + + // The target scan reads every row but only the narrow key column... + let lance_read = analysis + .lines() + .find(|l| l.trim_start().starts_with("LanceRead")) + .unwrap_or_else(|| panic!("no LanceRead node:\n{}", analysis)); + assert!( + lance_read.contains("projection=[key]"), + "target scan must defer the wide `other` column: {}", + lance_read + ); + assert_eq!( + output_rows_for_node(&analysis, "LanceRead"), + Some(100), + "target scan should still visit all rows:\n{}", + analysis + ); + + // ...and `other` is materialized only for the 5 matched rows. + assert_eq!( + output_rows_for_node(&analysis, "Take"), + Some(5), + "wide column should be taken for only the matched rows:\n{}", + analysis + ); + } + + /// A partial-schema `UpdateIf` whose condition references the deferred + /// (non-source) target column must still evaluate correctly: the column + /// is fetched once by the `Take` and read from there by the action + /// expression, not double-fetched or lost. + #[tokio::test] + async fn test_merge_insert_subcols_update_if_on_deferred_column() { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + Field::new("other", DataType::Utf8, true), + ])); + // `other` gates the update; keys 0,2,4 are "keep", 1,3,5 are "skip". + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3, 4, 5])), + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3, 4, 5])), + Arc::new(StringArray::from(vec![ + "keep", "skip", "keep", "skip", "keep", "skip", + ])), + ], + ) + .unwrap(); + let ds = Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + Some(WriteParams { + max_rows_per_file: 3, // two fragments + ..Default::default() + }), + ) + .await + .unwrap(); + let ds = Arc::new(ds); + + // Partial-schema source (key, value) matching keys 0..=3. + let source_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + ])); + let new_data = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![100u32, 100, 100, 100])), + ], + ) + .unwrap(); + let reader = Box::new(RecordBatchIterator::new( + [Ok(new_data.clone())], + new_data.schema(), + )); + + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::update_if(&ds, "target.other = 'keep'").unwrap()) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + let (updated, _stats) = job.execute_reader(reader).await.unwrap(); + + let batch = updated.scan().try_into_batch().await.unwrap(); + let keys = batch["key"].as_any().downcast_ref::().unwrap(); + let values = batch["value"] + .as_any() + .downcast_ref::() + .unwrap(); + let others = batch["other"] + .as_any() + .downcast_ref::() + .unwrap(); + let by_key = (0..batch.num_rows()) + .map(|i| { + ( + keys.value(i), + (values.value(i), others.value(i).to_string()), + ) + }) + .collect::>(); + + // Matched + condition true (other == "keep"): value updated to 100. + assert_eq!(by_key[&0], (100, "keep".to_string())); + assert_eq!(by_key[&2], (100, "keep".to_string())); + // Matched + condition false (other == "skip"): value unchanged. + assert_eq!(by_key[&1], (1, "skip".to_string())); + assert_eq!(by_key[&3], (3, "skip".to_string())); + // Unmatched rows untouched. + assert_eq!(by_key[&4], (4, "keep".to_string())); + assert_eq!(by_key[&5], (5, "skip".to_string())); + } + + /// The width gate's negative branch: a *narrow* missing column must NOT + /// be deferred — a sequential scan of it is cheaper than a per-row take, + /// so it stays in the target scan and no `Take` is introduced. + #[tokio::test] + async fn test_merge_insert_subcols_narrow_column_not_deferred() { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + Field::new("small", DataType::UInt32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![10u32, 11, 12, 13])), + ], + ) + .unwrap(); + let ds = Arc::new( + Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + None, + ) + .await + .unwrap(), + ); + + // Source omits the narrow `small` column. + let source_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + ])); + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + let plan = job.explain_plan(Some(&source_schema), false).await.unwrap(); + assert!( - plan.contains("LanceRead") && plan.contains("projection=[other"), - "target-side scan should include the filled `other` column: {}", + !plan.contains("Take"), + "a narrow column must not be deferred via a Take: {}", plan ); + let lance_read = plan + .lines() + .find(|l| l.trim_start().starts_with("LanceRead")) + .unwrap_or_else(|| panic!("no LanceRead node: {}", plan)); + assert!( + lance_read.contains("small"), + "narrow `small` should be read directly by the target scan: {}", + lance_read + ); + } + + /// Deferral must remain correct when more than one wide column is + /// dropped from the scan: this exercises the multi-column index remap of + /// the join output and the name-based re-index of the parent projection. + #[tokio::test] + async fn test_merge_insert_subcols_defers_multiple_wide_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + Field::new("wide_a", DataType::Utf8, true), + Field::new("wide_b", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(StringArray::from(vec!["a0", "a1", "a2", "a3"])), + Arc::new(StringArray::from(vec!["b0", "b1", "b2", "b3"])), + ], + ) + .unwrap(); + let ds = Arc::new( + Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + Some(WriteParams { + max_rows_per_file: 2, // two fragments + ..Default::default() + }), + ) + .await + .unwrap(), + ); + + // Source omits both wide columns; updates keys 0 and 1. + let source_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + ])); + let new_data = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1])), + Arc::new(UInt32Array::from(vec![100u32, 100])), + ], + ) + .unwrap(); + + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + + // Both wide columns are deferred to the take, not the scan. + let plan = job + .explain_plan(Some(&new_data.schema().as_ref().clone()), false) + .await + .unwrap(); + let lance_read = plan + .lines() + .find(|l| l.trim_start().starts_with("LanceRead")) + .unwrap_or_else(|| panic!("no LanceRead node: {}", plan)); + assert!( + !lance_read.contains("wide_a") && !lance_read.contains("wide_b"), + "both wide columns must be deferred out of the scan: {}", + lance_read + ); assert!( - plan.contains("other@0 as other"), - "expected post-join projection to carry `other` from the target side: {}", + plan.contains("(wide_a)") && plan.contains("(wide_b)"), + "the take must fetch both deferred columns: {}", plan ); + + // And the result is correct: matched rows updated, wide columns preserved. + let reader = Box::new(RecordBatchIterator::new( + [Ok(new_data.clone())], + new_data.schema(), + )); + let (updated, _stats) = job.execute_reader(reader).await.unwrap(); + let batch = updated.scan().try_into_batch().await.unwrap(); + let keys = batch["key"].as_any().downcast_ref::().unwrap(); + let values = batch["value"] + .as_any() + .downcast_ref::() + .unwrap(); + let wide_a = batch["wide_a"] + .as_any() + .downcast_ref::() + .unwrap(); + let wide_b = batch["wide_b"] + .as_any() + .downcast_ref::() + .unwrap(); + let by_key = (0..batch.num_rows()) + .map(|i| { + ( + keys.value(i), + ( + values.value(i), + wide_a.value(i).to_string(), + wide_b.value(i).to_string(), + ), + ) + }) + .collect::>(); + assert_eq!(by_key[&0], (100, "a0".to_string(), "b0".to_string())); + assert_eq!(by_key[&1], (100, "a1".to_string(), "b1".to_string())); + assert_eq!(by_key[&2], (2, "a2".to_string(), "b2".to_string())); + assert_eq!(by_key[&3], (3, "a3".to_string(), "b3".to_string())); } /// Partial-schema upserts with `insert_not_matched=InsertAll` must diff --git a/rust/lance/src/io/exec.rs b/rust/lance/src/io/exec.rs index a477d60d56d..3631db4e359 100644 --- a/rust/lance/src/io/exec.rs +++ b/rust/lance/src/io/exec.rs @@ -15,6 +15,7 @@ pub mod filtered_read; pub mod filtered_read_proto; pub mod fts; pub(crate) mod knn; +mod late_materialization; mod optimizer; mod projection; mod pushdown_scan; @@ -32,6 +33,7 @@ pub use filter::LanceFilterExec; pub use knn::{ANNIvfPartitionExec, ANNIvfSubIndexExec, KNNVectorDistanceExec}; pub use lance_datafusion::planner::Planner; pub use lance_index::scalar::expression::FilterPlan; +pub use late_materialization::LateMaterializeOverReducingJoin; pub use optimizer::get_physical_optimizer; pub use projection::project; pub use pushdown_scan::{LancePushdownScanExec, ScanConfig}; diff --git a/rust/lance/src/io/exec/late_materialization.rs b/rust/lance/src/io/exec/late_materialization.rs new file mode 100644 index 00000000000..76689f47958 --- /dev/null +++ b/rust/lance/src/io/exec/late_materialization.rs @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Late-materialization physical optimizer rule. +//! +//! Defers reading data columns that a row-reducing operator only carries +//! through, fetching them by `_rowaddr` after the row count has shrunk. Used by +//! `merge_insert` to avoid scanning wide non-source columns for every target +//! row of a selective partial-schema upsert. + +use std::collections::HashSet; +use std::sync::Arc; + +use arrow_schema::Schema as ArrowSchema; +use datafusion::{ + common::tree_node::{Transformed, TreeNode}, + config::ConfigOptions, + error::Result as DFResult, + logical_expr::JoinType, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{ + ExecutionPlan, + joins::HashJoinExec, + projection::{ProjectionExec, ProjectionExpr}, + }, +}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, expressions::Column}; +use lance_arrow::DataTypeExt; +use lance_core::datatypes::OnMissing; +use lance_core::{ROW_ADDR, ROW_ID}; + +use super::TakeExec; +use super::filtered_read::FilteredReadExec; + +/// Rewrite every [`Column`] in `expr` to reference `schema` by name. Used to +/// re-index a projection's expressions after the column layout of its input +/// changed (e.g. a column moved because it is now sourced from a [`TakeExec`]). +fn reindex_columns_by_name( + expr: Arc, + schema: &ArrowSchema, +) -> DFResult> { + Ok(expr + .transform_down(|e| { + if let Some(col) = e.as_any().downcast_ref::() { + let new_col = Column::new_with_schema(col.name(), schema)?; + Ok(Transformed::yes(Arc::new(new_col) as Arc)) + } else { + Ok(Transformed::no(e)) + } + })? + .data) +} + +/// Width/storage gate mirroring the scanner's late-materialization heuristic +/// ([`crate::dataset::scanner::MaterializationStyle::Heuristic`]): a column is +/// worth deferring only if it is "wide" for the backing storage — a +/// variable-width type (strings, lists, vectors) or a fixed-width type above the +/// per-row byte threshold (1KB on cloud storage, 10 bytes on local). Narrow +/// columns are cheaper to read in the sequential scan than to re-fetch by +/// address. +/// +/// Without a join-cardinality estimate (tracked in #4583) we cannot gate on +/// match selectivity, so we fall back to width alone. This covers the +/// inherently selective backfill case the feature targets; a follow-up can +/// incorporate cardinality once it is available. +fn is_wide_column(field: &lance_core::datatypes::Field, is_cloud: bool) -> bool { + if field.is_blob() { + return false; + } + let byte_width = field.data_type().byte_width_opt(); + if is_cloud { + byte_width.is_none_or(|bw| bw >= 1000) + } else { + byte_width.is_none_or(|bw| bw >= 10) + } +} + +/// Late-materialization rule: defer reading data columns that a row-reducing +/// operator (here, a [`HashJoinExec`]) only passes through, fetching them by +/// `_rowaddr` *after* the row count has been reduced. +/// +/// Concretely, for a `ProjectionExec -> HashJoinExec` where the join's build +/// (left) side is a [`FilteredReadExec`] that emits `_rowaddr`, any data column +/// the scan reads but the join only carries through (not a join key, not used +/// in a join filter) is dropped from the scan, re-fetched by a [`TakeExec`] +/// inserted above the join, and the parent projection is re-indexed to read it +/// from there. The projection's *output* schema is unchanged, so nothing above +/// it is affected. +/// +/// This is written generically but is currently applied only at the +/// merge_insert call site (not registered in the session-wide optimizer), which +/// bounds its blast radius. A column missing from the source of a partial-schema +/// upsert is exactly such a "carried-through" column, so deferring it avoids +/// scanning wide columns for every target row when only a few rows match. +#[derive(Debug, Default)] +pub struct LateMaterializeOverReducingJoin; + +impl LateMaterializeOverReducingJoin { + /// Attempt the rewrite for a `ProjectionExec` sitting directly above a + /// `HashJoinExec`. Returns the rewritten projection, or `None` if the + /// pattern does not apply or cannot be safely transformed. + fn try_defer(proj: &ProjectionExec) -> DFResult>> { + let Some(join) = proj.input().as_any().downcast_ref::() else { + return Ok(None); + }; + + // Only column-preserving join types have a (left ++ right) intermediate + // schema, which the index remapping below relies on. + if !matches!( + join.join_type(), + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) { + return Ok(None); + } + + // A join filter may reference the columns we want to defer; bail rather + // than reason about its intermediate schema. + if join.filter().is_some() { + return Ok(None); + } + + // We only handle deferral on the build (left) side, which is where the + // scanned target relation sits in a merge_insert plan. Deferring on the + // probe side would need a symmetric (and untested) index remapping. + let Some(scan) = join.left().as_any().downcast_ref::() else { + return Ok(None); + }; + + let left_schema = join.left().schema(); + let right_schema = join.right().schema(); + let left_field_count = left_schema.fields().len(); + let right_field_count = right_schema.fields().len(); + + // The scan must emit `_rowaddr` so the deferred columns remain + // fetchable by address after the join. + if left_schema.column_with_name(ROW_ADDR).is_none() { + return Ok(None); + } + + // Join keys on the build side must stay in the scan. + let mut left_key_names = HashSet::new(); + for (left, _) in join.on() { + let Some(col) = left.as_any().downcast_ref::() else { + return Ok(None); + }; + left_key_names.insert(col.name().to_string()); + } + + let dataset = scan.dataset(); + let is_cloud = dataset.object_store.is_cloud(); + + // Candidates = scan-side data columns that aren't join keys (and aren't + // the system `_rowid`/`_rowaddr` columns). These are only used above the + // join, so they could be fetched after the row count shrinks. + let candidate_names = left_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .filter(|name| { + name != ROW_ADDR + && name != ROW_ID + && !left_key_names.contains(name) + && dataset.schema().field(name).is_some() + }) + .collect::>(); + if candidate_names.is_empty() { + return Ok(None); + } + + // Width/storage gate: only defer columns wide enough that re-fetching + // by address beats scanning them for every target row. + let deferred_names = candidate_names + .iter() + .filter(|name| { + dataset + .schema() + .field(name) + .is_some_and(|f| is_wide_column(f, is_cloud)) + }) + .cloned() + .collect::>(); + if deferred_names.is_empty() { + tracing::debug!( + candidates = ?candidate_names, + is_cloud, + "merge_insert late-materialization skipped: no candidate column is wide enough to defer", + ); + return Ok(None); + } + let deferred_set = deferred_names.iter().cloned().collect::>(); + + // Map each old intermediate (left ++ right) column index to its index + // after the deferred columns are dropped from the left side. + let deferred_left_indices = left_schema + .fields() + .iter() + .enumerate() + .filter_map(|(i, f)| deferred_set.contains(f.name()).then_some(i)) + .collect::>(); + let mut old_to_new = vec![None; left_field_count + right_field_count]; + let mut new_left_len = 0; + for (i, slot) in old_to_new.iter_mut().enumerate().take(left_field_count) { + if !deferred_left_indices.contains(&i) { + *slot = Some(new_left_len); + new_left_len += 1; + } + } + for j in 0..right_field_count { + old_to_new[left_field_count + j] = Some(new_left_len + j); + } + + // Narrow the scan: drop the deferred columns, keep `_rowaddr`. + let mut narrowed = scan.options().projection.clone(); + narrowed = narrowed.subtract_predicate(|f| deferred_set.contains(&f.name)); + narrowed = narrowed.with_row_addr(); + let new_scan = Arc::new(FilteredReadExec::try_new( + dataset.clone(), + scan.options().clone().with_projection(narrowed), + scan.index_input().cloned(), + )?) as Arc; + + // Rebuild the join with the narrowed left child: re-index the keys and + // the join's output projection, dropping the deferred columns. + let new_on = join + .on() + .iter() + .map(|(left, right)| { + let col = left.as_any().downcast_ref::().unwrap(); + let new_idx = old_to_new[col.index()].expect("join key must not be deferred"); + ( + Arc::new(Column::new(col.name(), new_idx)) as PhysicalExprRef, + right.clone(), + ) + }) + .collect::>(); + let new_join_projection = join.projection.as_ref().map(|p| { + p.iter() + .filter_map(|&idx| old_to_new[idx]) + .collect::>() + }); + let new_join = HashJoinExec::try_new( + new_scan, + join.right().clone(), + new_on, + None, + join.join_type(), + new_join_projection, + *join.partition_mode(), + join.null_equality(), + join.null_aware, + )?; + + // Defensive: `_rowaddr` must survive into the take's input. + let join_schema = new_join.schema(); + if join_schema.column_with_name(ROW_ADDR).is_none() { + return Ok(None); + } + // If the join emits duplicate column names — e.g. both the left and + // right join keys survive because the join has no output projection — + // we can neither build the take's (lance) schema nor re-index the + // parent projection by name. Leave such plans untransformed. + let mut seen = HashSet::with_capacity(join_schema.fields().len()); + if join_schema.fields().iter().any(|f| !seen.insert(f.name())) { + return Ok(None); + } + let join_input = Arc::new(new_join) as Arc; + + // Insert the take that re-fetches the deferred columns by `_rowaddr`. + // For an outer join, unmatched (insert) rows have a null `_rowaddr` and + // must yield NULL deferred values, so the taken fields are nullable. + let mut take_projection = dataset.empty_projection(); + for name in &deferred_names { + take_projection = take_projection.union_column(name, OnMissing::Error)?; + } + let scan_side_null_extended = matches!(join.join_type(), JoinType::Right | JoinType::Full); + let take = if scan_side_null_extended { + TakeExec::try_new_nullable_extra(dataset.clone(), join_input, take_projection)? + } else { + TakeExec::try_new(dataset.clone(), join_input, take_projection)? + }; + let Some(take) = take else { + return Ok(None); + }; + let take = Arc::new(take) as Arc; + + // Re-index the parent projection's expressions onto the take output. + // Post-join column names are unique, so name-based reindexing is safe. + let take_schema = take.schema(); + let new_exprs = proj + .expr() + .iter() + .map(|pe| { + Ok(ProjectionExpr { + expr: reindex_columns_by_name(pe.expr.clone(), take_schema.as_ref())?, + alias: pe.alias.clone(), + }) + }) + .collect::>>()?; + let new_proj = ProjectionExec::try_new(new_exprs, take)?; + Ok(Some(Arc::new(new_proj))) + } +} + +impl PhysicalOptimizerRule for LateMaterializeOverReducingJoin { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> DFResult> { + Ok(plan + .transform_down(|plan| { + if let Some(proj) = plan.as_any().downcast_ref::() + && let Some(rewritten) = Self::try_defer(proj)? + { + return Ok(Transformed::yes(rewritten)); + } + Ok(Transformed::no(plan)) + })? + .data) + } + + fn name(&self) -> &str { + "late_materialize_over_reducing_join" + } + + fn schema_check(&self) -> bool { + true + } +} diff --git a/rust/lance/src/io/exec/take.rs b/rust/lance/src/io/exec/take.rs index c3642cdb043..9c947be9e08 100644 --- a/rust/lance/src/io/exec/take.rs +++ b/rust/lance/src/io/exec/take.rs @@ -219,11 +219,33 @@ impl TakeStream { let row_addrs = row_addrs_arr.as_primitive::(); - debug_assert!( - row_addrs.null_count() == 0, - "{} nulls in row addresses", - row_addrs.null_count() - ); + // Rows with a null address have no row in the dataset to fetch. This + // happens for insert (source-only) rows in an outer-join take, where + // the target side contributes no `_rowaddr`. We fetch only the + // non-null addresses and, after reading, scatter the fetched columns + // back into the original row positions, leaving the taken columns NULL + // for the null-address rows. + let null_count = row_addrs.null_count(); + let (row_addrs_arr, scatter_indices): (Arc, Option) = + if null_count > 0 { + let mut compacted = Vec::with_capacity(row_addrs.len() - null_count); + let mut scatter = Vec::with_capacity(row_addrs.len()); + for addr in row_addrs.iter() { + if let Some(addr) = addr { + scatter.push(Some(compacted.len() as u32)); + compacted.push(addr); + } else { + scatter.push(None); + } + } + ( + Arc::new(UInt64Array::from(compacted)), + Some(UInt32Array::from(scatter)), + ) + } else { + (row_addrs_arr, None) + }; + let row_addrs = row_addrs_arr.as_primitive::(); // Fast path: check if addresses are already sorted with no duplicates (common case). // This avoids all sorting, dedup, and permutation overhead. @@ -319,7 +341,19 @@ impl TakeStream { let batches = futures.try_collect::>().await?; if batches.is_empty() { - return Ok(RecordBatch::new_empty(self.output_schema.clone())); + match scatter_indices { + // Genuinely empty input (no rows at all). + None => return Ok(RecordBatch::new_empty(self.output_schema.clone())), + // Every address was null (e.g. an all-insert take): emit the + // input rows with NULL taken columns. + Some(scatter) => { + let empty = RecordBatch::new_empty(Arc::new(ArrowSchema::from( + self.fields_to_take.as_ref(), + ))); + let new_data = scatter_taken(&empty, &scatter)?; + return Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?); + } + } } let _compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer(); @@ -355,6 +389,14 @@ impl TakeStream { (None, None) => {} } + // Scatter the fetched columns back to the caller's original row count, + // inserting NULLs for rows whose address was null. + let new_data = if let Some(scatter) = scatter_indices { + scatter_taken(&new_data, &scatter)? + } else { + new_data + }; + Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?) } @@ -399,6 +441,32 @@ impl TakeStream { } } +/// Expand a batch of fetched ("taken") columns back to the caller's original +/// row count using `scatter`: a map from each original row index to the +/// position of that row within the fetched set, or null for rows that had no +/// address (and therefore were not fetched). +/// +/// The taken columns are returned as nullable, since null-address rows yield +/// NULL values that a non-nullable field could not hold. +fn scatter_taken(taken: &RecordBatch, scatter: &UInt32Array) -> DataFusionResult { + // Take per-column rather than via `take_record_batch` so we can attach a + // nullable schema before constructing the batch; otherwise `RecordBatch` + // validation would reject the new NULLs in non-nullable columns. + let columns = taken + .columns() + .iter() + .map(|c| arrow::compute::take(c, scatter, None)) + .collect::>>()?; + let fields = taken + .schema() + .fields() + .iter() + .map(|f| Arc::new(f.as_ref().clone().with_nullable(true))) + .collect::>(); + let schema = Arc::new(ArrowSchema::new(fields)); + Ok(RecordBatch::try_new(schema, columns)?) +} + #[derive(Debug)] pub struct TakeExec { // The dataset to take from @@ -415,6 +483,11 @@ pub struct TakeExec { input: Arc, properties: Arc, metrics: ExecutionPlanMetricsSet, + // When true, the taken (extra) fields are marked nullable in the output + // schema even if the dataset declares them non-null. This is used when the + // take sits above an outer join, where rows with a null address (inserts) + // legitimately produce NULL taken values. + nullable_extra_fields: bool, } impl DisplayAs for TakeExec { @@ -462,6 +535,27 @@ impl TakeExec { dataset: Arc, input: Arc, projection: Projection, + ) -> Result> { + Self::try_new_impl(dataset, input, projection, false) + } + + /// Like [`TakeExec::try_new`], but marks the taken (extra) fields nullable + /// in the output schema. Use this when the take sits above an outer join, + /// where rows with a null address (inserts) yield NULL taken values that a + /// non-nullable field could not hold. + pub fn try_new_nullable_extra( + dataset: Arc, + input: Arc, + projection: Projection, + ) -> Result> { + Self::try_new_impl(dataset, input, projection, true) + } + + fn try_new_impl( + dataset: Arc, + input: Arc, + projection: Projection, + nullable_extra_fields: bool, ) -> Result> { let original_projection = projection.clone(); let projection = @@ -492,7 +586,30 @@ impl TakeExec { &input.schema(), &projection, )); - let output_arrow = Arc::new(ArrowSchema::from(output_schema.as_ref())); + let schema_to_take = projection.into_schema_ref(); + let output_arrow = ArrowSchema::from(output_schema.as_ref()); + let output_arrow = if nullable_extra_fields { + let extra_names = schema_to_take + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(); + let fields = output_arrow + .fields() + .iter() + .map(|f| { + if extra_names.contains(f.name().as_str()) && !f.is_nullable() { + Arc::new(f.as_ref().clone().with_nullable(true)) + } else { + f.clone() + } + }) + .collect::>(); + ArrowSchema::new_with_metadata(fields, output_arrow.metadata().clone()) + } else { + output_arrow + }; + let output_arrow = Arc::new(output_arrow); let properties = Arc::new( input .properties() @@ -504,11 +621,12 @@ impl TakeExec { Ok(Some(Self { dataset, output_projection: original_projection, - schema_to_take: projection.into_schema_ref(), + schema_to_take, input, output_schema: output_arrow, properties, metrics: ExecutionPlanMetricsSet::new(), + nullable_extra_fields, })) } @@ -609,7 +727,12 @@ impl ExecutionPlan for TakeExec { let projection = self.output_projection.clone(); - let plan = Self::try_new(self.dataset.clone(), children[0].clone(), projection)?; + let plan = Self::try_new_impl( + self.dataset.clone(), + children[0].clone(), + projection, + self.nullable_extra_fields, + )?; if let Some(plan) = plan { Ok(Arc::new(plan)) @@ -1094,6 +1217,94 @@ mod tests { assert_eq!(s_col.value(1), s_col.value(3)); // both row 0 } + /// A null row address means "no row to fetch" (e.g. an insert row in an + /// outer-join take). The taken column must come back NULL for that row and + /// the row must be preserved (not dropped). + #[tokio::test] + async fn test_take_with_null_row_addrs() { + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; + + let row_addrs = UInt64Array::from(vec![Some(0u64), None, Some(2), None, Some(1)]); + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + ROW_ADDR, + DataType::UInt64, + true, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(row_addrs)]).unwrap(); + let stream = futures::stream::iter(vec![Ok(batch)]); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + let input = Arc::new(OneShotExec::new(stream)); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); + + let stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let all = concat_batches(&batches[0].schema(), &batches).unwrap(); + assert_eq!(all.num_rows(), 5); + + let s = all + .column_by_name("s") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(s.value(0), "str-0"); + assert!(s.is_null(1)); + assert_eq!(s.value(2), "str-2"); + assert!(s.is_null(3)); + assert_eq!(s.value(4), "str-1"); + } + + /// When *every* address is null (e.g. an all-insert take), all input rows + /// must be preserved with NULL taken columns. + #[tokio::test] + async fn test_take_all_null_row_addrs() { + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; + + let row_addrs = UInt64Array::from(vec![None, None, None]); + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + ROW_ADDR, + DataType::UInt64, + true, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(row_addrs)]).unwrap(); + let stream = futures::stream::iter(vec![Ok(batch)]); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + let input = Arc::new(OneShotExec::new(stream)); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); + + let stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let all = concat_batches(&batches[0].schema(), &batches).unwrap(); + assert_eq!(all.num_rows(), 3); + + let s = all.column_by_name("s").unwrap(); + assert_eq!(s.null_count(), 3); + } + #[tokio::test] async fn test_take_struct() { // When taking fields into an existing struct, the field order should be maintained From 5b9cc6989c027634827291e0f9eb1060306899b7 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 14:48:13 -0700 Subject: [PATCH 2/7] feat(merge_insert): add generic logical late-materialization rule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prototype a logical alternative to the physical `LateMaterializeOverReducingJoin` rule. A `LateTakeNode` (`UserDefinedLogicalNodeCore`) is inserted above a row-reducing join and advertises an output schema of "join columns minus deferred, plus the deferred columns appended". Its `necessary_children_exprs` reports that it does not need the deferred columns from its child (only `_rowaddr`), so DataFusion's stock `OptimizeProjections` prunes them from the scan automatically — no manual index remapping, and downstream references resolve the deferred columns from the take by name. `LateMaterializeJoin` (a logical `OptimizerRule`) detects candidates by qualifier rather than physical side, gates on actual usage above the join and the existing width heuristic, and bails on join filters or duplicate join-output names. `LateTakePlanner` lowers the node to the existing null-tolerant `TakeExec`. This is generic and unit-tested in isolation; wiring it into `merge_insert::create_plan` (and retiring the physical rule) is a follow-up. The physical rule and its tests are left untouched. Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance/src/io/exec.rs | 2 + .../lance/src/io/exec/late_materialization.rs | 2 +- rust/lance/src/io/exec/late_take.rs | 940 ++++++++++++++++++ 3 files changed, 943 insertions(+), 1 deletion(-) create mode 100644 rust/lance/src/io/exec/late_take.rs diff --git a/rust/lance/src/io/exec.rs b/rust/lance/src/io/exec.rs index 3631db4e359..d5fcef6ef24 100644 --- a/rust/lance/src/io/exec.rs +++ b/rust/lance/src/io/exec.rs @@ -16,6 +16,7 @@ pub mod filtered_read_proto; pub mod fts; pub(crate) mod knn; mod late_materialization; +mod late_take; mod optimizer; mod projection; mod pushdown_scan; @@ -34,6 +35,7 @@ pub use knn::{ANNIvfPartitionExec, ANNIvfSubIndexExec, KNNVectorDistanceExec}; pub use lance_datafusion::planner::Planner; pub use lance_index::scalar::expression::FilterPlan; pub use late_materialization::LateMaterializeOverReducingJoin; +pub use late_take::{LateMaterializeJoin, LateTakeNode, LateTakePlanner}; pub use optimizer::get_physical_optimizer; pub use projection::project; pub use pushdown_scan::{LancePushdownScanExec, ScanConfig}; diff --git a/rust/lance/src/io/exec/late_materialization.rs b/rust/lance/src/io/exec/late_materialization.rs index 76689f47958..6674a6ba10e 100644 --- a/rust/lance/src/io/exec/late_materialization.rs +++ b/rust/lance/src/io/exec/late_materialization.rs @@ -63,7 +63,7 @@ fn reindex_columns_by_name( /// match selectivity, so we fall back to width alone. This covers the /// inherently selective backfill case the feature targets; a follow-up can /// incorporate cardinality once it is available. -fn is_wide_column(field: &lance_core::datatypes::Field, is_cloud: bool) -> bool { +pub fn is_wide_column(field: &lance_core::datatypes::Field, is_cloud: bool) -> bool { if field.is_blob() { return false; } diff --git a/rust/lance/src/io/exec/late_take.rs b/rust/lance/src/io/exec/late_take.rs new file mode 100644 index 00000000000..b817e227e81 --- /dev/null +++ b/rust/lance/src/io/exec/late_take.rs @@ -0,0 +1,940 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Late-materialization *logical* optimizer rule. +//! +//! Defers reading wide data columns that a row-reducing join only carries +//! through, fetching them by `_rowaddr` *after* the row count has shrunk. +//! +//! Unlike the physical [`super::LateMaterializeOverReducingJoin`] rule (which +//! re-indexes a `HashJoinExec` and its parent projection by position), this +//! works at the logical level: a [`LateTakeNode`] is inserted above the join +//! and advertises an output schema of "join columns minus deferred, plus the +//! deferred columns appended". Its [`UserDefinedLogicalNodeCore::necessary_children_exprs`] +//! reports that it does *not* need the deferred columns from its child, only +//! `_rowaddr`. DataFusion's stock `OptimizeProjections` rule then prunes those +//! columns from the scan automatically — no manual index remapping — and +//! downstream column references resolve the deferred columns from the take by +//! name. +//! +//! The node lowers to the existing physical [`super::TakeExec`] via +//! [`LateTakePlanner`]. + +use std::collections::{BTreeSet, HashSet}; +use std::sync::Arc; + +use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; +use async_trait::async_trait; +use datafusion::{ + common::{ + DFSchema, DFSchemaRef, Result as DFResult, TableReference, + tree_node::{Transformed, TreeNode, TreeNodeRecursion}, + }, + datasource::DefaultTableSource, + execution::SessionState, + logical_expr::{Expr, Extension, Join, JoinType, LogicalPlan}, + optimizer::{OptimizerConfig, OptimizerRule}, + physical_plan::ExecutionPlan, + physical_planner::{ExtensionPlanner, PhysicalPlanner}, +}; +use datafusion_expr::{UserDefinedLogicalNode, UserDefinedLogicalNodeCore}; +use lance_core::datatypes::OnMissing; +use lance_core::{ROW_ADDR, ROW_ID}; + +use super::TakeExec; +use super::late_materialization::is_wide_column; +use crate::Dataset; +use crate::datafusion::dataframe::LanceTableProvider; + +/// Logical plan node that re-fetches `deferred_columns` from `dataset` by +/// `_rowaddr` after a row-reducing operator. +/// +/// Output schema = the input's columns with `deferred_columns` removed, +/// followed by the deferred columns appended in dataset-schema order (mirroring +/// the physical [`TakeExec`], which appends taken columns). Constructing the +/// schema this way makes it invariant under projection pushdown: whether or not +/// the child still produces a deferred column, the node advertises the same +/// output, so the rule can be inserted before pushdown prunes the scan. +#[derive(Debug)] +pub struct LateTakeNode { + input: LogicalPlan, + dataset: Arc, + /// Dataset field names to re-fetch by address, in dataset-schema order. + deferred_columns: Vec, + /// Qualifier for the appended deferred fields (e.g. `target`), matching the + /// relation the scanned columns came from. + qualifier: Option, + /// When true the deferred fields are nullable in the output even if the + /// dataset declares them non-null. Set above an outer join where the scan + /// side can be null-extended (its `_rowaddr` is null → NULL deferred value). + nullable_extra: bool, + schema: DFSchemaRef, +} + +impl PartialEq for LateTakeNode { + fn eq(&self, other: &Self) -> bool { + self.dataset.base == other.dataset.base + && self.deferred_columns == other.deferred_columns + && self.qualifier == other.qualifier + && self.nullable_extra == other.nullable_extra + && self.input == other.input + } +} + +impl Eq for LateTakeNode {} + +impl std::hash::Hash for LateTakeNode { + fn hash(&self, state: &mut H) { + self.dataset.base.hash(state); + self.deferred_columns.hash(state); + self.qualifier.hash(state); + self.nullable_extra.hash(state); + self.input.hash(state); + } +} + +impl PartialOrd for LateTakeNode { + fn partial_cmp(&self, other: &Self) -> Option { + match self.deferred_columns.partial_cmp(&other.deferred_columns) { + Some(std::cmp::Ordering::Equal) => self.input.partial_cmp(&other.input), + cmp => cmp, + } + } +} + +impl LateTakeNode { + pub fn try_new( + input: LogicalPlan, + dataset: Arc, + deferred_columns: Vec, + qualifier: Option, + nullable_extra: bool, + ) -> DFResult { + let schema = Self::build_output_schema( + &input, + &dataset, + &deferred_columns, + &qualifier, + nullable_extra, + )?; + Ok(Self { + input, + dataset, + deferred_columns, + qualifier, + nullable_extra, + schema, + }) + } + + /// Build `input columns (minus deferred) ++ deferred columns appended`. + fn build_output_schema( + input: &LogicalPlan, + dataset: &Dataset, + deferred_columns: &[String], + qualifier: &Option, + nullable_extra: bool, + ) -> DFResult { + let input_schema = input.schema(); + let deferred_set: HashSet<&str> = deferred_columns.iter().map(|s| s.as_str()).collect(); + + let mut qualified_fields: Vec<(Option, Arc)> = input_schema + .iter() + .filter(|(_, f)| !deferred_set.contains(f.name().as_str())) + .map(|(q, f)| (q.cloned(), f.clone())) + .collect(); + + let dataset_arrow = ArrowSchema::from(dataset.schema()); + for name in deferred_columns { + let field = dataset_arrow.field_with_name(name).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "late-materialization: deferred column '{name}' not found in dataset schema: {e}" + )) + })?; + let field = if nullable_extra && !field.is_nullable() { + field.clone().with_nullable(true) + } else { + field.clone() + }; + qualified_fields.push((qualifier.clone(), Arc::new(field))); + } + + Ok(Arc::new(DFSchema::new_with_metadata( + qualified_fields, + input_schema.metadata().clone(), + )?)) + } + + /// Index of the row-address (or, failing that, row-id) column in the child. + fn row_locator_index(&self) -> Option { + let input_schema = self.input.schema(); + input_schema + .index_of_column_by_name(self.qualifier.as_ref(), ROW_ADDR) + .or_else(|| input_schema.index_of_column_by_name(self.qualifier.as_ref(), ROW_ID)) + } +} + +impl UserDefinedLogicalNodeCore for LateTakeNode { + fn name(&self) -> &str { + "LateTake" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "LateTake: deferred=[{}], nullable_extra={}", + self.deferred_columns.join(", "), + self.nullable_extra + ) + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + mut inputs: Vec, + ) -> DFResult { + if !exprs.is_empty() { + return Err(datafusion::error::DataFusionError::Internal( + "LateTakeNode does not accept expressions".to_string(), + )); + } + if inputs.len() != 1 { + return Err(datafusion::error::DataFusionError::Internal( + "LateTakeNode requires exactly one input".to_string(), + )); + } + Self::try_new( + inputs.remove(0), + self.dataset.clone(), + self.deferred_columns.clone(), + self.qualifier.clone(), + self.nullable_extra, + ) + } + + /// Drive projection pushdown: the deferred columns are produced by this + /// node (fetched by address), so they are never requested from the child; + /// `_rowaddr` is always required so the fetch remains possible. + fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option>> { + let input_schema = self.input.schema(); + let deferred_set: HashSet<&str> = + self.deferred_columns.iter().map(|s| s.as_str()).collect(); + + // Output positions [0..passthrough_len) map back to these child indices, + // in order; positions beyond are the appended (fetched) deferred columns. + let passthrough: Vec = input_schema + .iter() + .enumerate() + .filter(|(_, (_, f))| !deferred_set.contains(f.name().as_str())) + .map(|(i, _)| i) + .collect(); + + let row_locator = self.row_locator_index()?; + + let mut needed = BTreeSet::new(); + for &oc in output_columns { + if let Some(child_idx) = passthrough.get(oc) { + needed.insert(*child_idx); + } + } + needed.insert(row_locator); + Some(vec![needed.into_iter().collect()]) + } +} + +/// Logical optimizer rule that inserts a [`LateTakeNode`] above a join when a +/// wide column from a Lance table relation is only carried through the join. +/// +/// Detection is qualifier-driven, not side-specific: it inspects both join +/// inputs for a [`LanceTableProvider`] scan that emits `_rowaddr`, so it keeps +/// working if build/probe sides are swapped. The actual scan narrowing is left +/// to `OptimizeProjections`, which must run after this rule. +#[derive(Debug, Default)] +pub struct LateMaterializeJoin; + +impl LateMaterializeJoin { + pub fn new() -> Self { + Self + } + + /// Recover the Lance dataset backing a join input, descending only through + /// single-input nodes (alias/projection/filter) so a nested join's scan is + /// never picked up by mistake. + fn find_lance_dataset(plan: &LogicalPlan) -> Option> { + if let LogicalPlan::TableScan(scan) = plan { + let source = scan.source.as_any().downcast_ref::()?; + let provider = source + .table_provider + .as_any() + .downcast_ref::()?; + return Some(provider.dataset()); + } + let inputs = plan.inputs(); + if inputs.len() == 1 { + Self::find_lance_dataset(inputs[0]) + } else { + None + } + } + + /// Names of the join's equi-keys on `side` (the columns that must stay in + /// the scan because the join reads them). + fn join_key_names(join: &Join, side: JoinSide) -> HashSet { + join.on + .iter() + .filter_map(|(left, right)| { + let expr = match side { + JoinSide::Left => left, + JoinSide::Right => right, + }; + match expr { + Expr::Column(col) => Some(col.name.clone()), + _ => None, + } + }) + .collect() + } + + /// Collect every `(qualifier, name)` column reference that appears in an + /// expression anywhere in `plan`. Used to tell which scan-side columns are + /// actually consumed *above* a join (and so worth re-fetching) rather than + /// merely produced by the scan. + fn collect_referenced_columns( + plan: &LogicalPlan, + ) -> DFResult, String)>> { + let mut referenced = HashSet::new(); + plan.apply(|node| { + for expr in node.expressions() { + expr.apply(|e| { + if let Expr::Column(col) = e { + referenced.insert((col.relation.clone(), col.name.clone())); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(referenced) + } + + fn try_defer_join( + join: &Join, + referenced: &HashSet<(Option, String)>, + ) -> DFResult> { + // Only column-preserving joins; a join filter may reference a column we + // would defer, so bail rather than reason about it. + if !matches!( + join.join_type, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) || join.filter.is_some() + { + return Ok(None); + } + + // The take re-fetches by `_rowaddr` and merges columns by name, so it + // cannot sit above a join whose output has duplicate field names (e.g. + // both sides' equi-keys share a name). Matches the physical rule. + let mut seen = HashSet::with_capacity(join.schema.fields().len()); + if join + .schema + .fields() + .iter() + .any(|f| !seen.insert(f.name().as_str())) + { + return Ok(None); + } + + for side in [JoinSide::Left, JoinSide::Right] { + let side_plan: &LogicalPlan = match side { + JoinSide::Left => &join.left, + JoinSide::Right => &join.right, + }; + let Some(dataset) = Self::find_lance_dataset(side_plan) else { + continue; + }; + let side_schema = side_plan.schema(); + + // The scan side must emit `_rowaddr` so deferred columns stay + // fetchable by address after the join. + if side_schema + .index_of_column_by_name(None, ROW_ADDR) + .is_none() + { + continue; + } + let qualifier = side_schema + .iter() + .find(|(_, f)| f.name() == ROW_ADDR) + .and_then(|(q, _)| q.cloned()); + + let key_names = Self::join_key_names(join, side); + let is_cloud = dataset.object_store.is_cloud(); + let dataset_arrow = ArrowSchema::from(dataset.schema()); + + // Candidates = scan-side data columns that aren't join keys nor the + // system columns, that are consumed above the join, and are wide + // enough to be worth re-fetching by address. + let mut deferred: Vec<(usize, String)> = Vec::new(); + for (col_qualifier, field) in side_schema.iter() { + let name = field.name(); + if name == ROW_ADDR || name == ROW_ID || key_names.contains(name) { + continue; + } + // Only defer columns actually used above the join; a column the + // scan produces but nobody references would just be pruned, so + // re-fetching it would be wasted work. + if !referenced.contains(&(col_qualifier.cloned(), name.clone())) { + continue; + } + let Some(ds_field) = dataset.schema().field(name) else { + continue; + }; + if !is_wide_column(ds_field, is_cloud) { + continue; + } + let Ok(ds_idx) = dataset_arrow.index_of(name) else { + continue; + }; + deferred.push((ds_idx, name.clone())); + } + if deferred.is_empty() { + continue; + } + // Append in dataset-schema order to match TakeExec's output order. + deferred.sort_by_key(|(idx, _)| *idx); + let deferred_columns: Vec = + deferred.into_iter().map(|(_, name)| name).collect(); + + // The scan side is null-extended when it is the optional side of an + // outer join; its unmatched rows then have a null `_rowaddr` and + // must yield NULL deferred values. + let nullable_extra = matches!( + (side, join.join_type), + (JoinSide::Left, JoinType::Right) + | (JoinSide::Left, JoinType::Full) + | (JoinSide::Right, JoinType::Left) + | (JoinSide::Right, JoinType::Full) + ); + + let node = LateTakeNode::try_new( + LogicalPlan::Join(join.clone()), + dataset, + deferred_columns, + qualifier, + nullable_extra, + )?; + return Ok(Some(LogicalPlan::Extension(Extension { + node: Arc::new(node), + }))); + } + + Ok(None) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum JoinSide { + Left, + Right, +} + +impl OptimizerRule for LateMaterializeJoin { + fn name(&self) -> &str { + "late_materialize_join" + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> DFResult> { + let referenced = Self::collect_referenced_columns(&plan)?; + plan.transform_down(|node| { + // Never descend into an already-deferred subtree, otherwise the + // inner join would be wrapped again on this or a later pass. + if let LogicalPlan::Extension(ext) = &node + && ext.node.as_any().is::() + { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); + } + if let LogicalPlan::Join(join) = &node + && let Some(wrapped) = Self::try_defer_join(join, &referenced)? + { + // Jump: the freshly wrapped join must not be revisited. + return Ok(Transformed::new(wrapped, true, TreeNodeRecursion::Jump)); + } + Ok(Transformed::no(node)) + }) + } +} + +/// Lowers a [`LateTakeNode`] to the physical [`TakeExec`]. +#[derive(Debug)] +pub struct LateTakePlanner; + +#[async_trait] +impl ExtensionPlanner for LateTakePlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> DFResult>> { + let Some(take_node) = node.as_any().downcast_ref::() else { + return Ok(None); + }; + assert_eq!(physical_inputs.len(), 1, "LateTake requires one input"); + let input = physical_inputs[0].clone(); + + let mut projection = take_node.dataset.empty_projection(); + for name in &take_node.deferred_columns { + projection = projection.union_column(name, OnMissing::Error)?; + } + + let take = if take_node.nullable_extra { + TakeExec::try_new_nullable_extra(take_node.dataset.clone(), input.clone(), projection)? + } else { + TakeExec::try_new(take_node.dataset.clone(), input.clone(), projection)? + }; + + // `try_new` returns None when no extra columns are needed; fall back to + // the raw input so the plan still lowers. + Ok(Some(match take { + Some(take) => Arc::new(take) as Arc, + None => input, + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow_array::{Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef}; + use datafusion::execution::{SessionStateBuilder, TaskContext}; + use datafusion::logical_expr::TableScan; + use datafusion::optimizer::{ + Optimizer, OptimizerContext, optimize_projections::OptimizeProjections, + }; + use datafusion::physical_plan::{collect, displayable}; + use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; + use datafusion::prelude::*; + use lance_core::utils::tempfile::TempStrDir; + + use crate::datafusion::dataframe::SessionContextExt; + use crate::dataset::WriteParams; + + /// Write a `{id: Int32 (key), payload: Utf8 (wide), tag: Int32 (narrow)}` + /// dataset to a temp dir and open it. + async fn test_dataset() -> (Arc, TempStrDir) { + let schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("payload", DataType::Utf8, true), + Field::new("tag", DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(StringArray::from(vec![ + "alpha", "bravo", "charlie", "delta", "echo", + ])), + Arc::new(Int32Array::from(vec![10, 20, 30, 40, 50])), + ], + ) + .unwrap(); + + let tmp = TempStrDir::default(); + let reader = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema.clone()); + Dataset::write(reader, tmp.as_str(), Some(WriteParams::default())) + .await + .unwrap(); + let dataset = Arc::new(Dataset::open(tmp.as_str()).await.unwrap()); + (dataset, tmp) + } + + /// A source DataFrame `{sid}` aliased "source". A distinct key name (vs the + /// target's `id`) keeps the join output free of duplicate names so the take + /// can sit above it. + fn source_df(ctx: &SessionContext, ids: Vec) -> DataFrame { + let schema: SchemaRef = Arc::new(ArrowSchema::new(vec![Field::new( + "sid", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap(); + ctx.read_batch(batch).unwrap().alias("source").unwrap() + } + + fn find_table_scan(plan: &LogicalPlan) -> Option<&TableScan> { + if let LogicalPlan::TableScan(scan) = plan { + return Some(scan); + } + plan.inputs().into_iter().find_map(find_table_scan) + } + + fn has_late_take(plan: &LogicalPlan) -> bool { + if let LogicalPlan::Extension(ext) = plan + && ext.node.as_any().is::() + { + return true; + } + plan.inputs().iter().any(|p| has_late_take(p)) + } + + fn scan_column_names(plan: &LogicalPlan) -> Vec { + find_table_scan(plan) + .unwrap() + .projected_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() + } + + fn run_rule_and_pushdown(plan: LogicalPlan) -> LogicalPlan { + let optimizer = Optimizer::with_rules(vec![ + Arc::new(LateMaterializeJoin::new()), + Arc::new(OptimizeProjections::new()), + ]); + optimizer + .optimize(plan, &OptimizerContext::new(), |_, _| {}) + .unwrap() + } + + /// Build `Projection(select cols) <- Join(target, source) <- scan`. + async fn join_plan( + ctx: &SessionContext, + dataset: Arc, + join_type: JoinType, + select: &[&str], + source_ids: Vec, + ) -> LogicalPlan { + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let source = source_df(ctx, source_ids); + let exprs = select.iter().map(|c| col(*c)).collect::>(); + target + .join(source, join_type, &["id"], &["sid"], None) + .unwrap() + .select(exprs) + .unwrap() + .into_unoptimized_plan() + } + + #[tokio::test] + async fn test_wide_column_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let plan = join_plan( + &ctx, + dataset, + JoinType::Inner, + &["target.id", "target.payload", "target.tag"], + vec![1, 3], + ) + .await; + + let before = plan.schema().clone(); + let optimized = run_rule_and_pushdown(plan); + + assert!( + has_late_take(&optimized), + "expected a LateTake node:\n{}", + optimized.display_indent() + ); + + let scan_cols = scan_column_names(&optimized); + assert!(scan_cols.contains(&"id".to_string()), "scan: {scan_cols:?}"); + assert!( + scan_cols.contains(&ROW_ADDR.to_string()), + "scan: {scan_cols:?}" + ); + // wide column dropped from the scan; narrow `tag` (used above the join) + // is not deferred and stays in the scan. + assert!( + !scan_cols.contains(&"payload".to_string()), + "payload should be deferred, scan: {scan_cols:?}" + ); + assert!( + scan_cols.contains(&"tag".to_string()), + "scan: {scan_cols:?}" + ); + + // The plan's output schema is unchanged by the rewrite. + assert_eq!(before.fields(), optimized.schema().fields()); + } + + #[tokio::test] + async fn test_no_wide_columns_no_take() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // Only narrow columns used above the join → nothing to defer. + let plan = join_plan( + &ctx, + dataset, + JoinType::Inner, + &["target.id", "target.tag"], + vec![1, 3], + ) + .await; + let optimized = run_rule_and_pushdown(plan); + + assert!(!has_late_take(&optimized), "no take expected"); + let scan_cols = scan_column_names(&optimized); + assert!( + scan_cols.contains(&"tag".to_string()), + "scan: {scan_cols:?}" + ); + } + + #[tokio::test] + async fn test_join_key_not_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // Join on the wide `payload` column: as a key it must stay in the scan, + // and there is no other wide column to defer. + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![Field::new( + "spayload", + DataType::Utf8, + true, + )])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![Arc::new(StringArray::from(vec!["alpha", "charlie"]))], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + let plan = target + .join(source, JoinType::Inner, &["payload"], &["spayload"], None) + .unwrap() + .select(vec![col("target.id"), col("target.payload")]) + .unwrap() + .into_unoptimized_plan(); + + let optimized = run_rule_and_pushdown(plan); + assert!(!has_late_take(&optimized), "join key must not be deferred"); + let scan_cols = scan_column_names(&optimized); + assert!( + scan_cols.contains(&"payload".to_string()), + "join key stays in scan: {scan_cols:?}" + ); + } + + #[tokio::test] + async fn test_duplicate_join_output_names_not_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // Both sides expose `payload`; the join output has a duplicate name, so + // the take (which merges by name) cannot be inserted. + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("sid", DataType::Int32, false), + Field::new("payload", DataType::Utf8, true), + ])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(Int32Array::from(vec![1, 3])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + let plan = target + .join(source, JoinType::Inner, &["id"], &["sid"], None) + .unwrap() + .select(vec![col("target.id"), col("target.payload")]) + .unwrap() + .into_unoptimized_plan(); + + let optimized = run_rule_and_pushdown(plan); + assert!( + !has_late_take(&optimized), + "duplicate join-output names must not be deferred" + ); + } + + #[tokio::test] + async fn test_outer_join_marks_deferred_nullable() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // RIGHT join: target (left input) is the null-extended side. + let plan = join_plan( + &ctx, + dataset, + JoinType::Right, + &["target.id", "target.payload"], + vec![1, 3], + ) + .await; + let optimized = run_rule_and_pushdown(plan); + assert!(has_late_take(&optimized)); + + // Locate the node and confirm the deferred field is nullable. + fn find_node(plan: &LogicalPlan) -> Option<&LateTakeNode> { + if let LogicalPlan::Extension(ext) = plan + && let Some(n) = ext.node.as_any().downcast_ref::() + { + return Some(n); + } + plan.inputs().into_iter().find_map(find_node) + } + let node = find_node(&optimized).unwrap(); + assert!(node.nullable_extra); + let payload = UserDefinedLogicalNodeCore::schema(node) + .field_with_unqualified_name("payload") + .unwrap(); + assert!(payload.is_nullable()); + } + + #[tokio::test] + async fn test_necessary_children_exprs() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let input = ctx + .read_lance_unordered(dataset.clone(), false, true) + .unwrap() + .alias("target") + .unwrap() + .into_unoptimized_plan(); + // child schema order: id(0), payload(1), tag(2), _rowaddr(3) + let rowaddr_idx = input + .schema() + .index_of_column_by_name(None, ROW_ADDR) + .unwrap(); + let node = LateTakeNode::try_new( + input, + dataset, + vec!["payload".to_string()], + Some(TableReference::bare("target")), + false, + ) + .unwrap(); + + // Output: id(0), tag(1), _rowaddr(2), payload(3, appended/fetched). + // All outputs requested → child needs id, tag, _rowaddr (not payload). + assert_eq!( + UserDefinedLogicalNodeCore::necessary_children_exprs(&node, &[0, 1, 2, 3]), + Some(vec![vec![0, 2, rowaddr_idx]]) + ); + // Only the deferred column requested → child still only needs _rowaddr. + assert_eq!( + UserDefinedLogicalNodeCore::necessary_children_exprs(&node, &[3]), + Some(vec![vec![rowaddr_idx]]) + ); + // Nothing requested → _rowaddr is still forced in. + assert_eq!( + UserDefinedLogicalNodeCore::necessary_children_exprs(&node, &[]), + Some(vec![vec![rowaddr_idx]]) + ); + } + + /// Lower the rewritten plan to physical and confirm it (a) inserts a + /// `TakeExec` and (b) produces exactly the same rows as the un-deferred plan, + /// including NULL deferred values for source-only rows of an outer join. + async fn assert_execution_parity(join_type: JoinType, source_ids: Vec) { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let state = SessionStateBuilder::new().with_default_features().build(); + + let plan = join_plan( + &ctx, + dataset, + join_type, + &["target.id", "target.payload"], + source_ids, + ) + .await; + + // Baseline: standard optimization + default planner. + let baseline_logical = state.optimize(&plan).unwrap(); + let baseline = state.create_physical_plan(&baseline_logical).await.unwrap(); + let baseline_rows = collect(baseline, Arc::new(TaskContext::default())) + .await + .unwrap(); + + // Deferred: insert the take, prune, then lower with our planner. + let optimized = run_rule_and_pushdown(plan); + let planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(LateTakePlanner)]); + let physical = planner + .create_physical_plan(&optimized, &state) + .await + .unwrap(); + let rendered = displayable(physical.as_ref()).indent(true).to_string(); + assert!( + rendered.contains("Take"), + "expected a TakeExec in the plan:\n{rendered}" + ); + let deferred_rows = collect(physical, Arc::new(TaskContext::default())) + .await + .unwrap(); + + // Compare row sets (order-independent). + let sort = |batches: &[RecordBatch]| { + let mut rows: Vec<(Option, Option)> = Vec::new(); + for b in batches { + let ids = b.column(0).as_any().downcast_ref::().unwrap(); + let payloads = b.column(1).as_any().downcast_ref::().unwrap(); + for i in 0..b.num_rows() { + rows.push(( + (!ids.is_null(i)).then(|| ids.value(i)), + (!payloads.is_null(i)).then(|| payloads.value(i).to_string()), + )); + } + } + rows.sort(); + rows + }; + assert_eq!(sort(&baseline_rows), sort(&deferred_rows)); + } + + #[tokio::test] + async fn test_execution_parity_inner() { + assert_execution_parity(JoinType::Inner, vec![1, 3]).await; + } + + #[tokio::test] + async fn test_execution_parity_outer_null_rowaddr() { + // RIGHT join with a source-only id (99) that has no target match: that + // row's `target._rowaddr` is null, so the take must scatter a NULL + // payload for it — matching the un-deferred plan. + assert_execution_parity(JoinType::Right, vec![1, 99]).await; + } +} From 440ea944e7946b0ce522957b65017a32baccff8a Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 15:16:42 -0700 Subject: [PATCH 3/7] perf(merge_insert): use logical late-materialization rule, drop physical rule Wire the generic logical `LateMaterializeJoin` rule + `LateTakePlanner` into `merge_insert::create_plan`, replacing the physical `LateMaterializeOverReducingJoin` rule (now deleted). The rule runs between two `session_state.optimize` calls: the first keeps the wide non-source columns in the target scan, the rule inserts a `LateTake` above the join, and the second prunes those columns from the scan via projection pushdown. The merge_insert join is an equi-join on a shared key name, which surfaced two issues the Stage-1 unit tests (distinct key names) did not: - `collect_referenced_columns` now skips a join's own on-clause expressions. Those keys are consumed by the join, not above it, so counting them made both sides' same-named keys look "used above the join". - The rule inserts a normalizing projection between the join and the take, selecting only the columns carried past the join (deferred columns dropped). Without it the opaque take node blocks the physical planner from giving `HashJoinExec` a tight output projection, so both same-named keys reach `TakeExec` as duplicate arrow names (qualifiers are erased at the physical level) and the lance schema build rejects the duplicate. The projection is also where scan narrowing now happens. `is_wide_column` moves into `late_take.rs` as a private fn. Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance/src/dataset/write/merge_insert.rs | 39 ++- rust/lance/src/io/exec.rs | 2 - .../lance/src/io/exec/late_materialization.rs | 329 ------------------ rust/lance/src/io/exec/late_take.rs | 133 +++++-- 4 files changed, 131 insertions(+), 372 deletions(-) delete mode 100644 rust/lance/src/io/exec/late_materialization.rs diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 3b696f5d2f1..075c4551f73 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -58,7 +58,7 @@ use crate::{ }, index::DatasetIndexInternalExt, io::exec::{ - AddRowAddrExec, Planner, TakeExec, project, + AddRowAddrExec, LateMaterializeJoin, LateTakePlanner, Planner, TakeExec, project, scalar_index::{IndexLookup, MapIndexExec}, utils::ReplayExec, }, @@ -72,6 +72,7 @@ use arrow_select::take::take_record_batch; use datafusion::common::NullEquality; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::error::DataFusionError; +use datafusion::optimizer::{OptimizerContext, OptimizerRule}; use datafusion::{ execution::{ context::{SessionConfig, SessionContext}, @@ -88,7 +89,7 @@ use datafusion::{ stream::RecordBatchStreamAdapter, union::UnionExec, }, - physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::DataFrame, scalar::ScalarValue, }; @@ -1538,24 +1539,34 @@ impl MergeInsertJob { node: Arc::new(write_node), }); + // First pass: standard optimization. The non-source "fill" columns are + // still referenced (each copies `target.`), so projection pushdown + // keeps them — and `target._rowaddr` — in the target scan. let logical_plan = session_state.optimize(&logical_plan)?; - let planner = - DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(MergeInsertPlanner {})]); - // This method already does the optimization for us. + // Defer reading non-source columns: a partial-schema upsert reads the + // missing columns from the target side of the join only to rewrite full + // rows. This rule inserts a `LateTake` above the join so those wide + // columns are fetched by `_rowaddr` for the matched rows only, instead + // of being scanned for every target row. It is applied only here (not in + // the session-wide optimizer) to bound its blast radius. + let logical_plan = LateMaterializeJoin::new() + .rewrite(logical_plan, &OptimizerContext::default())? + .data; + + // Second pass: the deferred columns are now absent from the `LateTake`'s + // input, so projection pushdown narrows the target scan to drop them + // while keeping `_rowaddr` (which the take forces in). + let logical_plan = session_state.optimize(&logical_plan)?; + + let planner = DefaultPhysicalPlanner::with_extension_planners(vec![ + Arc::new(MergeInsertPlanner {}) as Arc, + Arc::new(LateTakePlanner) as Arc, + ]); let physical_plan = planner .create_physical_plan(&logical_plan, &session_state) .await?; - // Defer reading non-source columns: a partial-schema upsert reads the - // missing columns from the target side of the join only to rewrite full - // rows. This rule pushes those reads past the join so a selective match - // does not scan wide columns for every target row. It is applied only - // here (not in the session-wide optimizer) to bound its blast radius. - use datafusion::physical_optimizer::PhysicalOptimizerRule; - let physical_plan = crate::io::exec::LateMaterializeOverReducingJoin - .optimize(physical_plan, &datafusion::config::ConfigOptions::default())?; - Ok(physical_plan) } diff --git a/rust/lance/src/io/exec.rs b/rust/lance/src/io/exec.rs index d5fcef6ef24..428362d83d8 100644 --- a/rust/lance/src/io/exec.rs +++ b/rust/lance/src/io/exec.rs @@ -15,7 +15,6 @@ pub mod filtered_read; pub mod filtered_read_proto; pub mod fts; pub(crate) mod knn; -mod late_materialization; mod late_take; mod optimizer; mod projection; @@ -34,7 +33,6 @@ pub use filter::LanceFilterExec; pub use knn::{ANNIvfPartitionExec, ANNIvfSubIndexExec, KNNVectorDistanceExec}; pub use lance_datafusion::planner::Planner; pub use lance_index::scalar::expression::FilterPlan; -pub use late_materialization::LateMaterializeOverReducingJoin; pub use late_take::{LateMaterializeJoin, LateTakeNode, LateTakePlanner}; pub use optimizer::get_physical_optimizer; pub use projection::project; diff --git a/rust/lance/src/io/exec/late_materialization.rs b/rust/lance/src/io/exec/late_materialization.rs deleted file mode 100644 index 6674a6ba10e..00000000000 --- a/rust/lance/src/io/exec/late_materialization.rs +++ /dev/null @@ -1,329 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -//! Late-materialization physical optimizer rule. -//! -//! Defers reading data columns that a row-reducing operator only carries -//! through, fetching them by `_rowaddr` after the row count has shrunk. Used by -//! `merge_insert` to avoid scanning wide non-source columns for every target -//! row of a selective partial-schema upsert. - -use std::collections::HashSet; -use std::sync::Arc; - -use arrow_schema::Schema as ArrowSchema; -use datafusion::{ - common::tree_node::{Transformed, TreeNode}, - config::ConfigOptions, - error::Result as DFResult, - logical_expr::JoinType, - physical_optimizer::PhysicalOptimizerRule, - physical_plan::{ - ExecutionPlan, - joins::HashJoinExec, - projection::{ProjectionExec, ProjectionExpr}, - }, -}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, expressions::Column}; -use lance_arrow::DataTypeExt; -use lance_core::datatypes::OnMissing; -use lance_core::{ROW_ADDR, ROW_ID}; - -use super::TakeExec; -use super::filtered_read::FilteredReadExec; - -/// Rewrite every [`Column`] in `expr` to reference `schema` by name. Used to -/// re-index a projection's expressions after the column layout of its input -/// changed (e.g. a column moved because it is now sourced from a [`TakeExec`]). -fn reindex_columns_by_name( - expr: Arc, - schema: &ArrowSchema, -) -> DFResult> { - Ok(expr - .transform_down(|e| { - if let Some(col) = e.as_any().downcast_ref::() { - let new_col = Column::new_with_schema(col.name(), schema)?; - Ok(Transformed::yes(Arc::new(new_col) as Arc)) - } else { - Ok(Transformed::no(e)) - } - })? - .data) -} - -/// Width/storage gate mirroring the scanner's late-materialization heuristic -/// ([`crate::dataset::scanner::MaterializationStyle::Heuristic`]): a column is -/// worth deferring only if it is "wide" for the backing storage — a -/// variable-width type (strings, lists, vectors) or a fixed-width type above the -/// per-row byte threshold (1KB on cloud storage, 10 bytes on local). Narrow -/// columns are cheaper to read in the sequential scan than to re-fetch by -/// address. -/// -/// Without a join-cardinality estimate (tracked in #4583) we cannot gate on -/// match selectivity, so we fall back to width alone. This covers the -/// inherently selective backfill case the feature targets; a follow-up can -/// incorporate cardinality once it is available. -pub fn is_wide_column(field: &lance_core::datatypes::Field, is_cloud: bool) -> bool { - if field.is_blob() { - return false; - } - let byte_width = field.data_type().byte_width_opt(); - if is_cloud { - byte_width.is_none_or(|bw| bw >= 1000) - } else { - byte_width.is_none_or(|bw| bw >= 10) - } -} - -/// Late-materialization rule: defer reading data columns that a row-reducing -/// operator (here, a [`HashJoinExec`]) only passes through, fetching them by -/// `_rowaddr` *after* the row count has been reduced. -/// -/// Concretely, for a `ProjectionExec -> HashJoinExec` where the join's build -/// (left) side is a [`FilteredReadExec`] that emits `_rowaddr`, any data column -/// the scan reads but the join only carries through (not a join key, not used -/// in a join filter) is dropped from the scan, re-fetched by a [`TakeExec`] -/// inserted above the join, and the parent projection is re-indexed to read it -/// from there. The projection's *output* schema is unchanged, so nothing above -/// it is affected. -/// -/// This is written generically but is currently applied only at the -/// merge_insert call site (not registered in the session-wide optimizer), which -/// bounds its blast radius. A column missing from the source of a partial-schema -/// upsert is exactly such a "carried-through" column, so deferring it avoids -/// scanning wide columns for every target row when only a few rows match. -#[derive(Debug, Default)] -pub struct LateMaterializeOverReducingJoin; - -impl LateMaterializeOverReducingJoin { - /// Attempt the rewrite for a `ProjectionExec` sitting directly above a - /// `HashJoinExec`. Returns the rewritten projection, or `None` if the - /// pattern does not apply or cannot be safely transformed. - fn try_defer(proj: &ProjectionExec) -> DFResult>> { - let Some(join) = proj.input().as_any().downcast_ref::() else { - return Ok(None); - }; - - // Only column-preserving join types have a (left ++ right) intermediate - // schema, which the index remapping below relies on. - if !matches!( - join.join_type(), - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full - ) { - return Ok(None); - } - - // A join filter may reference the columns we want to defer; bail rather - // than reason about its intermediate schema. - if join.filter().is_some() { - return Ok(None); - } - - // We only handle deferral on the build (left) side, which is where the - // scanned target relation sits in a merge_insert plan. Deferring on the - // probe side would need a symmetric (and untested) index remapping. - let Some(scan) = join.left().as_any().downcast_ref::() else { - return Ok(None); - }; - - let left_schema = join.left().schema(); - let right_schema = join.right().schema(); - let left_field_count = left_schema.fields().len(); - let right_field_count = right_schema.fields().len(); - - // The scan must emit `_rowaddr` so the deferred columns remain - // fetchable by address after the join. - if left_schema.column_with_name(ROW_ADDR).is_none() { - return Ok(None); - } - - // Join keys on the build side must stay in the scan. - let mut left_key_names = HashSet::new(); - for (left, _) in join.on() { - let Some(col) = left.as_any().downcast_ref::() else { - return Ok(None); - }; - left_key_names.insert(col.name().to_string()); - } - - let dataset = scan.dataset(); - let is_cloud = dataset.object_store.is_cloud(); - - // Candidates = scan-side data columns that aren't join keys (and aren't - // the system `_rowid`/`_rowaddr` columns). These are only used above the - // join, so they could be fetched after the row count shrinks. - let candidate_names = left_schema - .fields() - .iter() - .map(|f| f.name().clone()) - .filter(|name| { - name != ROW_ADDR - && name != ROW_ID - && !left_key_names.contains(name) - && dataset.schema().field(name).is_some() - }) - .collect::>(); - if candidate_names.is_empty() { - return Ok(None); - } - - // Width/storage gate: only defer columns wide enough that re-fetching - // by address beats scanning them for every target row. - let deferred_names = candidate_names - .iter() - .filter(|name| { - dataset - .schema() - .field(name) - .is_some_and(|f| is_wide_column(f, is_cloud)) - }) - .cloned() - .collect::>(); - if deferred_names.is_empty() { - tracing::debug!( - candidates = ?candidate_names, - is_cloud, - "merge_insert late-materialization skipped: no candidate column is wide enough to defer", - ); - return Ok(None); - } - let deferred_set = deferred_names.iter().cloned().collect::>(); - - // Map each old intermediate (left ++ right) column index to its index - // after the deferred columns are dropped from the left side. - let deferred_left_indices = left_schema - .fields() - .iter() - .enumerate() - .filter_map(|(i, f)| deferred_set.contains(f.name()).then_some(i)) - .collect::>(); - let mut old_to_new = vec![None; left_field_count + right_field_count]; - let mut new_left_len = 0; - for (i, slot) in old_to_new.iter_mut().enumerate().take(left_field_count) { - if !deferred_left_indices.contains(&i) { - *slot = Some(new_left_len); - new_left_len += 1; - } - } - for j in 0..right_field_count { - old_to_new[left_field_count + j] = Some(new_left_len + j); - } - - // Narrow the scan: drop the deferred columns, keep `_rowaddr`. - let mut narrowed = scan.options().projection.clone(); - narrowed = narrowed.subtract_predicate(|f| deferred_set.contains(&f.name)); - narrowed = narrowed.with_row_addr(); - let new_scan = Arc::new(FilteredReadExec::try_new( - dataset.clone(), - scan.options().clone().with_projection(narrowed), - scan.index_input().cloned(), - )?) as Arc; - - // Rebuild the join with the narrowed left child: re-index the keys and - // the join's output projection, dropping the deferred columns. - let new_on = join - .on() - .iter() - .map(|(left, right)| { - let col = left.as_any().downcast_ref::().unwrap(); - let new_idx = old_to_new[col.index()].expect("join key must not be deferred"); - ( - Arc::new(Column::new(col.name(), new_idx)) as PhysicalExprRef, - right.clone(), - ) - }) - .collect::>(); - let new_join_projection = join.projection.as_ref().map(|p| { - p.iter() - .filter_map(|&idx| old_to_new[idx]) - .collect::>() - }); - let new_join = HashJoinExec::try_new( - new_scan, - join.right().clone(), - new_on, - None, - join.join_type(), - new_join_projection, - *join.partition_mode(), - join.null_equality(), - join.null_aware, - )?; - - // Defensive: `_rowaddr` must survive into the take's input. - let join_schema = new_join.schema(); - if join_schema.column_with_name(ROW_ADDR).is_none() { - return Ok(None); - } - // If the join emits duplicate column names — e.g. both the left and - // right join keys survive because the join has no output projection — - // we can neither build the take's (lance) schema nor re-index the - // parent projection by name. Leave such plans untransformed. - let mut seen = HashSet::with_capacity(join_schema.fields().len()); - if join_schema.fields().iter().any(|f| !seen.insert(f.name())) { - return Ok(None); - } - let join_input = Arc::new(new_join) as Arc; - - // Insert the take that re-fetches the deferred columns by `_rowaddr`. - // For an outer join, unmatched (insert) rows have a null `_rowaddr` and - // must yield NULL deferred values, so the taken fields are nullable. - let mut take_projection = dataset.empty_projection(); - for name in &deferred_names { - take_projection = take_projection.union_column(name, OnMissing::Error)?; - } - let scan_side_null_extended = matches!(join.join_type(), JoinType::Right | JoinType::Full); - let take = if scan_side_null_extended { - TakeExec::try_new_nullable_extra(dataset.clone(), join_input, take_projection)? - } else { - TakeExec::try_new(dataset.clone(), join_input, take_projection)? - }; - let Some(take) = take else { - return Ok(None); - }; - let take = Arc::new(take) as Arc; - - // Re-index the parent projection's expressions onto the take output. - // Post-join column names are unique, so name-based reindexing is safe. - let take_schema = take.schema(); - let new_exprs = proj - .expr() - .iter() - .map(|pe| { - Ok(ProjectionExpr { - expr: reindex_columns_by_name(pe.expr.clone(), take_schema.as_ref())?, - alias: pe.alias.clone(), - }) - }) - .collect::>>()?; - let new_proj = ProjectionExec::try_new(new_exprs, take)?; - Ok(Some(Arc::new(new_proj))) - } -} - -impl PhysicalOptimizerRule for LateMaterializeOverReducingJoin { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> DFResult> { - Ok(plan - .transform_down(|plan| { - if let Some(proj) = plan.as_any().downcast_ref::() - && let Some(rewritten) = Self::try_defer(proj)? - { - return Ok(Transformed::yes(rewritten)); - } - Ok(Transformed::no(plan)) - })? - .data) - } - - fn name(&self) -> &str { - "late_materialize_over_reducing_join" - } - - fn schema_check(&self) -> bool { - true - } -} diff --git a/rust/lance/src/io/exec/late_take.rs b/rust/lance/src/io/exec/late_take.rs index b817e227e81..9f8deae6c37 100644 --- a/rust/lance/src/io/exec/late_take.rs +++ b/rust/lance/src/io/exec/late_take.rs @@ -4,13 +4,15 @@ //! Late-materialization *logical* optimizer rule. //! //! Defers reading wide data columns that a row-reducing join only carries -//! through, fetching them by `_rowaddr` *after* the row count has shrunk. +//! through, fetching them by `_rowaddr` *after* the row count has shrunk. Used +//! by `merge_insert` to avoid scanning wide non-source columns for every target +//! row of a selective partial-schema upsert. //! -//! Unlike the physical [`super::LateMaterializeOverReducingJoin`] rule (which -//! re-indexes a `HashJoinExec` and its parent projection by position), this -//! works at the logical level: a [`LateTakeNode`] is inserted above the join -//! and advertises an output schema of "join columns minus deferred, plus the -//! deferred columns appended". Its [`UserDefinedLogicalNodeCore::necessary_children_exprs`] +//! Working at the logical level (rather than rewriting a physical +//! `HashJoinExec` by position), a [`LateTakeNode`] is inserted above the join — +//! fed by a projection that keeps only the columns carried past the join — and +//! advertises an output schema of "carried columns plus the deferred columns +//! appended". Its [`UserDefinedLogicalNodeCore::necessary_children_exprs`] //! reports that it does *not* need the deferred columns from its child, only //! `_rowaddr`. DataFusion's stock `OptimizeProjections` rule then prunes those //! columns from the scan automatically — no manual index remapping — and @@ -27,25 +29,49 @@ use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; use async_trait::async_trait; use datafusion::{ common::{ - DFSchema, DFSchemaRef, Result as DFResult, TableReference, + Column, DFSchema, DFSchemaRef, Result as DFResult, TableReference, tree_node::{Transformed, TreeNode, TreeNodeRecursion}, }, datasource::DefaultTableSource, execution::SessionState, - logical_expr::{Expr, Extension, Join, JoinType, LogicalPlan}, + logical_expr::{Expr, Extension, Join, JoinType, LogicalPlan, Projection}, optimizer::{OptimizerConfig, OptimizerRule}, physical_plan::ExecutionPlan, physical_planner::{ExtensionPlanner, PhysicalPlanner}, }; use datafusion_expr::{UserDefinedLogicalNode, UserDefinedLogicalNodeCore}; +use lance_arrow::DataTypeExt; use lance_core::datatypes::OnMissing; use lance_core::{ROW_ADDR, ROW_ID}; use super::TakeExec; -use super::late_materialization::is_wide_column; use crate::Dataset; use crate::datafusion::dataframe::LanceTableProvider; +/// Width/storage gate mirroring the scanner's late-materialization heuristic +/// ([`crate::dataset::scanner::MaterializationStyle::Heuristic`]): a column is +/// worth deferring only if it is "wide" for the backing storage — a +/// variable-width type (strings, lists, vectors) or a fixed-width type above the +/// per-row byte threshold (1KB on cloud storage, 10 bytes on local). Narrow +/// columns are cheaper to read in the sequential scan than to re-fetch by +/// address. +/// +/// Without a join-cardinality estimate (tracked in #4583) we cannot gate on +/// match selectivity, so we fall back to width alone. This covers the +/// inherently selective backfill case the feature targets; a follow-up can +/// incorporate cardinality once it is available. +fn is_wide_column(field: &lance_core::datatypes::Field, is_cloud: bool) -> bool { + if field.is_blob() { + return false; + } + let byte_width = field.data_type().byte_width_opt(); + if is_cloud { + byte_width.is_none_or(|bw| bw >= 1000) + } else { + byte_width.is_none_or(|bw| bw >= 10) + } +} + /// Logical plan node that re-fetches `deferred_columns` from `dataset` by /// `_rowaddr` after a row-reducing operator. /// @@ -316,6 +342,14 @@ impl LateMaterializeJoin { ) -> DFResult, String)>> { let mut referenced = HashSet::new(); plan.apply(|node| { + // A join's own on-clause / filter columns are consumed by the join + // itself, not by an operator above it, so they must not count as + // "used above the join". (Equi-keys are the common case: both sides' + // keys share a name, which would otherwise look like a duplicate + // column flowing into the take.) Children are still visited. + if matches!(node, LogicalPlan::Join(_)) { + return Ok(TreeNodeRecursion::Continue); + } for expr in node.expressions() { expr.apply(|e| { if let Expr::Column(col) = e { @@ -343,19 +377,6 @@ impl LateMaterializeJoin { return Ok(None); } - // The take re-fetches by `_rowaddr` and merges columns by name, so it - // cannot sit above a join whose output has duplicate field names (e.g. - // both sides' equi-keys share a name). Matches the physical rule. - let mut seen = HashSet::with_capacity(join.schema.fields().len()); - if join - .schema - .fields() - .iter() - .any(|f| !seen.insert(f.name().as_str())) - { - return Ok(None); - } - for side in [JoinSide::Left, JoinSide::Right] { let side_plan: &LogicalPlan = match side { JoinSide::Left => &join.left, @@ -417,6 +438,57 @@ impl LateMaterializeJoin { let deferred_columns: Vec = deferred.into_iter().map(|(_, name)| name).collect(); + // Insert a normalizing projection between the join and the take that + // keeps only the columns the take must carry: those referenced above + // the join plus the row locator, with the deferred columns dropped + // (they are re-fetched). This serves two purposes: + // * It lets the physical planner give the join a tight output + // projection. Without it, the opaque take node forces the + // `HashJoinExec` to emit its full `left ++ right` output, which + // for an equi-join includes both sides' key columns — duplicate + // arrow names that `TakeExec` (which merges by name) cannot + // handle once qualifiers are erased at the physical level. + // * It is where scan narrowing happens: the deferred columns are + // simply absent from the projection, so projection pushdown drops + // them from the scan while keeping `_rowaddr`. + // `referenced` excludes the join's own on-clause columns, so a + // redundant equi-key (used only by the join) is pruned here rather + // than colliding with the surviving key. If two *referenced* columns + // still share a name, the take cannot disambiguate them; bail. + let deferred_set: HashSet<&str> = deferred_columns.iter().map(|s| s.as_str()).collect(); + let mut kept_exprs: Vec = Vec::new(); + // Seed with the deferred names: they are appended to the take output, + // so a kept column sharing one of those names also conflicts. + let mut kept_names: HashSet = deferred_columns.iter().cloned().collect(); + let mut name_conflict = false; + for (col_qualifier, field) in join.schema.iter() { + let name = field.name(); + // Drop the deferred columns: they are re-fetched and appended by + // the take, not carried through its input. + if col_qualifier == qualifier.as_ref() && deferred_set.contains(name.as_str()) { + continue; + } + let kept = name == ROW_ADDR + || name == ROW_ID + || referenced.contains(&(col_qualifier.cloned(), name.clone())); + if !kept { + continue; + } + if !kept_names.insert(name.clone()) { + name_conflict = true; + break; + } + kept_exprs.push(Expr::Column(Column::new(col_qualifier.cloned(), name))); + } + if name_conflict { + continue; + } + + let take_input = LogicalPlan::Projection(Projection::try_new( + kept_exprs, + Arc::new(LogicalPlan::Join(join.clone())), + )?); + // The scan side is null-extended when it is the optional side of an // outer join; its unmatched rows then have a null `_rowaddr` and // must yield NULL deferred values. @@ -429,7 +501,7 @@ impl LateMaterializeJoin { ); let node = LateTakeNode::try_new( - LogicalPlan::Join(join.clone()), + take_input, dataset, deferred_columns, qualifier, @@ -752,8 +824,9 @@ mod tests { async fn test_duplicate_join_output_names_not_deferred() { let (dataset, _tmp) = test_dataset().await; let ctx = SessionContext::new(); - // Both sides expose `payload`; the join output has a duplicate name, so - // the take (which merges by name) cannot be inserted. + // Both sides expose `payload` and both are consumed above the join, so + // the take (which merges appended columns by name) would produce a + // duplicate `payload` and must not be inserted. let target = ctx .read_lance_unordered(dataset, false, true) .unwrap() @@ -779,14 +852,20 @@ mod tests { let plan = target .join(source, JoinType::Inner, &["id"], &["sid"], None) .unwrap() - .select(vec![col("target.id"), col("target.payload")]) + // Reference both sides' `payload` above the join: deferring + // `target.payload` would collide with the kept `source.payload`. + .select(vec![ + col("target.id"), + col("target.payload"), + col("source.payload"), + ]) .unwrap() .into_unoptimized_plan(); let optimized = run_rule_and_pushdown(plan); assert!( !has_late_take(&optimized), - "duplicate join-output names must not be deferred" + "duplicate output names must not be deferred" ); } From c2d76d0bc05301051b3162ce16e24430fea1ca25 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 16:15:47 -0700 Subject: [PATCH 4/7] review: qualify deferred-column matching; generalize comments Address review on the logical late-materialization node: - The node identified deferred columns by name alone, so a same-named column from another relation would be wrongly dropped from the passthrough. Match on (qualifier, name) instead, in both `build_output_schema` and `necessary_children_exprs`. Adds a unit test. - Strip use-case/issue-specific context from the module doc and `is_wide_column` so the rule reads as generic. Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance/src/io/exec/late_take.rs | 121 +++++++++++++++++++++------- 1 file changed, 92 insertions(+), 29 deletions(-) diff --git a/rust/lance/src/io/exec/late_take.rs b/rust/lance/src/io/exec/late_take.rs index 9f8deae6c37..6e2cf97dd3c 100644 --- a/rust/lance/src/io/exec/late_take.rs +++ b/rust/lance/src/io/exec/late_take.rs @@ -4,23 +4,18 @@ //! Late-materialization *logical* optimizer rule. //! //! Defers reading wide data columns that a row-reducing join only carries -//! through, fetching them by `_rowaddr` *after* the row count has shrunk. Used -//! by `merge_insert` to avoid scanning wide non-source columns for every target -//! row of a selective partial-schema upsert. +//! through, fetching them by `_rowaddr` *after* the row count has shrunk. //! -//! Working at the logical level (rather than rewriting a physical -//! `HashJoinExec` by position), a [`LateTakeNode`] is inserted above the join — -//! fed by a projection that keeps only the columns carried past the join — and -//! advertises an output schema of "carried columns plus the deferred columns -//! appended". Its [`UserDefinedLogicalNodeCore::necessary_children_exprs`] -//! reports that it does *not* need the deferred columns from its child, only -//! `_rowaddr`. DataFusion's stock `OptimizeProjections` rule then prunes those -//! columns from the scan automatically — no manual index remapping — and -//! downstream column references resolve the deferred columns from the take by -//! name. +//! A [`LateTakeNode`] is inserted above the join — fed by a projection that +//! keeps only the columns carried past the join — and advertises an output +//! schema of "carried columns plus the deferred columns appended". Its +//! [`UserDefinedLogicalNodeCore::necessary_children_exprs`] reports that it does +//! *not* need the deferred columns from its child, only `_rowaddr`, so +//! DataFusion's stock `OptimizeProjections` rule prunes them from the scan +//! automatically and downstream references resolve the deferred columns from the +//! take by name. //! -//! The node lowers to the existing physical [`super::TakeExec`] via -//! [`LateTakePlanner`]. +//! The node lowers to the physical [`super::TakeExec`] via [`LateTakePlanner`]. use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; @@ -48,18 +43,11 @@ use super::TakeExec; use crate::Dataset; use crate::datafusion::dataframe::LanceTableProvider; -/// Width/storage gate mirroring the scanner's late-materialization heuristic -/// ([`crate::dataset::scanner::MaterializationStyle::Heuristic`]): a column is -/// worth deferring only if it is "wide" for the backing storage — a -/// variable-width type (strings, lists, vectors) or a fixed-width type above the -/// per-row byte threshold (1KB on cloud storage, 10 bytes on local). Narrow -/// columns are cheaper to read in the sequential scan than to re-fetch by -/// address. -/// -/// Without a join-cardinality estimate (tracked in #4583) we cannot gate on -/// match selectivity, so we fall back to width alone. This covers the -/// inherently selective backfill case the feature targets; a follow-up can -/// incorporate cardinality once it is available. +/// Width/storage gate: a column is worth deferring only if it is "wide" for the +/// backing storage — a variable-width type (strings, lists, vectors) or a +/// fixed-width type at or above the per-row byte threshold (1KB on cloud +/// storage, 10 bytes on local). Narrow columns are cheaper to read in the +/// sequential scan than to re-fetch by address. fn is_wide_column(field: &lance_core::datatypes::Field, is_cloud: bool) -> bool { if field.is_blob() { return false; @@ -164,9 +152,14 @@ impl LateTakeNode { let input_schema = input.schema(); let deferred_set: HashSet<&str> = deferred_columns.iter().map(|s| s.as_str()).collect(); + // A field is deferred only when both its name and its qualifier match the + // deferred columns' relation; a same-named field from another relation is + // a distinct column and must stay in place. let mut qualified_fields: Vec<(Option, Arc)> = input_schema .iter() - .filter(|(_, f)| !deferred_set.contains(f.name().as_str())) + .filter(|(q, f)| { + !(*q == qualifier.as_ref() && deferred_set.contains(f.name().as_str())) + }) .map(|(q, f)| (q.cloned(), f.clone())) .collect(); @@ -260,10 +253,14 @@ impl UserDefinedLogicalNodeCore for LateTakeNode { // Output positions [0..passthrough_len) map back to these child indices, // in order; positions beyond are the appended (fetched) deferred columns. + // The deferred match is qualified: a same-named field from another + // relation is a distinct passthrough column, not a deferred one. let passthrough: Vec = input_schema .iter() .enumerate() - .filter(|(_, (_, f))| !deferred_set.contains(f.name().as_str())) + .filter(|(_, (q, f))| { + !(*q == self.qualifier.as_ref() && deferred_set.contains(f.name().as_str())) + }) .map(|(i, _)| i) .collect(); @@ -869,6 +866,72 @@ mod tests { ); } + /// The deferred-column match is qualified: a same-named column from another + /// relation must not be dropped just because it shares a deferred name. + #[tokio::test] + async fn test_deferred_match_respects_qualifier() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let target = ctx + .read_lance_unordered(dataset.clone(), false, true) + .unwrap() + .alias("target") + .unwrap(); + // The source relation also exposes a `payload` column. + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("sid", DataType::Int32, false), + Field::new("payload", DataType::Utf8, true), + ])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(Int32Array::from(vec![1, 3])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + let input = target + .join(source, JoinType::Inner, &["id"], &["sid"], None) + .unwrap() + .into_unoptimized_plan(); + let input_field_count = input.schema().fields().len(); + + // Defer only `target.payload`; `source.payload` is a distinct column. + let node = LateTakeNode::try_new( + input, + dataset, + vec!["payload".to_string()], + Some(TableReference::bare("target")), + false, + ) + .unwrap(); + + let schema = UserDefinedLogicalNodeCore::schema(&node); + let target_ref = TableReference::bare("target"); + let source_ref = TableReference::bare("source"); + // `source.payload` survives (only the qualified `target.payload` is + // dropped from the passthrough), and `target.payload` is re-appended — + // so the output keeps the same field count as the input. + assert!( + schema + .index_of_column_by_name(Some(&source_ref), "payload") + .is_some(), + "source.payload must not be dropped by qualifier-blind matching" + ); + assert!( + schema + .index_of_column_by_name(Some(&target_ref), "payload") + .is_some(), + "target.payload should be appended by the take" + ); + assert_eq!(schema.fields().len(), input_field_count); + } + #[tokio::test] async fn test_outer_join_marks_deferred_nullable() { let (dataset, _tmp) = test_dataset().await; From 08295debdedfd5cc2212550ca85bd591222d3f82 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 16:51:04 -0700 Subject: [PATCH 5/7] test(merge_insert): cover is_wide_column gate; doc LateTakeNode::try_new Self-review polish: - Add a doc comment to the public `LateTakeNode::try_new` constructor. - Note on the `PartialOrd` impl why it orders a subset of the equality fields. - Add a unit test for `is_wide_column` covering the local vs cloud byte thresholds and the blob exclusion (previously only the local variable-width path was exercised). Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance/src/io/exec/late_take.rs | 50 +++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/rust/lance/src/io/exec/late_take.rs b/rust/lance/src/io/exec/late_take.rs index 6e2cf97dd3c..f18038b649c 100644 --- a/rust/lance/src/io/exec/late_take.rs +++ b/rust/lance/src/io/exec/late_take.rs @@ -108,6 +108,9 @@ impl std::hash::Hash for LateTakeNode { } impl PartialOrd for LateTakeNode { + // Orders by the only fields that have a natural order (`deferred_columns`, + // then `input`); `dataset`/`qualifier`/`nullable_extra` are part of equality + // but not ordered here, matching the sibling `MergeInsertWriteNode` impl. fn partial_cmp(&self, other: &Self) -> Option { match self.deferred_columns.partial_cmp(&other.deferred_columns) { Some(std::cmp::Ordering::Equal) => self.input.partial_cmp(&other.input), @@ -117,6 +120,14 @@ impl PartialOrd for LateTakeNode { } impl LateTakeNode { + /// Build a node that re-fetches `deferred_columns` (dataset field names, in + /// dataset-schema order) from `dataset` by row address. + /// + /// `qualifier` is the relation the deferred columns came from; the appended + /// output fields carry it so downstream references still resolve. Set + /// `nullable_extra` when the take sits above an outer join on whose optional + /// side the scan lives, so unmatched rows (null row address) yield NULL + /// deferred values. Errors if a deferred name is absent from the dataset. pub fn try_new( input: LogicalPlan, dataset: Arc, @@ -932,6 +943,45 @@ mod tests { assert_eq!(schema.fields().len(), input_field_count); } + #[test] + fn test_is_wide_column() { + use lance_core::datatypes::Field as LanceField; + use std::collections::HashMap; + + let lance_field = |arrow: Field| LanceField::try_from(arrow).unwrap(); + + // Variable-width: wide regardless of storage. + let utf8 = lance_field(Field::new("s", DataType::Utf8, true)); + assert!(is_wide_column(&utf8, false)); + assert!(is_wide_column(&utf8, true)); + + // Small fixed-width (4 bytes): narrow on both local and cloud. + let small = lance_field(Field::new("n", DataType::Int32, true)); + assert!(!is_wide_column(&small, false)); + assert!(!is_wide_column(&small, true)); + + // Mid fixed-width (64 bytes): above the 10-byte local threshold but + // below the 1KB cloud threshold — wide locally, narrow on cloud. + let mid = lance_field(Field::new("m", DataType::FixedSizeBinary(64), true)); + assert!(is_wide_column(&mid, false)); + assert!(!is_wide_column(&mid, true)); + + // Large fixed-width (2KB): wide on both. + let large = lance_field(Field::new("l", DataType::FixedSizeBinary(2048), true)); + assert!(is_wide_column(&large, false)); + assert!(is_wide_column(&large, true)); + + // Blob columns are never deferred even though LargeBinary is + // variable-width — they have their own access path. + let blob_meta = + HashMap::from([(lance_arrow::BLOB_META_KEY.to_string(), "true".to_string())]); + let blob = + lance_field(Field::new("b", DataType::LargeBinary, true).with_metadata(blob_meta)); + assert!(blob.is_blob()); + assert!(!is_wide_column(&blob, false)); + assert!(!is_wide_column(&blob, true)); + } + #[tokio::test] async fn test_outer_join_marks_deferred_nullable() { let (dataset, _tmp) = test_dataset().await; From cf54eb6bc3f12c7c9bc7c1ac8412c0c6817ed312 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 17:00:02 -0700 Subject: [PATCH 6/7] test(merge_insert): cover kept-column name conflict and single-input descent Add two tests flagged in self-review: - `test_kept_column_name_conflict_not_deferred`: two non-deferred columns from different relations sharing a name (the kept-vs-kept bail path, distinct from the deferred-name collision already covered). - `test_find_lance_dataset_descends_single_input_only`: the documented invariant that a multi-input join never surfaces a nested scan as its relation. Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance/src/io/exec/late_take.rs | 89 +++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/rust/lance/src/io/exec/late_take.rs b/rust/lance/src/io/exec/late_take.rs index f18038b649c..c0384fca57b 100644 --- a/rust/lance/src/io/exec/late_take.rs +++ b/rust/lance/src/io/exec/late_take.rs @@ -982,6 +982,95 @@ mod tests { assert!(!is_wide_column(&blob, true)); } + /// Two *non-deferred* columns from different relations that share a name + /// (here `tag`, on both sides) collide in the take's input. This is the + /// kept-vs-kept conflict path — distinct from the duplicate-name test above, + /// which collides a kept column with a *deferred* name. + #[tokio::test] + async fn test_kept_column_name_conflict_not_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + // Source carries a narrow `tag` (never a defer candidate itself). + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("sid", DataType::Int32, false), + Field::new("tag", DataType::Int32, true), + ])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(Int32Array::from(vec![1, 3])), + Arc::new(Int32Array::from(vec![100, 300])), + ], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + // `payload` is wide (deferrable), but both `target.tag` and `source.tag` + // are carried past the join → their shared name would collide in the + // take's input, so the rule must bail. + let plan = target + .join(source, JoinType::Inner, &["id"], &["sid"], None) + .unwrap() + .select(vec![ + col("target.id"), + col("target.payload"), + col("target.tag"), + col("source.tag"), + ]) + .unwrap() + .into_unoptimized_plan(); + + let optimized = run_rule_and_pushdown(plan); + assert!( + !has_late_take(&optimized), + "a kept-column name collision must not be deferred" + ); + } + + /// `find_lance_dataset` must descend only single-input nodes, so a nested + /// join's scan is never mistaken for a scannable relation of an outer join. + #[tokio::test] + async fn test_find_lance_dataset_descends_single_input_only() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + + // A single-input chain (SubqueryAlias -> TableScan) resolves to the dataset. + let scan_plan = ctx + .read_lance_unordered(dataset.clone(), false, true) + .unwrap() + .alias("target") + .unwrap() + .into_unoptimized_plan(); + assert!(LateMaterializeJoin::find_lance_dataset(&scan_plan).is_some()); + + // A join is multi-input: neither side's scan is surfaced, so the rule + // cannot pick up a nested join's relation by mistake. + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let join_plan = target + .join( + source_df(&ctx, vec![1, 3]), + JoinType::Inner, + &["id"], + &["sid"], + None, + ) + .unwrap() + .into_unoptimized_plan(); + assert!(LateMaterializeJoin::find_lance_dataset(&join_plan).is_none()); + } + #[tokio::test] async fn test_outer_join_marks_deferred_nullable() { let (dataset, _tmp) = test_dataset().await; From c3055c8f589ef17184affe4ea3f84e195f83291b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 18 Jun 2026 17:03:42 -0700 Subject: [PATCH 7/7] refactor(merge_insert): tighten candidate collection in late-take rule - Collect deferral candidates by name and derive `deferred_columns` by iterating the dataset schema in order, dropping the index/sort tuple. - Trim the normalizing-projection comment. Co-Authored-By: Claude Opus 4.8 (1M context) --- rust/lance/src/io/exec/late_take.rs | 87 +++++++++++++---------------- 1 file changed, 39 insertions(+), 48 deletions(-) diff --git a/rust/lance/src/io/exec/late_take.rs b/rust/lance/src/io/exec/late_take.rs index c0384fca57b..f9295594532 100644 --- a/rust/lance/src/io/exec/late_take.rs +++ b/rust/lance/src/io/exec/late_take.rs @@ -413,56 +413,47 @@ impl LateMaterializeJoin { let dataset_arrow = ArrowSchema::from(dataset.schema()); // Candidates = scan-side data columns that aren't join keys nor the - // system columns, that are consumed above the join, and are wide - // enough to be worth re-fetching by address. - let mut deferred: Vec<(usize, String)> = Vec::new(); - for (col_qualifier, field) in side_schema.iter() { - let name = field.name(); - if name == ROW_ADDR || name == ROW_ID || key_names.contains(name) { - continue; - } - // Only defer columns actually used above the join; a column the - // scan produces but nobody references would just be pruned, so - // re-fetching it would be wasted work. - if !referenced.contains(&(col_qualifier.cloned(), name.clone())) { - continue; - } - let Some(ds_field) = dataset.schema().field(name) else { - continue; - }; - if !is_wide_column(ds_field, is_cloud) { - continue; - } - let Ok(ds_idx) = dataset_arrow.index_of(name) else { - continue; - }; - deferred.push((ds_idx, name.clone())); - } - if deferred.is_empty() { + // system columns, that are consumed above the join (a column nobody + // references would just be pruned, so re-fetching it is wasted work), + // and are wide enough to be worth re-fetching by address. + let candidates: HashSet<&str> = side_schema + .iter() + .filter(|(col_qualifier, field)| { + let name = field.name(); + name != ROW_ADDR + && name != ROW_ID + && !key_names.contains(name) + && referenced.contains(&(col_qualifier.cloned(), name.clone())) + && dataset + .schema() + .field(name) + .is_some_and(|f| is_wide_column(f, is_cloud)) + }) + .map(|(_, field)| field.name().as_str()) + .collect(); + if candidates.is_empty() { continue; } - // Append in dataset-schema order to match TakeExec's output order. - deferred.sort_by_key(|(idx, _)| *idx); - let deferred_columns: Vec = - deferred.into_iter().map(|(_, name)| name).collect(); - - // Insert a normalizing projection between the join and the take that - // keeps only the columns the take must carry: those referenced above - // the join plus the row locator, with the deferred columns dropped - // (they are re-fetched). This serves two purposes: - // * It lets the physical planner give the join a tight output - // projection. Without it, the opaque take node forces the - // `HashJoinExec` to emit its full `left ++ right` output, which - // for an equi-join includes both sides' key columns — duplicate - // arrow names that `TakeExec` (which merges by name) cannot - // handle once qualifiers are erased at the physical level. - // * It is where scan narrowing happens: the deferred columns are - // simply absent from the projection, so projection pushdown drops - // them from the scan while keeping `_rowaddr`. - // `referenced` excludes the join's own on-clause columns, so a - // redundant equi-key (used only by the join) is pruned here rather - // than colliding with the surviving key. If two *referenced* columns - // still share a name, the take cannot disambiguate them; bail. + // Order by the dataset schema to match TakeExec's append order. + let deferred_columns: Vec = dataset_arrow + .fields() + .iter() + .map(|f| f.name()) + .filter(|name| candidates.contains(name.as_str())) + .cloned() + .collect(); + + // Insert a normalizing projection between the join and the take, + // keeping only the columns carried past the join (referenced above it + // plus the row locator) and dropping the deferred ones. This gives the + // join a tight output projection — without it the opaque take node + // forces `HashJoinExec` to emit its full `left ++ right`, whose + // equi-keys become duplicate arrow names the take (merging by name) + // can't handle once qualifiers are erased — and it is where scan + // narrowing happens, since the absent columns are pushed out of the + // scan. `referenced` excludes the join's own on-clause, so a redundant + // equi-key is pruned here; if two *referenced* columns still share a + // name the take can't disambiguate them, so bail. let deferred_set: HashSet<&str> = deferred_columns.iter().map(|s| s.as_str()).collect(); let mut kept_exprs: Vec = Vec::new(); // Seed with the deferred names: they are appended to the take output,