diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index b14421c963f..075c4551f73 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -58,7 +58,7 @@ use crate::{ }, index::DatasetIndexInternalExt, io::exec::{ - AddRowAddrExec, Planner, TakeExec, project, + AddRowAddrExec, LateMaterializeJoin, LateTakePlanner, Planner, TakeExec, project, scalar_index::{IndexLookup, MapIndexExec}, utils::ReplayExec, }, @@ -72,6 +72,7 @@ use arrow_select::take::take_record_batch; use datafusion::common::NullEquality; use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::error::DataFusionError; +use datafusion::optimizer::{OptimizerContext, OptimizerRule}; use datafusion::{ execution::{ context::{SessionConfig, SessionContext}, @@ -88,7 +89,7 @@ use datafusion::{ stream::RecordBatchStreamAdapter, union::UnionExec, }, - physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, + physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::DataFrame, scalar::ScalarValue, }; @@ -1538,11 +1539,30 @@ impl MergeInsertJob { node: Arc::new(write_node), }); + // First pass: standard optimization. The non-source "fill" columns are + // still referenced (each copies `target.`), so projection pushdown + // keeps them — and `target._rowaddr` — in the target scan. let logical_plan = session_state.optimize(&logical_plan)?; - let planner = - DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(MergeInsertPlanner {})]); - // This method already does the optimization for us. + // Defer reading non-source columns: a partial-schema upsert reads the + // missing columns from the target side of the join only to rewrite full + // rows. This rule inserts a `LateTake` above the join so those wide + // columns are fetched by `_rowaddr` for the matched rows only, instead + // of being scanned for every target row. It is applied only here (not in + // the session-wide optimizer) to bound its blast radius. + let logical_plan = LateMaterializeJoin::new() + .rewrite(logical_plan, &OptimizerContext::default())? + .data; + + // Second pass: the deferred columns are now absent from the `LateTake`'s + // input, so projection pushdown narrows the target scan to drop them + // while keeping `_rowaddr` (which the take forces in). + let logical_plan = session_state.optimize(&logical_plan)?; + + let planner = DefaultPhysicalPlanner::with_extension_planners(vec![ + Arc::new(MergeInsertPlanner {}) as Arc, + Arc::new(LateTakePlanner) as Arc, + ]); let physical_plan = planner .create_physical_plan(&logical_plan, &session_state) .await?; @@ -4521,20 +4541,379 @@ mod tests { "expected HashJoinExec in plan, got: {}", plan ); - // Evidence that the partial-schema fix is active: the target - // side of the join reads the `other` column (which is missing - // from the source) and an explicit projection carries it - // through to the write exec alongside source columns. + // Late materialization is active: the `other` column (missing from + // the source) is *not* read by the target scan. Instead it is + // fetched by a `Take` inserted above the join, so a selective match + // does not scan `other` for every target row. + assert!( + plan.contains("LanceRead") && plan.contains("projection=[key]"), + "target-side scan should only read the join key, not `other`: {}", + plan + ); + assert!( + !plan.contains("projection=[other"), + "deferred `other` column must not be in the target scan projection: {}", + plan + ); + assert!( + plan.contains("Take") && plan.contains("(other)"), + "expected a Take above the join fetching the deferred `other` column: {}", + plan + ); + } + + /// Extract the `output_rows` metric of the first plan node whose + /// (trimmed) display line starts with `node_prefix`, from an + /// `analyze_plan` string. + fn output_rows_for_node(analysis: &str, node_prefix: &str) -> Option { + let line = analysis + .lines() + .find(|l| l.trim_start().starts_with(node_prefix))?; + let start = line.find("output_rows=")? + "output_rows=".len(); + let rest = &line[start..]; + let end = rest + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(rest.len()); + rest[..end].parse().ok() + } + + /// The read-amplification payoff: on a selective partial-schema update, + /// the wide non-source column must be fetched only for the matched rows + /// (via the `Take`), not scanned for every target row. + #[tokio::test] + async fn test_merge_insert_subcols_defers_wide_column_reads() { + // 100 rows across 4 fragments; `other` is a wide (string) column, + // `key`/`value` are narrow. Keys are 0..100 so matches are precise. + let batch = lance_datagen::gen_batch() + .with_seed(Seed::from(1)) + .col("other", array::rand_utf8(64.into(), false)) + .col("value", array::step::()) + .col("key", array::step_custom::(0, 1)) + .into_batch_rows(RowCount::from(100)) + .unwrap(); + let schema = batch.schema(); + let ds = Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + Some(WriteParams { + max_rows_per_file: 25, + ..Default::default() + }), + ) + .await + .unwrap(); + let ds = Arc::new(ds); + + // Partial-schema source (key, value) updating only 5 of 100 keys. + let update_schema = Arc::new(schema.project(&[2, 1]).unwrap()); + let new_data = RecordBatch::try_new( + update_schema, + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3, 4])), + Arc::new(UInt32Array::from(vec![1000u32, 1001, 1002, 1003, 1004])), + ], + ) + .unwrap(); + + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + + let source = reader_to_stream(Box::new(RecordBatchIterator::new( + [Ok(new_data.clone())], + new_data.schema(), + ))); + let analysis = job.analyze_plan(source).await.unwrap(); + + // The target scan reads every row but only the narrow key column... + let lance_read = analysis + .lines() + .find(|l| l.trim_start().starts_with("LanceRead")) + .unwrap_or_else(|| panic!("no LanceRead node:\n{}", analysis)); + assert!( + lance_read.contains("projection=[key]"), + "target scan must defer the wide `other` column: {}", + lance_read + ); + assert_eq!( + output_rows_for_node(&analysis, "LanceRead"), + Some(100), + "target scan should still visit all rows:\n{}", + analysis + ); + + // ...and `other` is materialized only for the 5 matched rows. + assert_eq!( + output_rows_for_node(&analysis, "Take"), + Some(5), + "wide column should be taken for only the matched rows:\n{}", + analysis + ); + } + + /// A partial-schema `UpdateIf` whose condition references the deferred + /// (non-source) target column must still evaluate correctly: the column + /// is fetched once by the `Take` and read from there by the action + /// expression, not double-fetched or lost. + #[tokio::test] + async fn test_merge_insert_subcols_update_if_on_deferred_column() { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + Field::new("other", DataType::Utf8, true), + ])); + // `other` gates the update; keys 0,2,4 are "keep", 1,3,5 are "skip". + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3, 4, 5])), + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3, 4, 5])), + Arc::new(StringArray::from(vec![ + "keep", "skip", "keep", "skip", "keep", "skip", + ])), + ], + ) + .unwrap(); + let ds = Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + Some(WriteParams { + max_rows_per_file: 3, // two fragments + ..Default::default() + }), + ) + .await + .unwrap(); + let ds = Arc::new(ds); + + // Partial-schema source (key, value) matching keys 0..=3. + let source_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + ])); + let new_data = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![100u32, 100, 100, 100])), + ], + ) + .unwrap(); + let reader = Box::new(RecordBatchIterator::new( + [Ok(new_data.clone())], + new_data.schema(), + )); + + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::update_if(&ds, "target.other = 'keep'").unwrap()) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + let (updated, _stats) = job.execute_reader(reader).await.unwrap(); + + let batch = updated.scan().try_into_batch().await.unwrap(); + let keys = batch["key"].as_any().downcast_ref::().unwrap(); + let values = batch["value"] + .as_any() + .downcast_ref::() + .unwrap(); + let others = batch["other"] + .as_any() + .downcast_ref::() + .unwrap(); + let by_key = (0..batch.num_rows()) + .map(|i| { + ( + keys.value(i), + (values.value(i), others.value(i).to_string()), + ) + }) + .collect::>(); + + // Matched + condition true (other == "keep"): value updated to 100. + assert_eq!(by_key[&0], (100, "keep".to_string())); + assert_eq!(by_key[&2], (100, "keep".to_string())); + // Matched + condition false (other == "skip"): value unchanged. + assert_eq!(by_key[&1], (1, "skip".to_string())); + assert_eq!(by_key[&3], (3, "skip".to_string())); + // Unmatched rows untouched. + assert_eq!(by_key[&4], (4, "keep".to_string())); + assert_eq!(by_key[&5], (5, "skip".to_string())); + } + + /// The width gate's negative branch: a *narrow* missing column must NOT + /// be deferred — a sequential scan of it is cheaper than a per-row take, + /// so it stays in the target scan and no `Take` is introduced. + #[tokio::test] + async fn test_merge_insert_subcols_narrow_column_not_deferred() { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + Field::new("small", DataType::UInt32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![10u32, 11, 12, 13])), + ], + ) + .unwrap(); + let ds = Arc::new( + Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + None, + ) + .await + .unwrap(), + ); + + // Source omits the narrow `small` column. + let source_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + ])); + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + let plan = job.explain_plan(Some(&source_schema), false).await.unwrap(); + assert!( - plan.contains("LanceRead") && plan.contains("projection=[other"), - "target-side scan should include the filled `other` column: {}", + !plan.contains("Take"), + "a narrow column must not be deferred via a Take: {}", plan ); + let lance_read = plan + .lines() + .find(|l| l.trim_start().starts_with("LanceRead")) + .unwrap_or_else(|| panic!("no LanceRead node: {}", plan)); + assert!( + lance_read.contains("small"), + "narrow `small` should be read directly by the target scan: {}", + lance_read + ); + } + + /// Deferral must remain correct when more than one wide column is + /// dropped from the scan: this exercises the multi-column index remap of + /// the join output and the name-based re-index of the parent projection. + #[tokio::test] + async fn test_merge_insert_subcols_defers_multiple_wide_columns() { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + Field::new("wide_a", DataType::Utf8, true), + Field::new("wide_b", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(UInt32Array::from(vec![0u32, 1, 2, 3])), + Arc::new(StringArray::from(vec!["a0", "a1", "a2", "a3"])), + Arc::new(StringArray::from(vec!["b0", "b1", "b2", "b3"])), + ], + ) + .unwrap(); + let ds = Arc::new( + Dataset::write( + RecordBatchIterator::new([Ok(batch)], schema.clone()), + "memory://", + Some(WriteParams { + max_rows_per_file: 2, // two fragments + ..Default::default() + }), + ) + .await + .unwrap(), + ); + + // Source omits both wide columns; updates keys 0 and 1. + let source_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + ])); + let new_data = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(UInt32Array::from(vec![0u32, 1])), + Arc::new(UInt32Array::from(vec![100u32, 100])), + ], + ) + .unwrap(); + + let job = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::DoNothing) + .try_build() + .unwrap(); + + // Both wide columns are deferred to the take, not the scan. + let plan = job + .explain_plan(Some(&new_data.schema().as_ref().clone()), false) + .await + .unwrap(); + let lance_read = plan + .lines() + .find(|l| l.trim_start().starts_with("LanceRead")) + .unwrap_or_else(|| panic!("no LanceRead node: {}", plan)); + assert!( + !lance_read.contains("wide_a") && !lance_read.contains("wide_b"), + "both wide columns must be deferred out of the scan: {}", + lance_read + ); assert!( - plan.contains("other@0 as other"), - "expected post-join projection to carry `other` from the target side: {}", + plan.contains("(wide_a)") && plan.contains("(wide_b)"), + "the take must fetch both deferred columns: {}", plan ); + + // And the result is correct: matched rows updated, wide columns preserved. + let reader = Box::new(RecordBatchIterator::new( + [Ok(new_data.clone())], + new_data.schema(), + )); + let (updated, _stats) = job.execute_reader(reader).await.unwrap(); + let batch = updated.scan().try_into_batch().await.unwrap(); + let keys = batch["key"].as_any().downcast_ref::().unwrap(); + let values = batch["value"] + .as_any() + .downcast_ref::() + .unwrap(); + let wide_a = batch["wide_a"] + .as_any() + .downcast_ref::() + .unwrap(); + let wide_b = batch["wide_b"] + .as_any() + .downcast_ref::() + .unwrap(); + let by_key = (0..batch.num_rows()) + .map(|i| { + ( + keys.value(i), + ( + values.value(i), + wide_a.value(i).to_string(), + wide_b.value(i).to_string(), + ), + ) + }) + .collect::>(); + assert_eq!(by_key[&0], (100, "a0".to_string(), "b0".to_string())); + assert_eq!(by_key[&1], (100, "a1".to_string(), "b1".to_string())); + assert_eq!(by_key[&2], (2, "a2".to_string(), "b2".to_string())); + assert_eq!(by_key[&3], (3, "a3".to_string(), "b3".to_string())); } /// Partial-schema upserts with `insert_not_matched=InsertAll` must diff --git a/rust/lance/src/io/exec.rs b/rust/lance/src/io/exec.rs index a477d60d56d..428362d83d8 100644 --- a/rust/lance/src/io/exec.rs +++ b/rust/lance/src/io/exec.rs @@ -15,6 +15,7 @@ pub mod filtered_read; pub mod filtered_read_proto; pub mod fts; pub(crate) mod knn; +mod late_take; mod optimizer; mod projection; mod pushdown_scan; @@ -32,6 +33,7 @@ pub use filter::LanceFilterExec; pub use knn::{ANNIvfPartitionExec, ANNIvfSubIndexExec, KNNVectorDistanceExec}; pub use lance_datafusion::planner::Planner; pub use lance_index::scalar::expression::FilterPlan; +pub use late_take::{LateMaterializeJoin, LateTakeNode, LateTakePlanner}; pub use optimizer::get_physical_optimizer; pub use projection::project; pub use pushdown_scan::{LancePushdownScanExec, ScanConfig}; diff --git a/rust/lance/src/io/exec/late_take.rs b/rust/lance/src/io/exec/late_take.rs new file mode 100644 index 00000000000..f9295594532 --- /dev/null +++ b/rust/lance/src/io/exec/late_take.rs @@ -0,0 +1,1212 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Late-materialization *logical* optimizer rule. +//! +//! Defers reading wide data columns that a row-reducing join only carries +//! through, fetching them by `_rowaddr` *after* the row count has shrunk. +//! +//! A [`LateTakeNode`] is inserted above the join — fed by a projection that +//! keeps only the columns carried past the join — and advertises an output +//! schema of "carried columns plus the deferred columns appended". Its +//! [`UserDefinedLogicalNodeCore::necessary_children_exprs`] reports that it does +//! *not* need the deferred columns from its child, only `_rowaddr`, so +//! DataFusion's stock `OptimizeProjections` rule prunes them from the scan +//! automatically and downstream references resolve the deferred columns from the +//! take by name. +//! +//! The node lowers to the physical [`super::TakeExec`] via [`LateTakePlanner`]. + +use std::collections::{BTreeSet, HashSet}; +use std::sync::Arc; + +use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; +use async_trait::async_trait; +use datafusion::{ + common::{ + Column, DFSchema, DFSchemaRef, Result as DFResult, TableReference, + tree_node::{Transformed, TreeNode, TreeNodeRecursion}, + }, + datasource::DefaultTableSource, + execution::SessionState, + logical_expr::{Expr, Extension, Join, JoinType, LogicalPlan, Projection}, + optimizer::{OptimizerConfig, OptimizerRule}, + physical_plan::ExecutionPlan, + physical_planner::{ExtensionPlanner, PhysicalPlanner}, +}; +use datafusion_expr::{UserDefinedLogicalNode, UserDefinedLogicalNodeCore}; +use lance_arrow::DataTypeExt; +use lance_core::datatypes::OnMissing; +use lance_core::{ROW_ADDR, ROW_ID}; + +use super::TakeExec; +use crate::Dataset; +use crate::datafusion::dataframe::LanceTableProvider; + +/// Width/storage gate: a column is worth deferring only if it is "wide" for the +/// backing storage — a variable-width type (strings, lists, vectors) or a +/// fixed-width type at or above the per-row byte threshold (1KB on cloud +/// storage, 10 bytes on local). Narrow columns are cheaper to read in the +/// sequential scan than to re-fetch by address. +fn is_wide_column(field: &lance_core::datatypes::Field, is_cloud: bool) -> bool { + if field.is_blob() { + return false; + } + let byte_width = field.data_type().byte_width_opt(); + if is_cloud { + byte_width.is_none_or(|bw| bw >= 1000) + } else { + byte_width.is_none_or(|bw| bw >= 10) + } +} + +/// Logical plan node that re-fetches `deferred_columns` from `dataset` by +/// `_rowaddr` after a row-reducing operator. +/// +/// Output schema = the input's columns with `deferred_columns` removed, +/// followed by the deferred columns appended in dataset-schema order (mirroring +/// the physical [`TakeExec`], which appends taken columns). Constructing the +/// schema this way makes it invariant under projection pushdown: whether or not +/// the child still produces a deferred column, the node advertises the same +/// output, so the rule can be inserted before pushdown prunes the scan. +#[derive(Debug)] +pub struct LateTakeNode { + input: LogicalPlan, + dataset: Arc, + /// Dataset field names to re-fetch by address, in dataset-schema order. + deferred_columns: Vec, + /// Qualifier for the appended deferred fields (e.g. `target`), matching the + /// relation the scanned columns came from. + qualifier: Option, + /// When true the deferred fields are nullable in the output even if the + /// dataset declares them non-null. Set above an outer join where the scan + /// side can be null-extended (its `_rowaddr` is null → NULL deferred value). + nullable_extra: bool, + schema: DFSchemaRef, +} + +impl PartialEq for LateTakeNode { + fn eq(&self, other: &Self) -> bool { + self.dataset.base == other.dataset.base + && self.deferred_columns == other.deferred_columns + && self.qualifier == other.qualifier + && self.nullable_extra == other.nullable_extra + && self.input == other.input + } +} + +impl Eq for LateTakeNode {} + +impl std::hash::Hash for LateTakeNode { + fn hash(&self, state: &mut H) { + self.dataset.base.hash(state); + self.deferred_columns.hash(state); + self.qualifier.hash(state); + self.nullable_extra.hash(state); + self.input.hash(state); + } +} + +impl PartialOrd for LateTakeNode { + // Orders by the only fields that have a natural order (`deferred_columns`, + // then `input`); `dataset`/`qualifier`/`nullable_extra` are part of equality + // but not ordered here, matching the sibling `MergeInsertWriteNode` impl. + fn partial_cmp(&self, other: &Self) -> Option { + match self.deferred_columns.partial_cmp(&other.deferred_columns) { + Some(std::cmp::Ordering::Equal) => self.input.partial_cmp(&other.input), + cmp => cmp, + } + } +} + +impl LateTakeNode { + /// Build a node that re-fetches `deferred_columns` (dataset field names, in + /// dataset-schema order) from `dataset` by row address. + /// + /// `qualifier` is the relation the deferred columns came from; the appended + /// output fields carry it so downstream references still resolve. Set + /// `nullable_extra` when the take sits above an outer join on whose optional + /// side the scan lives, so unmatched rows (null row address) yield NULL + /// deferred values. Errors if a deferred name is absent from the dataset. + pub fn try_new( + input: LogicalPlan, + dataset: Arc, + deferred_columns: Vec, + qualifier: Option, + nullable_extra: bool, + ) -> DFResult { + let schema = Self::build_output_schema( + &input, + &dataset, + &deferred_columns, + &qualifier, + nullable_extra, + )?; + Ok(Self { + input, + dataset, + deferred_columns, + qualifier, + nullable_extra, + schema, + }) + } + + /// Build `input columns (minus deferred) ++ deferred columns appended`. + fn build_output_schema( + input: &LogicalPlan, + dataset: &Dataset, + deferred_columns: &[String], + qualifier: &Option, + nullable_extra: bool, + ) -> DFResult { + let input_schema = input.schema(); + let deferred_set: HashSet<&str> = deferred_columns.iter().map(|s| s.as_str()).collect(); + + // A field is deferred only when both its name and its qualifier match the + // deferred columns' relation; a same-named field from another relation is + // a distinct column and must stay in place. + let mut qualified_fields: Vec<(Option, Arc)> = input_schema + .iter() + .filter(|(q, f)| { + !(*q == qualifier.as_ref() && deferred_set.contains(f.name().as_str())) + }) + .map(|(q, f)| (q.cloned(), f.clone())) + .collect(); + + let dataset_arrow = ArrowSchema::from(dataset.schema()); + for name in deferred_columns { + let field = dataset_arrow.field_with_name(name).map_err(|e| { + datafusion::error::DataFusionError::Plan(format!( + "late-materialization: deferred column '{name}' not found in dataset schema: {e}" + )) + })?; + let field = if nullable_extra && !field.is_nullable() { + field.clone().with_nullable(true) + } else { + field.clone() + }; + qualified_fields.push((qualifier.clone(), Arc::new(field))); + } + + Ok(Arc::new(DFSchema::new_with_metadata( + qualified_fields, + input_schema.metadata().clone(), + )?)) + } + + /// Index of the row-address (or, failing that, row-id) column in the child. + fn row_locator_index(&self) -> Option { + let input_schema = self.input.schema(); + input_schema + .index_of_column_by_name(self.qualifier.as_ref(), ROW_ADDR) + .or_else(|| input_schema.index_of_column_by_name(self.qualifier.as_ref(), ROW_ID)) + } +} + +impl UserDefinedLogicalNodeCore for LateTakeNode { + fn name(&self) -> &str { + "LateTake" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "LateTake: deferred=[{}], nullable_extra={}", + self.deferred_columns.join(", "), + self.nullable_extra + ) + } + + fn with_exprs_and_inputs( + &self, + exprs: Vec, + mut inputs: Vec, + ) -> DFResult { + if !exprs.is_empty() { + return Err(datafusion::error::DataFusionError::Internal( + "LateTakeNode does not accept expressions".to_string(), + )); + } + if inputs.len() != 1 { + return Err(datafusion::error::DataFusionError::Internal( + "LateTakeNode requires exactly one input".to_string(), + )); + } + Self::try_new( + inputs.remove(0), + self.dataset.clone(), + self.deferred_columns.clone(), + self.qualifier.clone(), + self.nullable_extra, + ) + } + + /// Drive projection pushdown: the deferred columns are produced by this + /// node (fetched by address), so they are never requested from the child; + /// `_rowaddr` is always required so the fetch remains possible. + fn necessary_children_exprs(&self, output_columns: &[usize]) -> Option>> { + let input_schema = self.input.schema(); + let deferred_set: HashSet<&str> = + self.deferred_columns.iter().map(|s| s.as_str()).collect(); + + // Output positions [0..passthrough_len) map back to these child indices, + // in order; positions beyond are the appended (fetched) deferred columns. + // The deferred match is qualified: a same-named field from another + // relation is a distinct passthrough column, not a deferred one. + let passthrough: Vec = input_schema + .iter() + .enumerate() + .filter(|(_, (q, f))| { + !(*q == self.qualifier.as_ref() && deferred_set.contains(f.name().as_str())) + }) + .map(|(i, _)| i) + .collect(); + + let row_locator = self.row_locator_index()?; + + let mut needed = BTreeSet::new(); + for &oc in output_columns { + if let Some(child_idx) = passthrough.get(oc) { + needed.insert(*child_idx); + } + } + needed.insert(row_locator); + Some(vec![needed.into_iter().collect()]) + } +} + +/// Logical optimizer rule that inserts a [`LateTakeNode`] above a join when a +/// wide column from a Lance table relation is only carried through the join. +/// +/// Detection is qualifier-driven, not side-specific: it inspects both join +/// inputs for a [`LanceTableProvider`] scan that emits `_rowaddr`, so it keeps +/// working if build/probe sides are swapped. The actual scan narrowing is left +/// to `OptimizeProjections`, which must run after this rule. +#[derive(Debug, Default)] +pub struct LateMaterializeJoin; + +impl LateMaterializeJoin { + pub fn new() -> Self { + Self + } + + /// Recover the Lance dataset backing a join input, descending only through + /// single-input nodes (alias/projection/filter) so a nested join's scan is + /// never picked up by mistake. + fn find_lance_dataset(plan: &LogicalPlan) -> Option> { + if let LogicalPlan::TableScan(scan) = plan { + let source = scan.source.as_any().downcast_ref::()?; + let provider = source + .table_provider + .as_any() + .downcast_ref::()?; + return Some(provider.dataset()); + } + let inputs = plan.inputs(); + if inputs.len() == 1 { + Self::find_lance_dataset(inputs[0]) + } else { + None + } + } + + /// Names of the join's equi-keys on `side` (the columns that must stay in + /// the scan because the join reads them). + fn join_key_names(join: &Join, side: JoinSide) -> HashSet { + join.on + .iter() + .filter_map(|(left, right)| { + let expr = match side { + JoinSide::Left => left, + JoinSide::Right => right, + }; + match expr { + Expr::Column(col) => Some(col.name.clone()), + _ => None, + } + }) + .collect() + } + + /// Collect every `(qualifier, name)` column reference that appears in an + /// expression anywhere in `plan`. Used to tell which scan-side columns are + /// actually consumed *above* a join (and so worth re-fetching) rather than + /// merely produced by the scan. + fn collect_referenced_columns( + plan: &LogicalPlan, + ) -> DFResult, String)>> { + let mut referenced = HashSet::new(); + plan.apply(|node| { + // A join's own on-clause / filter columns are consumed by the join + // itself, not by an operator above it, so they must not count as + // "used above the join". (Equi-keys are the common case: both sides' + // keys share a name, which would otherwise look like a duplicate + // column flowing into the take.) Children are still visited. + if matches!(node, LogicalPlan::Join(_)) { + return Ok(TreeNodeRecursion::Continue); + } + for expr in node.expressions() { + expr.apply(|e| { + if let Expr::Column(col) = e { + referenced.insert((col.relation.clone(), col.name.clone())); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(referenced) + } + + fn try_defer_join( + join: &Join, + referenced: &HashSet<(Option, String)>, + ) -> DFResult> { + // Only column-preserving joins; a join filter may reference a column we + // would defer, so bail rather than reason about it. + if !matches!( + join.join_type, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) || join.filter.is_some() + { + return Ok(None); + } + + for side in [JoinSide::Left, JoinSide::Right] { + let side_plan: &LogicalPlan = match side { + JoinSide::Left => &join.left, + JoinSide::Right => &join.right, + }; + let Some(dataset) = Self::find_lance_dataset(side_plan) else { + continue; + }; + let side_schema = side_plan.schema(); + + // The scan side must emit `_rowaddr` so deferred columns stay + // fetchable by address after the join. + if side_schema + .index_of_column_by_name(None, ROW_ADDR) + .is_none() + { + continue; + } + let qualifier = side_schema + .iter() + .find(|(_, f)| f.name() == ROW_ADDR) + .and_then(|(q, _)| q.cloned()); + + let key_names = Self::join_key_names(join, side); + let is_cloud = dataset.object_store.is_cloud(); + let dataset_arrow = ArrowSchema::from(dataset.schema()); + + // Candidates = scan-side data columns that aren't join keys nor the + // system columns, that are consumed above the join (a column nobody + // references would just be pruned, so re-fetching it is wasted work), + // and are wide enough to be worth re-fetching by address. + let candidates: HashSet<&str> = side_schema + .iter() + .filter(|(col_qualifier, field)| { + let name = field.name(); + name != ROW_ADDR + && name != ROW_ID + && !key_names.contains(name) + && referenced.contains(&(col_qualifier.cloned(), name.clone())) + && dataset + .schema() + .field(name) + .is_some_and(|f| is_wide_column(f, is_cloud)) + }) + .map(|(_, field)| field.name().as_str()) + .collect(); + if candidates.is_empty() { + continue; + } + // Order by the dataset schema to match TakeExec's append order. + let deferred_columns: Vec = dataset_arrow + .fields() + .iter() + .map(|f| f.name()) + .filter(|name| candidates.contains(name.as_str())) + .cloned() + .collect(); + + // Insert a normalizing projection between the join and the take, + // keeping only the columns carried past the join (referenced above it + // plus the row locator) and dropping the deferred ones. This gives the + // join a tight output projection — without it the opaque take node + // forces `HashJoinExec` to emit its full `left ++ right`, whose + // equi-keys become duplicate arrow names the take (merging by name) + // can't handle once qualifiers are erased — and it is where scan + // narrowing happens, since the absent columns are pushed out of the + // scan. `referenced` excludes the join's own on-clause, so a redundant + // equi-key is pruned here; if two *referenced* columns still share a + // name the take can't disambiguate them, so bail. + let deferred_set: HashSet<&str> = deferred_columns.iter().map(|s| s.as_str()).collect(); + let mut kept_exprs: Vec = Vec::new(); + // Seed with the deferred names: they are appended to the take output, + // so a kept column sharing one of those names also conflicts. + let mut kept_names: HashSet = deferred_columns.iter().cloned().collect(); + let mut name_conflict = false; + for (col_qualifier, field) in join.schema.iter() { + let name = field.name(); + // Drop the deferred columns: they are re-fetched and appended by + // the take, not carried through its input. + if col_qualifier == qualifier.as_ref() && deferred_set.contains(name.as_str()) { + continue; + } + let kept = name == ROW_ADDR + || name == ROW_ID + || referenced.contains(&(col_qualifier.cloned(), name.clone())); + if !kept { + continue; + } + if !kept_names.insert(name.clone()) { + name_conflict = true; + break; + } + kept_exprs.push(Expr::Column(Column::new(col_qualifier.cloned(), name))); + } + if name_conflict { + continue; + } + + let take_input = LogicalPlan::Projection(Projection::try_new( + kept_exprs, + Arc::new(LogicalPlan::Join(join.clone())), + )?); + + // The scan side is null-extended when it is the optional side of an + // outer join; its unmatched rows then have a null `_rowaddr` and + // must yield NULL deferred values. + let nullable_extra = matches!( + (side, join.join_type), + (JoinSide::Left, JoinType::Right) + | (JoinSide::Left, JoinType::Full) + | (JoinSide::Right, JoinType::Left) + | (JoinSide::Right, JoinType::Full) + ); + + let node = LateTakeNode::try_new( + take_input, + dataset, + deferred_columns, + qualifier, + nullable_extra, + )?; + return Ok(Some(LogicalPlan::Extension(Extension { + node: Arc::new(node), + }))); + } + + Ok(None) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum JoinSide { + Left, + Right, +} + +impl OptimizerRule for LateMaterializeJoin { + fn name(&self) -> &str { + "late_materialize_join" + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> DFResult> { + let referenced = Self::collect_referenced_columns(&plan)?; + plan.transform_down(|node| { + // Never descend into an already-deferred subtree, otherwise the + // inner join would be wrapped again on this or a later pass. + if let LogicalPlan::Extension(ext) = &node + && ext.node.as_any().is::() + { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); + } + if let LogicalPlan::Join(join) = &node + && let Some(wrapped) = Self::try_defer_join(join, &referenced)? + { + // Jump: the freshly wrapped join must not be revisited. + return Ok(Transformed::new(wrapped, true, TreeNodeRecursion::Jump)); + } + Ok(Transformed::no(node)) + }) + } +} + +/// Lowers a [`LateTakeNode`] to the physical [`TakeExec`]. +#[derive(Debug)] +pub struct LateTakePlanner; + +#[async_trait] +impl ExtensionPlanner for LateTakePlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> DFResult>> { + let Some(take_node) = node.as_any().downcast_ref::() else { + return Ok(None); + }; + assert_eq!(physical_inputs.len(), 1, "LateTake requires one input"); + let input = physical_inputs[0].clone(); + + let mut projection = take_node.dataset.empty_projection(); + for name in &take_node.deferred_columns { + projection = projection.union_column(name, OnMissing::Error)?; + } + + let take = if take_node.nullable_extra { + TakeExec::try_new_nullable_extra(take_node.dataset.clone(), input.clone(), projection)? + } else { + TakeExec::try_new(take_node.dataset.clone(), input.clone(), projection)? + }; + + // `try_new` returns None when no extra columns are needed; fall back to + // the raw input so the plan still lowers. + Ok(Some(match take { + Some(take) => Arc::new(take) as Arc, + None => input, + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow_array::{Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef}; + use datafusion::execution::{SessionStateBuilder, TaskContext}; + use datafusion::logical_expr::TableScan; + use datafusion::optimizer::{ + Optimizer, OptimizerContext, optimize_projections::OptimizeProjections, + }; + use datafusion::physical_plan::{collect, displayable}; + use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; + use datafusion::prelude::*; + use lance_core::utils::tempfile::TempStrDir; + + use crate::datafusion::dataframe::SessionContextExt; + use crate::dataset::WriteParams; + + /// Write a `{id: Int32 (key), payload: Utf8 (wide), tag: Int32 (narrow)}` + /// dataset to a temp dir and open it. + async fn test_dataset() -> (Arc, TempStrDir) { + let schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("payload", DataType::Utf8, true), + Field::new("tag", DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(StringArray::from(vec![ + "alpha", "bravo", "charlie", "delta", "echo", + ])), + Arc::new(Int32Array::from(vec![10, 20, 30, 40, 50])), + ], + ) + .unwrap(); + + let tmp = TempStrDir::default(); + let reader = RecordBatchIterator::new(vec![Ok(batch)].into_iter(), schema.clone()); + Dataset::write(reader, tmp.as_str(), Some(WriteParams::default())) + .await + .unwrap(); + let dataset = Arc::new(Dataset::open(tmp.as_str()).await.unwrap()); + (dataset, tmp) + } + + /// A source DataFrame `{sid}` aliased "source". A distinct key name (vs the + /// target's `id`) keeps the join output free of duplicate names so the take + /// can sit above it. + fn source_df(ctx: &SessionContext, ids: Vec) -> DataFrame { + let schema: SchemaRef = Arc::new(ArrowSchema::new(vec![Field::new( + "sid", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap(); + ctx.read_batch(batch).unwrap().alias("source").unwrap() + } + + fn find_table_scan(plan: &LogicalPlan) -> Option<&TableScan> { + if let LogicalPlan::TableScan(scan) = plan { + return Some(scan); + } + plan.inputs().into_iter().find_map(find_table_scan) + } + + fn has_late_take(plan: &LogicalPlan) -> bool { + if let LogicalPlan::Extension(ext) = plan + && ext.node.as_any().is::() + { + return true; + } + plan.inputs().iter().any(|p| has_late_take(p)) + } + + fn scan_column_names(plan: &LogicalPlan) -> Vec { + find_table_scan(plan) + .unwrap() + .projected_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() + } + + fn run_rule_and_pushdown(plan: LogicalPlan) -> LogicalPlan { + let optimizer = Optimizer::with_rules(vec![ + Arc::new(LateMaterializeJoin::new()), + Arc::new(OptimizeProjections::new()), + ]); + optimizer + .optimize(plan, &OptimizerContext::new(), |_, _| {}) + .unwrap() + } + + /// Build `Projection(select cols) <- Join(target, source) <- scan`. + async fn join_plan( + ctx: &SessionContext, + dataset: Arc, + join_type: JoinType, + select: &[&str], + source_ids: Vec, + ) -> LogicalPlan { + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let source = source_df(ctx, source_ids); + let exprs = select.iter().map(|c| col(*c)).collect::>(); + target + .join(source, join_type, &["id"], &["sid"], None) + .unwrap() + .select(exprs) + .unwrap() + .into_unoptimized_plan() + } + + #[tokio::test] + async fn test_wide_column_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let plan = join_plan( + &ctx, + dataset, + JoinType::Inner, + &["target.id", "target.payload", "target.tag"], + vec![1, 3], + ) + .await; + + let before = plan.schema().clone(); + let optimized = run_rule_and_pushdown(plan); + + assert!( + has_late_take(&optimized), + "expected a LateTake node:\n{}", + optimized.display_indent() + ); + + let scan_cols = scan_column_names(&optimized); + assert!(scan_cols.contains(&"id".to_string()), "scan: {scan_cols:?}"); + assert!( + scan_cols.contains(&ROW_ADDR.to_string()), + "scan: {scan_cols:?}" + ); + // wide column dropped from the scan; narrow `tag` (used above the join) + // is not deferred and stays in the scan. + assert!( + !scan_cols.contains(&"payload".to_string()), + "payload should be deferred, scan: {scan_cols:?}" + ); + assert!( + scan_cols.contains(&"tag".to_string()), + "scan: {scan_cols:?}" + ); + + // The plan's output schema is unchanged by the rewrite. + assert_eq!(before.fields(), optimized.schema().fields()); + } + + #[tokio::test] + async fn test_no_wide_columns_no_take() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // Only narrow columns used above the join → nothing to defer. + let plan = join_plan( + &ctx, + dataset, + JoinType::Inner, + &["target.id", "target.tag"], + vec![1, 3], + ) + .await; + let optimized = run_rule_and_pushdown(plan); + + assert!(!has_late_take(&optimized), "no take expected"); + let scan_cols = scan_column_names(&optimized); + assert!( + scan_cols.contains(&"tag".to_string()), + "scan: {scan_cols:?}" + ); + } + + #[tokio::test] + async fn test_join_key_not_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // Join on the wide `payload` column: as a key it must stay in the scan, + // and there is no other wide column to defer. + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![Field::new( + "spayload", + DataType::Utf8, + true, + )])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![Arc::new(StringArray::from(vec!["alpha", "charlie"]))], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + let plan = target + .join(source, JoinType::Inner, &["payload"], &["spayload"], None) + .unwrap() + .select(vec![col("target.id"), col("target.payload")]) + .unwrap() + .into_unoptimized_plan(); + + let optimized = run_rule_and_pushdown(plan); + assert!(!has_late_take(&optimized), "join key must not be deferred"); + let scan_cols = scan_column_names(&optimized); + assert!( + scan_cols.contains(&"payload".to_string()), + "join key stays in scan: {scan_cols:?}" + ); + } + + #[tokio::test] + async fn test_duplicate_join_output_names_not_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // Both sides expose `payload` and both are consumed above the join, so + // the take (which merges appended columns by name) would produce a + // duplicate `payload` and must not be inserted. + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("sid", DataType::Int32, false), + Field::new("payload", DataType::Utf8, true), + ])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(Int32Array::from(vec![1, 3])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + let plan = target + .join(source, JoinType::Inner, &["id"], &["sid"], None) + .unwrap() + // Reference both sides' `payload` above the join: deferring + // `target.payload` would collide with the kept `source.payload`. + .select(vec![ + col("target.id"), + col("target.payload"), + col("source.payload"), + ]) + .unwrap() + .into_unoptimized_plan(); + + let optimized = run_rule_and_pushdown(plan); + assert!( + !has_late_take(&optimized), + "duplicate output names must not be deferred" + ); + } + + /// The deferred-column match is qualified: a same-named column from another + /// relation must not be dropped just because it shares a deferred name. + #[tokio::test] + async fn test_deferred_match_respects_qualifier() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let target = ctx + .read_lance_unordered(dataset.clone(), false, true) + .unwrap() + .alias("target") + .unwrap(); + // The source relation also exposes a `payload` column. + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("sid", DataType::Int32, false), + Field::new("payload", DataType::Utf8, true), + ])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(Int32Array::from(vec![1, 3])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + let input = target + .join(source, JoinType::Inner, &["id"], &["sid"], None) + .unwrap() + .into_unoptimized_plan(); + let input_field_count = input.schema().fields().len(); + + // Defer only `target.payload`; `source.payload` is a distinct column. + let node = LateTakeNode::try_new( + input, + dataset, + vec!["payload".to_string()], + Some(TableReference::bare("target")), + false, + ) + .unwrap(); + + let schema = UserDefinedLogicalNodeCore::schema(&node); + let target_ref = TableReference::bare("target"); + let source_ref = TableReference::bare("source"); + // `source.payload` survives (only the qualified `target.payload` is + // dropped from the passthrough), and `target.payload` is re-appended — + // so the output keeps the same field count as the input. + assert!( + schema + .index_of_column_by_name(Some(&source_ref), "payload") + .is_some(), + "source.payload must not be dropped by qualifier-blind matching" + ); + assert!( + schema + .index_of_column_by_name(Some(&target_ref), "payload") + .is_some(), + "target.payload should be appended by the take" + ); + assert_eq!(schema.fields().len(), input_field_count); + } + + #[test] + fn test_is_wide_column() { + use lance_core::datatypes::Field as LanceField; + use std::collections::HashMap; + + let lance_field = |arrow: Field| LanceField::try_from(arrow).unwrap(); + + // Variable-width: wide regardless of storage. + let utf8 = lance_field(Field::new("s", DataType::Utf8, true)); + assert!(is_wide_column(&utf8, false)); + assert!(is_wide_column(&utf8, true)); + + // Small fixed-width (4 bytes): narrow on both local and cloud. + let small = lance_field(Field::new("n", DataType::Int32, true)); + assert!(!is_wide_column(&small, false)); + assert!(!is_wide_column(&small, true)); + + // Mid fixed-width (64 bytes): above the 10-byte local threshold but + // below the 1KB cloud threshold — wide locally, narrow on cloud. + let mid = lance_field(Field::new("m", DataType::FixedSizeBinary(64), true)); + assert!(is_wide_column(&mid, false)); + assert!(!is_wide_column(&mid, true)); + + // Large fixed-width (2KB): wide on both. + let large = lance_field(Field::new("l", DataType::FixedSizeBinary(2048), true)); + assert!(is_wide_column(&large, false)); + assert!(is_wide_column(&large, true)); + + // Blob columns are never deferred even though LargeBinary is + // variable-width — they have their own access path. + let blob_meta = + HashMap::from([(lance_arrow::BLOB_META_KEY.to_string(), "true".to_string())]); + let blob = + lance_field(Field::new("b", DataType::LargeBinary, true).with_metadata(blob_meta)); + assert!(blob.is_blob()); + assert!(!is_wide_column(&blob, false)); + assert!(!is_wide_column(&blob, true)); + } + + /// Two *non-deferred* columns from different relations that share a name + /// (here `tag`, on both sides) collide in the take's input. This is the + /// kept-vs-kept conflict path — distinct from the duplicate-name test above, + /// which collides a kept column with a *deferred* name. + #[tokio::test] + async fn test_kept_column_name_conflict_not_deferred() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + // Source carries a narrow `tag` (never a defer candidate itself). + let source_schema: SchemaRef = Arc::new(ArrowSchema::new(vec![ + Field::new("sid", DataType::Int32, false), + Field::new("tag", DataType::Int32, true), + ])); + let source_batch = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(Int32Array::from(vec![1, 3])), + Arc::new(Int32Array::from(vec![100, 300])), + ], + ) + .unwrap(); + let source = ctx + .read_batch(source_batch) + .unwrap() + .alias("source") + .unwrap(); + // `payload` is wide (deferrable), but both `target.tag` and `source.tag` + // are carried past the join → their shared name would collide in the + // take's input, so the rule must bail. + let plan = target + .join(source, JoinType::Inner, &["id"], &["sid"], None) + .unwrap() + .select(vec![ + col("target.id"), + col("target.payload"), + col("target.tag"), + col("source.tag"), + ]) + .unwrap() + .into_unoptimized_plan(); + + let optimized = run_rule_and_pushdown(plan); + assert!( + !has_late_take(&optimized), + "a kept-column name collision must not be deferred" + ); + } + + /// `find_lance_dataset` must descend only single-input nodes, so a nested + /// join's scan is never mistaken for a scannable relation of an outer join. + #[tokio::test] + async fn test_find_lance_dataset_descends_single_input_only() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + + // A single-input chain (SubqueryAlias -> TableScan) resolves to the dataset. + let scan_plan = ctx + .read_lance_unordered(dataset.clone(), false, true) + .unwrap() + .alias("target") + .unwrap() + .into_unoptimized_plan(); + assert!(LateMaterializeJoin::find_lance_dataset(&scan_plan).is_some()); + + // A join is multi-input: neither side's scan is surfaced, so the rule + // cannot pick up a nested join's relation by mistake. + let target = ctx + .read_lance_unordered(dataset, false, true) + .unwrap() + .alias("target") + .unwrap(); + let join_plan = target + .join( + source_df(&ctx, vec![1, 3]), + JoinType::Inner, + &["id"], + &["sid"], + None, + ) + .unwrap() + .into_unoptimized_plan(); + assert!(LateMaterializeJoin::find_lance_dataset(&join_plan).is_none()); + } + + #[tokio::test] + async fn test_outer_join_marks_deferred_nullable() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + // RIGHT join: target (left input) is the null-extended side. + let plan = join_plan( + &ctx, + dataset, + JoinType::Right, + &["target.id", "target.payload"], + vec![1, 3], + ) + .await; + let optimized = run_rule_and_pushdown(plan); + assert!(has_late_take(&optimized)); + + // Locate the node and confirm the deferred field is nullable. + fn find_node(plan: &LogicalPlan) -> Option<&LateTakeNode> { + if let LogicalPlan::Extension(ext) = plan + && let Some(n) = ext.node.as_any().downcast_ref::() + { + return Some(n); + } + plan.inputs().into_iter().find_map(find_node) + } + let node = find_node(&optimized).unwrap(); + assert!(node.nullable_extra); + let payload = UserDefinedLogicalNodeCore::schema(node) + .field_with_unqualified_name("payload") + .unwrap(); + assert!(payload.is_nullable()); + } + + #[tokio::test] + async fn test_necessary_children_exprs() { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let input = ctx + .read_lance_unordered(dataset.clone(), false, true) + .unwrap() + .alias("target") + .unwrap() + .into_unoptimized_plan(); + // child schema order: id(0), payload(1), tag(2), _rowaddr(3) + let rowaddr_idx = input + .schema() + .index_of_column_by_name(None, ROW_ADDR) + .unwrap(); + let node = LateTakeNode::try_new( + input, + dataset, + vec!["payload".to_string()], + Some(TableReference::bare("target")), + false, + ) + .unwrap(); + + // Output: id(0), tag(1), _rowaddr(2), payload(3, appended/fetched). + // All outputs requested → child needs id, tag, _rowaddr (not payload). + assert_eq!( + UserDefinedLogicalNodeCore::necessary_children_exprs(&node, &[0, 1, 2, 3]), + Some(vec![vec![0, 2, rowaddr_idx]]) + ); + // Only the deferred column requested → child still only needs _rowaddr. + assert_eq!( + UserDefinedLogicalNodeCore::necessary_children_exprs(&node, &[3]), + Some(vec![vec![rowaddr_idx]]) + ); + // Nothing requested → _rowaddr is still forced in. + assert_eq!( + UserDefinedLogicalNodeCore::necessary_children_exprs(&node, &[]), + Some(vec![vec![rowaddr_idx]]) + ); + } + + /// Lower the rewritten plan to physical and confirm it (a) inserts a + /// `TakeExec` and (b) produces exactly the same rows as the un-deferred plan, + /// including NULL deferred values for source-only rows of an outer join. + async fn assert_execution_parity(join_type: JoinType, source_ids: Vec) { + let (dataset, _tmp) = test_dataset().await; + let ctx = SessionContext::new(); + let state = SessionStateBuilder::new().with_default_features().build(); + + let plan = join_plan( + &ctx, + dataset, + join_type, + &["target.id", "target.payload"], + source_ids, + ) + .await; + + // Baseline: standard optimization + default planner. + let baseline_logical = state.optimize(&plan).unwrap(); + let baseline = state.create_physical_plan(&baseline_logical).await.unwrap(); + let baseline_rows = collect(baseline, Arc::new(TaskContext::default())) + .await + .unwrap(); + + // Deferred: insert the take, prune, then lower with our planner. + let optimized = run_rule_and_pushdown(plan); + let planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(LateTakePlanner)]); + let physical = planner + .create_physical_plan(&optimized, &state) + .await + .unwrap(); + let rendered = displayable(physical.as_ref()).indent(true).to_string(); + assert!( + rendered.contains("Take"), + "expected a TakeExec in the plan:\n{rendered}" + ); + let deferred_rows = collect(physical, Arc::new(TaskContext::default())) + .await + .unwrap(); + + // Compare row sets (order-independent). + let sort = |batches: &[RecordBatch]| { + let mut rows: Vec<(Option, Option)> = Vec::new(); + for b in batches { + let ids = b.column(0).as_any().downcast_ref::().unwrap(); + let payloads = b.column(1).as_any().downcast_ref::().unwrap(); + for i in 0..b.num_rows() { + rows.push(( + (!ids.is_null(i)).then(|| ids.value(i)), + (!payloads.is_null(i)).then(|| payloads.value(i).to_string()), + )); + } + } + rows.sort(); + rows + }; + assert_eq!(sort(&baseline_rows), sort(&deferred_rows)); + } + + #[tokio::test] + async fn test_execution_parity_inner() { + assert_execution_parity(JoinType::Inner, vec![1, 3]).await; + } + + #[tokio::test] + async fn test_execution_parity_outer_null_rowaddr() { + // RIGHT join with a source-only id (99) that has no target match: that + // row's `target._rowaddr` is null, so the take must scatter a NULL + // payload for it — matching the un-deferred plan. + assert_execution_parity(JoinType::Right, vec![1, 99]).await; + } +} diff --git a/rust/lance/src/io/exec/take.rs b/rust/lance/src/io/exec/take.rs index c3642cdb043..9c947be9e08 100644 --- a/rust/lance/src/io/exec/take.rs +++ b/rust/lance/src/io/exec/take.rs @@ -219,11 +219,33 @@ impl TakeStream { let row_addrs = row_addrs_arr.as_primitive::(); - debug_assert!( - row_addrs.null_count() == 0, - "{} nulls in row addresses", - row_addrs.null_count() - ); + // Rows with a null address have no row in the dataset to fetch. This + // happens for insert (source-only) rows in an outer-join take, where + // the target side contributes no `_rowaddr`. We fetch only the + // non-null addresses and, after reading, scatter the fetched columns + // back into the original row positions, leaving the taken columns NULL + // for the null-address rows. + let null_count = row_addrs.null_count(); + let (row_addrs_arr, scatter_indices): (Arc, Option) = + if null_count > 0 { + let mut compacted = Vec::with_capacity(row_addrs.len() - null_count); + let mut scatter = Vec::with_capacity(row_addrs.len()); + for addr in row_addrs.iter() { + if let Some(addr) = addr { + scatter.push(Some(compacted.len() as u32)); + compacted.push(addr); + } else { + scatter.push(None); + } + } + ( + Arc::new(UInt64Array::from(compacted)), + Some(UInt32Array::from(scatter)), + ) + } else { + (row_addrs_arr, None) + }; + let row_addrs = row_addrs_arr.as_primitive::(); // Fast path: check if addresses are already sorted with no duplicates (common case). // This avoids all sorting, dedup, and permutation overhead. @@ -319,7 +341,19 @@ impl TakeStream { let batches = futures.try_collect::>().await?; if batches.is_empty() { - return Ok(RecordBatch::new_empty(self.output_schema.clone())); + match scatter_indices { + // Genuinely empty input (no rows at all). + None => return Ok(RecordBatch::new_empty(self.output_schema.clone())), + // Every address was null (e.g. an all-insert take): emit the + // input rows with NULL taken columns. + Some(scatter) => { + let empty = RecordBatch::new_empty(Arc::new(ArrowSchema::from( + self.fields_to_take.as_ref(), + ))); + let new_data = scatter_taken(&empty, &scatter)?; + return Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?); + } + } } let _compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer(); @@ -355,6 +389,14 @@ impl TakeStream { (None, None) => {} } + // Scatter the fetched columns back to the caller's original row count, + // inserting NULLs for rows whose address was null. + let new_data = if let Some(scatter) = scatter_indices { + scatter_taken(&new_data, &scatter)? + } else { + new_data + }; + Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?) } @@ -399,6 +441,32 @@ impl TakeStream { } } +/// Expand a batch of fetched ("taken") columns back to the caller's original +/// row count using `scatter`: a map from each original row index to the +/// position of that row within the fetched set, or null for rows that had no +/// address (and therefore were not fetched). +/// +/// The taken columns are returned as nullable, since null-address rows yield +/// NULL values that a non-nullable field could not hold. +fn scatter_taken(taken: &RecordBatch, scatter: &UInt32Array) -> DataFusionResult { + // Take per-column rather than via `take_record_batch` so we can attach a + // nullable schema before constructing the batch; otherwise `RecordBatch` + // validation would reject the new NULLs in non-nullable columns. + let columns = taken + .columns() + .iter() + .map(|c| arrow::compute::take(c, scatter, None)) + .collect::>>()?; + let fields = taken + .schema() + .fields() + .iter() + .map(|f| Arc::new(f.as_ref().clone().with_nullable(true))) + .collect::>(); + let schema = Arc::new(ArrowSchema::new(fields)); + Ok(RecordBatch::try_new(schema, columns)?) +} + #[derive(Debug)] pub struct TakeExec { // The dataset to take from @@ -415,6 +483,11 @@ pub struct TakeExec { input: Arc, properties: Arc, metrics: ExecutionPlanMetricsSet, + // When true, the taken (extra) fields are marked nullable in the output + // schema even if the dataset declares them non-null. This is used when the + // take sits above an outer join, where rows with a null address (inserts) + // legitimately produce NULL taken values. + nullable_extra_fields: bool, } impl DisplayAs for TakeExec { @@ -462,6 +535,27 @@ impl TakeExec { dataset: Arc, input: Arc, projection: Projection, + ) -> Result> { + Self::try_new_impl(dataset, input, projection, false) + } + + /// Like [`TakeExec::try_new`], but marks the taken (extra) fields nullable + /// in the output schema. Use this when the take sits above an outer join, + /// where rows with a null address (inserts) yield NULL taken values that a + /// non-nullable field could not hold. + pub fn try_new_nullable_extra( + dataset: Arc, + input: Arc, + projection: Projection, + ) -> Result> { + Self::try_new_impl(dataset, input, projection, true) + } + + fn try_new_impl( + dataset: Arc, + input: Arc, + projection: Projection, + nullable_extra_fields: bool, ) -> Result> { let original_projection = projection.clone(); let projection = @@ -492,7 +586,30 @@ impl TakeExec { &input.schema(), &projection, )); - let output_arrow = Arc::new(ArrowSchema::from(output_schema.as_ref())); + let schema_to_take = projection.into_schema_ref(); + let output_arrow = ArrowSchema::from(output_schema.as_ref()); + let output_arrow = if nullable_extra_fields { + let extra_names = schema_to_take + .fields + .iter() + .map(|f| f.name.as_str()) + .collect::>(); + let fields = output_arrow + .fields() + .iter() + .map(|f| { + if extra_names.contains(f.name().as_str()) && !f.is_nullable() { + Arc::new(f.as_ref().clone().with_nullable(true)) + } else { + f.clone() + } + }) + .collect::>(); + ArrowSchema::new_with_metadata(fields, output_arrow.metadata().clone()) + } else { + output_arrow + }; + let output_arrow = Arc::new(output_arrow); let properties = Arc::new( input .properties() @@ -504,11 +621,12 @@ impl TakeExec { Ok(Some(Self { dataset, output_projection: original_projection, - schema_to_take: projection.into_schema_ref(), + schema_to_take, input, output_schema: output_arrow, properties, metrics: ExecutionPlanMetricsSet::new(), + nullable_extra_fields, })) } @@ -609,7 +727,12 @@ impl ExecutionPlan for TakeExec { let projection = self.output_projection.clone(); - let plan = Self::try_new(self.dataset.clone(), children[0].clone(), projection)?; + let plan = Self::try_new_impl( + self.dataset.clone(), + children[0].clone(), + projection, + self.nullable_extra_fields, + )?; if let Some(plan) = plan { Ok(Arc::new(plan)) @@ -1094,6 +1217,94 @@ mod tests { assert_eq!(s_col.value(1), s_col.value(3)); // both row 0 } + /// A null row address means "no row to fetch" (e.g. an insert row in an + /// outer-join take). The taken column must come back NULL for that row and + /// the row must be preserved (not dropped). + #[tokio::test] + async fn test_take_with_null_row_addrs() { + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; + + let row_addrs = UInt64Array::from(vec![Some(0u64), None, Some(2), None, Some(1)]); + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + ROW_ADDR, + DataType::UInt64, + true, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(row_addrs)]).unwrap(); + let stream = futures::stream::iter(vec![Ok(batch)]); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + let input = Arc::new(OneShotExec::new(stream)); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); + + let stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let all = concat_batches(&batches[0].schema(), &batches).unwrap(); + assert_eq!(all.num_rows(), 5); + + let s = all + .column_by_name("s") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(s.value(0), "str-0"); + assert!(s.is_null(1)); + assert_eq!(s.value(2), "str-2"); + assert!(s.is_null(3)); + assert_eq!(s.value(4), "str-1"); + } + + /// When *every* address is null (e.g. an all-insert take), all input rows + /// must be preserved with NULL taken columns. + #[tokio::test] + async fn test_take_all_null_row_addrs() { + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; + + let row_addrs = UInt64Array::from(vec![None, None, None]); + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + ROW_ADDR, + DataType::UInt64, + true, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(row_addrs)]).unwrap(); + let stream = futures::stream::iter(vec![Ok(batch)]); + let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + let input = Arc::new(OneShotExec::new(stream)); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection) + .unwrap() + .unwrap(); + + let stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + let all = concat_batches(&batches[0].schema(), &batches).unwrap(); + assert_eq!(all.num_rows(), 3); + + let s = all.column_by_name("s").unwrap(); + assert_eq!(s.null_count(), 3); + } + #[tokio::test] async fn test_take_struct() { // When taking fields into an existing struct, the field order should be maintained