From 4675c43305cc3b4f98f067713271176903a5ff6c Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 14 Apr 2026 19:57:04 +0530 Subject: [PATCH 01/23] Push TopK (Sort with fetch) through outer joins --- datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 + .../src/push_down_topk_through_join.rs | 405 ++++++++++++++++++ .../push_down_topk_through_join.slt | 176 ++++++++ 4 files changed, 584 insertions(+) create mode 100644 datafusion/optimizer/src/push_down_topk_through_join.rs create mode 100644 datafusion/sqllogictest/test_files/push_down_topk_through_join.slt diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index e610091824092..e8309a3ceb028 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -65,6 +65,7 @@ pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; +pub mod push_down_topk_through_join; pub mod replace_distinct_aggregate; pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index bdea6a83072cd..1f9d1de863239 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -51,6 +51,7 @@ use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; +use crate::push_down_topk_through_join::PushDownTopKThroughJoin; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; @@ -296,6 +297,7 @@ impl Optimizer { Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit Arc::new(PushDownLimit::new()), + Arc::new(PushDownTopKThroughJoin::new()), Arc::new(PushDownFilter::new()), Arc::new(SingleDistinctToGroupBy::new()), // The previous optimizations added expressions and projections, diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs new file mode 100644 index 0000000000000..d8f18d9a9ec30 --- /dev/null +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -0,0 +1,405 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through outer joins +//! +//! When a `Sort` with a fetch limit sits above an outer join and all sort +//! expressions come from the **preserved** side, this rule inserts a copy +//! of the `Sort(fetch)` on that input to reduce the number of rows +//! entering the join. +//! +//! This is correct because: +//! - A LEFT JOIN preserves every left row (each appears at least once in the +//! output). The final top-N by left-side columns must come from the top-N +//! left rows. +//! - The same reasoning applies symmetrically for RIGHT JOIN and right-side +//! columns. +//! +//! The top-level sort is kept for correctness since a 1-to-many join can +//! produce more than N output rows from N input rows. +//! +//! ## Example +//! +//! Before: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Scan: t1 ← scans ALL rows +//! Scan: t2 +//! ``` +//! +//! After: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Sort: t1.b ASC, fetch=3 ← pushed down +//! Scan: t1 +//! Scan: t2 +//! ``` + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use crate::utils::{has_all_column_refs, schema_columns}; +use datafusion_common::Result; +use datafusion_common::tree_node::Transformed; +use datafusion_expr::logical_plan::{JoinType, LogicalPlan, Sort as SortPlan}; + +/// Optimization rule that pushes TopK (Sort with fetch) through +/// LEFT / RIGHT outer joins when all sort expressions come from +/// the preserved side. +/// +/// See module-level documentation for details. +#[derive(Default, Debug)] +pub struct PushDownTopKThroughJoin; + +impl PushDownTopKThroughJoin { + #[expect(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PushDownTopKThroughJoin { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // Match Sort with fetch (TopK) + let LogicalPlan::Sort(sort) = &plan else { + return Ok(Transformed::no(plan)); + }; + let Some(fetch) = sort.fetch else { + return Ok(Transformed::no(plan)); + }; + + // Check if the child is a Join (look through Projection) + let (has_projection, join) = match sort.input.as_ref() { + LogicalPlan::Join(join) => (false, join), + LogicalPlan::Projection(proj) => match proj.input.as_ref() { + LogicalPlan::Join(join) => (true, join), + _ => return Ok(Transformed::no(plan)), + }, + _ => return Ok(Transformed::no(plan)), + }; + + // Only LEFT or RIGHT, no non-equijoin filter + let preserved_is_left = match join.join_type { + JoinType::Left => true, + JoinType::Right => false, + _ => return Ok(Transformed::no(plan)), + }; + if join.filter.is_some() { + return Ok(Transformed::no(plan)); + } + + // Check all sort expression columns come from the preserved side + let preserved_schema = if preserved_is_left { + join.left.schema() + } else { + join.right.schema() + }; + let preserved_cols = schema_columns(preserved_schema); + + let all_from_preserved = sort + .expr + .iter() + .all(|sort_expr| has_all_column_refs(&sort_expr.expr, &preserved_cols)); + if !all_from_preserved { + return Ok(Transformed::no(plan)); + } + + // Don't push if preserved child is already a Sort (redundant) + let preserved_child = if preserved_is_left { + &join.left + } else { + &join.right + }; + if matches!(preserved_child.as_ref(), LogicalPlan::Sort(_)) { + return Ok(Transformed::no(plan)); + } + + // Create the new Sort(fetch) on the preserved child + let new_child_sort = Arc::new(LogicalPlan::Sort(SortPlan { + expr: sort.expr.clone(), + input: Arc::clone(preserved_child), + fetch: Some(fetch), + })); + + // Reconstruct the join with the new child + let mut new_join = join.clone(); + if preserved_is_left { + new_join.left = new_child_sort; + } else { + new_join.right = new_child_sort; + } + + // Rebuild the tree: join → optional projection → top-level sort + let new_join_plan = LogicalPlan::Join(new_join); + let new_sort_input = if has_projection { + // Reconstruct the Projection with the new join + let LogicalPlan::Projection(proj) = sort.input.as_ref() else { + unreachable!() + }; + let mut new_proj = proj.clone(); + new_proj.input = Arc::new(new_join_plan); + Arc::new(LogicalPlan::Projection(new_proj)) + } else { + Arc::new(new_join_plan) + }; + + Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { + expr: sort.expr.clone(), + input: new_sort_input, + fetch: sort.fetch, + }))) + } + + fn name(&self) -> &str { + "push_down_topk_through_join" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::OptimizerContext; + use crate::assert_optimized_plan_eq_snapshot; + use crate::test::*; + + use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownTopKThroughJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; + } + + /// TopK on left-side columns above a LEFT JOIN → pushed to left child. + #[test] + fn topk_pushed_to_left_of_left_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// TopK on right-side columns above a RIGHT JOIN → pushed to right child. + #[test] + fn topk_pushed_to_right_of_right_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Right, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(5))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=5 + Right Join: t1.a = t2.a + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=5 + TableScan: t2 + " + ) + } + + /// TopK pushed through a Projection between Sort and Join. + #[test] + fn topk_pushed_through_projection() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .project(vec![col("t1.a"), col("t1.b"), col("t2.c")])? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Projection: t1.a, t1.b, t2.c + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// INNER JOIN → no pushdown. + #[test] + fn topk_not_pushed_for_inner_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Inner, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// LEFT JOIN but sort on right-side columns → no pushdown. + #[test] + fn topk_not_pushed_for_wrong_side() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Join with a non-equijoin filter → no pushdown (conservative). + #[test] + fn topk_not_pushed_with_join_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join_on( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + vec![col("t1.a").eq(col("t2.a"))], + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Left Join: Filter: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Sort without fetch (unbounded) → no pushdown. + #[test] + fn topk_not_pushed_without_fetch() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort(vec![col("t1.b").sort(true, false)])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } +} \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt new file mode 100644 index 0000000000000..ef6858c406b8f --- /dev/null +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for pushing TopK (Sort with fetch) through outer joins + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.explain.logical_plan_only = true; + +# Create test tables +statement ok +CREATE TABLE t1 (a INT, b INT, c VARCHAR) AS VALUES + (1, 10, 'one'), + (2, 20, 'two'), + (3, 30, 'three'), + (4, 40, 'four'), + (5, 50, 'five'); + +statement ok +CREATE TABLE t2 (x INT, y INT, z VARCHAR) AS VALUES + (1, 100, 'alpha'), + (2, 200, 'beta'), + (3, 300, 'gamma'), + (6, 600, 'delta'), + (7, 700, 'epsilon'); + +### +### Positive cases — TopK should be pushed down +### + +# LEFT JOIN: TopK on left-side columns pushed to left child +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# Verify correctness of the above query +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 1 +2 20 2 +3 30 3 + +# RIGHT JOIN: TopK on right-side columns pushed to right child +query TT +EXPLAIN SELECT t1.a, t2.x, t2.y +FROM t1 RIGHT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--Right Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y ASC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +# Verify correctness +query III +SELECT t1.a, t2.x, t2.y +FROM t1 RIGHT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +1 1 100 +2 2 200 +3 3 300 + +### +### Negative cases — TopK should NOT be pushed down +### + +# INNER JOIN: no pushdown +query TT +EXPLAIN SELECT t1.a, t2.x +FROM t1 INNER JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Projection: t1.a, t2.x +02)--Sort: t1.b ASC NULLS LAST, fetch=3 +03)----Projection: t1.a, t2.x, t1.b +04)------Inner Join: t1.a = t2.x +05)--------TableScan: t1 projection=[a, b] +06)--------TableScan: t2 projection=[x] + +# LEFT JOIN but sort on right-side columns: no pushdown +query TT +EXPLAIN SELECT t1.a, t2.x, t2.y +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----TableScan: t2 projection=[x, y] + +# FULL OUTER JOIN: no pushdown +query TT +EXPLAIN SELECT t1.a, t2.x +FROM t1 FULL OUTER JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Projection: t1.a, t2.x +02)--Sort: t1.b ASC NULLS LAST, fetch=3 +03)----Projection: t1.a, t2.x, t1.b +04)------Full Join: t1.a = t2.x +05)--------TableScan: t1 projection=[a, b] +06)--------TableScan: t2 projection=[x] + +# LEFT JOIN with non-equijoin filter: no pushdown (conservative) +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > t2.y +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Projection: t1.a, t1.b, t2.x +03)----Left Join: t1.a = t2.x Filter: t1.b > t2.y +04)------TableScan: t1 projection=[a, b] +05)------TableScan: t2 projection=[x, y] + +# Sort without LIMIT: no pushdown +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] + +### +### Config reset +### + +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +reset datafusion.explain.logical_plan_only; + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; \ No newline at end of file From 9aede677a207a38d56eacb942573d61629546313 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 14 Apr 2026 19:58:49 +0530 Subject: [PATCH 02/23] lint fix --- datafusion/optimizer/src/push_down_topk_through_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index d8f18d9a9ec30..24977b215c400 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -402,4 +402,4 @@ mod test { " ) } -} \ No newline at end of file +} From 19b0edc4e4bfe188924e7c14cdc27065202007eb Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 14 Apr 2026 20:48:00 +0530 Subject: [PATCH 03/23] fix build failure --- datafusion/sqllogictest/test_files/explain.slt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 467afe7b6c2ba..3628f6a70ccd1 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -193,6 +193,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE @@ -217,6 +218,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE @@ -565,6 +567,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE @@ -589,6 +592,7 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE From baf25ef47f6339eef88f33c559f0cefcc0367327 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Fri, 17 Apr 2026 15:08:48 +0530 Subject: [PATCH 04/23] Handle edge cases --- .../src/push_down_topk_through_join.rs | 348 +++++++++++++++++- .../push_down_topk_through_join.slt | 219 ++++++++++- 2 files changed, 551 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 24977b215c400..cd42cfd00797b 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -57,9 +57,12 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use crate::utils::{has_all_column_refs, schema_columns}; -use datafusion_common::Result; -use datafusion_common::tree_node::Transformed; -use datafusion_expr::logical_plan::{JoinType, LogicalPlan, Sort as SortPlan}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, Result}; +use datafusion_expr::logical_plan::{ + JoinType, LogicalPlan, Projection, Sort as SortPlan, +}; +use datafusion_expr::{Expr, SortExpr}; /// Optimization rule that pushes TopK (Sort with fetch) through /// LEFT / RIGHT outer joins when all sort expressions come from @@ -104,17 +107,29 @@ impl OptimizerRule for PushDownTopKThroughJoin { _ => return Ok(Transformed::no(plan)), }; - // Only LEFT or RIGHT, no non-equijoin filter + // Only outer/semi/anti joins where the preserved side is known. + // No non-equijoin filter (conservative — filter may change row count). let preserved_is_left = match join.join_type { - JoinType::Left => true, - JoinType::Right => false, + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => true, + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => false, _ => return Ok(Transformed::no(plan)), }; if join.filter.is_some() { return Ok(Transformed::no(plan)); } - // Check all sort expression columns come from the preserved side + // Check all sort expression columns come from the preserved side. + // When there's a projection, resolve sort expressions through it first + // since the sort references post-projection columns. + let resolved_sort_exprs = if has_projection { + let LogicalPlan::Projection(proj) = sort.input.as_ref() else { + unreachable!() + }; + resolve_sort_exprs_through_projection(&sort.expr, proj)? + } else { + sort.expr.clone() + }; + let preserved_schema = if preserved_is_left { join.left.schema() } else { @@ -122,28 +137,65 @@ impl OptimizerRule for PushDownTopKThroughJoin { }; let preserved_cols = schema_columns(preserved_schema); - let all_from_preserved = sort - .expr + let all_from_preserved = resolved_sort_exprs .iter() .all(|sort_expr| has_all_column_refs(&sort_expr.expr, &preserved_cols)); if !all_from_preserved { return Ok(Transformed::no(plan)); } - // Don't push if preserved child is already a Sort (redundant) + // Push through when the preserved child has no Sort, or has a Sort + // with a larger/no fetch limit (our tighter limit reduces data further). + // + // Example (push): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Child limits to 10, our tighter fetch=5 reduces data further. + // + // Example (push): Sort(a ASC, fetch=5) → Join → Sort(a ASC) + // Child has no fetch (full sort), adding fetch=5 limits early. + // + // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) + // Child already limits to 3 rows, pushing fetch=5 won't help. let preserved_child = if preserved_is_left { &join.left } else { &join.right }; - if matches!(preserved_child.as_ref(), LogicalPlan::Sort(_)) { + if let LogicalPlan::Sort(child_sort) = preserved_child.as_ref() { + // Compare using resolved expressions since the parent sort may + // reference post-projection column names while the child uses + // pre-projection expressions. + let same_exprs = child_sort.expr == resolved_sort_exprs; + let child_fetch_tighter = match child_sort.fetch { + Some(child_fetch) => child_fetch <= fetch, + None => false, + }; + if same_exprs && child_fetch_tighter { + return Ok(Transformed::no(plan)); + } + } + + // Don't push if any sort expression is non-deterministic (e.g. random()). + // Duplicating such expressions would produce different values at each + // evaluation point, potentially changing the result. + if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { return Ok(Transformed::no(plan)); } - // Create the new Sort(fetch) on the preserved child + // Create the new Sort(fetch) on the preserved child. + // Use the resolved expressions (pre-projection) for the pushed Sort. + // + // If the child is already a Sort with the same expressions but a larger + // fetch, tighten its fetch in-place instead of stacking a redundant Sort + // on top. + let (sort_input, sort_exprs) = match preserved_child.as_ref() { + LogicalPlan::Sort(child_sort) if child_sort.expr == resolved_sort_exprs => { + (Arc::clone(&child_sort.input), child_sort.expr.clone()) + } + _ => (Arc::clone(preserved_child), resolved_sort_exprs), + }; let new_child_sort = Arc::new(LogicalPlan::Sort(SortPlan { - expr: sort.expr.clone(), - input: Arc::clone(preserved_child), + expr: sort_exprs, + input: sort_input, fetch: Some(fetch), })); @@ -185,6 +237,63 @@ impl OptimizerRule for PushDownTopKThroughJoin { } } +/// Resolve sort expressions through a projection by replacing column +/// references with the underlying projection expressions. +/// +/// For example, if sort expr is `b ASC` and projection has `-t1.b AS b`, +/// the resolved sort expr becomes `-t1.b ASC`. +/// +/// Before: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// t1 +/// t2 +/// ``` +/// +/// After resolving, the pushed Sort uses pre-projection expressions: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// Sort: -t1.b ASC, fetch=3 ← resolved through projection +/// t1 +/// t2 +/// ``` +fn resolve_sort_exprs_through_projection( + sort_exprs: &[SortExpr], + projection: &Projection, +) -> Result> { + // Build map: output column name → underlying expression + let replace_map: std::collections::HashMap = projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect(); + + sort_exprs + .iter() + .map(|sort_expr| { + let new_expr = sort_expr.expr.clone().transform(|expr| { + let replacement = match &expr { + Expr::Column(col) => replace_map.get(&col.flat_name()).cloned(), + _ => None, + }; + Ok(replacement.map_or_else(|| Transformed::no(expr), Transformed::yes)) + })?; + Ok(SortExpr { + expr: new_expr.data, + ..*sort_expr + }) + }) + .collect() +} + #[cfg(test)] mod test { use super::*; @@ -192,7 +301,8 @@ mod test { use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; - use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_expr::col; + use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; macro_rules! assert_optimized_plan_equal { ( @@ -402,4 +512,212 @@ mod test { " ) } + + /// LEFT SEMI JOIN: TopK on left-side columns → pushed to left child. + #[test] + fn topk_pushed_for_left_semi_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::LeftSemi, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + LeftSemi Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// LEFT ANTI JOIN: TopK on left-side columns → pushed to left child. + #[test] + fn topk_pushed_for_left_anti_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::LeftAnti, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + LeftAnti Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// RIGHT SEMI JOIN: TopK on right-side columns → pushed to right child. + #[test] + fn topk_pushed_for_right_semi_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::RightSemi, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + RightSemi Join: t1.a = t2.a + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=3 + TableScan: t2 + " + ) + } + + /// RIGHT ANTI JOIN: TopK on right-side columns → pushed to right child. + #[test] + fn topk_pushed_for_right_anti_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::RightAnti, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + RightAnti Join: t1.a = t2.a + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=3 + TableScan: t2 + " + ) + } + + /// Multi-column sort with columns from both sides → no pushdown. + #[test] + fn topk_not_pushed_for_mixed_side_sort() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit( + vec![col("t1.b").sort(true, false), col("t2.b").sort(true, false)], + Some(3), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, t2.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Preserved child has a larger fetch → push our tighter limit. + #[test] + fn topk_pushed_when_child_has_larger_fetch() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Child already has Sort(b ASC, fetch=10); our outer Sort has fetch=3 (tighter). + let t1_with_sort = LogicalPlanBuilder::from(t1) + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(10))? + .build()?; + + let plan = LogicalPlanBuilder::from(t1_with_sort) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Preserved child already has a tighter fetch → skip pushdown. + #[test] + fn topk_not_pushed_when_child_has_smaller_fetch() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Child already has Sort(b ASC, fetch=2); our outer Sort has fetch=5 (looser). + let t1_with_sort = LogicalPlanBuilder::from(t1) + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(2))? + .build()?; + + let plan = LogicalPlanBuilder::from(t1_with_sort) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(5))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=5 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=2 + TableScan: t1 + TableScan: t2 + " + ) + } } diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index ef6858c406b8f..b3b8f987aa2e6 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -159,6 +159,223 @@ logical_plan 03)----TableScan: t1 projection=[a, b] 04)----TableScan: t2 projection=[x] +### +### Sort child cases — push vs skip based on existing child Sort +### + +# Child has larger fetch: push our tighter limit +# The inner Sort(fetch=5) has a larger limit than our outer Sort(fetch=2), +# so pushing fetch=2 to the preserved child reduces data further. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.b ASC NULLS LAST, fetch=5 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Child has smaller fetch with same sort: skip (already tighter) +# The inner Sort(fetch=2) already has a tighter limit than our outer Sort(fetch=5), +# so pushing fetch=5 would be redundant. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 2) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 5; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=5 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.b ASC NULLS LAST, fetch=2 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 2) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 5; +---- +1 10 1 +2 20 2 + +### +### Semi/Anti join cases — pushdown supported +### + +# LEFT SEMI JOIN: push to left child +query TT +EXPLAIN SELECT t1.a, t1.b +FROM t1 LEFT SEMI JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--LeftSemi Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# LEFT ANTI JOIN: push to left child +query TT +EXPLAIN SELECT t1.a, t1.b +FROM t1 LEFT ANTI JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--LeftAnti Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# RIGHT SEMI JOIN: push to right child +query TT +EXPLAIN SELECT t2.x, t2.y +FROM t1 RIGHT SEMI JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--RightSemi Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y ASC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +# RIGHT ANTI JOIN: push to right child +query TT +EXPLAIN SELECT t2.x, t2.y +FROM t1 RIGHT ANTI JOIN t2 ON t1.a = t2.x +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--RightAnti Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y ASC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +### +### Multi-column sort and OFFSET cases +### + +# ORDER BY columns from both sides: no pushdown +# Sort uses t1.b (left) and t2.y (right) — not all from preserved side +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x, t2.y +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC, t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, t2.y ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x, y] + +# Verify correctness +query IIII +SELECT t1.a, t1.b, t2.x, t2.y +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC, t2.y ASC LIMIT 3; +---- +1 10 1 100 +2 20 2 200 +3 30 3 300 + +# LIMIT with OFFSET: pushdown still applies (Sort fetch = limit + offset = 3) +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 2 OFFSET 1; +---- +logical_plan +01)Limit: skip=1, fetch=2 +02)--Sort: t1.b ASC NULLS LAST, fetch=3 +03)----Left Join: t1.a = t2.x +04)------Sort: t1.b ASC NULLS LAST, fetch=3 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# Verify correctness: skip 1, take 2 +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC LIMIT 2 OFFSET 1; +---- +2 20 2 +3 30 3 + +### +### Projection expression resolution cases +### + +# Sort on a projected expression: the pushed Sort should use the +# pre-projection expression, not the aliased column name. +# ORDER BY neg_b (which is -t1.b) should push Sort(-t1.b) below the join. +query TT +EXPLAIN SELECT -t1.b AS neg_b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY neg_b ASC LIMIT 3; +---- +logical_plan +01)Sort: neg_b ASC NULLS LAST, fetch=3 +02)--Projection: (- t1.b) AS neg_b, t2.x +03)----Left Join: t1.a = t2.x +04)------Sort: (- t1.b) ASC NULLS LAST, fetch=3 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# Verify correctness: -b ascending means largest b first +query II +SELECT -t1.b AS neg_b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY neg_b ASC LIMIT 3; +---- +-50 NULL +-40 NULL +-30 3 + +# Non-deterministic sort expression (random()): no pushdown +# Duplicating random() would produce different values at each evaluation point. +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b + random() ASC LIMIT 3; +---- +logical_plan +01)Sort: CAST(t1.b AS Float64) + random() ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] + ### ### Config reset ### @@ -173,4 +390,4 @@ statement ok DROP TABLE t1; statement ok -DROP TABLE t2; \ No newline at end of file +DROP TABLE t2; From 67f92658b9f21740cf90938a9bd89ff1fc7dd661 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Fri, 17 Apr 2026 15:16:04 +0530 Subject: [PATCH 05/23] Handle volatile expr early --- .../optimizer/src/push_down_topk_through_join.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index cd42cfd00797b..22711f5aba54b 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -97,6 +97,13 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); }; + // Don't push if any sort expression is non-deterministic (e.g. random()). + // Duplicating such expressions would produce different values at each + // evaluation point, potentially changing the result. + if sort.expr.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + // Check if the child is a Join (look through Projection) let (has_projection, join) = match sort.input.as_ref() { LogicalPlan::Join(join) => (false, join), @@ -174,13 +181,6 @@ impl OptimizerRule for PushDownTopKThroughJoin { } } - // Don't push if any sort expression is non-deterministic (e.g. random()). - // Duplicating such expressions would produce different values at each - // evaluation point, potentially changing the result. - if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { - return Ok(Transformed::no(plan)); - } - // Create the new Sort(fetch) on the preserved child. // Use the resolved expressions (pre-projection) for the pushed Sort. // From d12aefa983f4af9d2b62520b06bd0cf87546e09d Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Fri, 17 Apr 2026 17:25:49 +0530 Subject: [PATCH 06/23] Fix build failure --- .../src/push_down_topk_through_join.rs | 37 +++++++++---------- .../push_down_topk_through_join.slt | 28 +++++++------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 22711f5aba54b..fd13f864390c8 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -114,11 +114,14 @@ impl OptimizerRule for PushDownTopKThroughJoin { _ => return Ok(Transformed::no(plan)), }; - // Only outer/semi/anti joins where the preserved side is known. + // Only outer joins where the preserved side is known. + // Semi/Anti joins are excluded: not all preserved-side rows appear in + // the output (only matched/unmatched rows do), so pushing fetch=N to + // the preserved child can drop rows that would have survived the filter. // No non-equijoin filter (conservative — filter may change row count). let preserved_is_left = match join.join_type { - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => true, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => false, + JoinType::Left => true, + JoinType::Right => false, _ => return Ok(Transformed::no(plan)), }; if join.filter.is_some() { @@ -513,9 +516,9 @@ mod test { ) } - /// LEFT SEMI JOIN: TopK on left-side columns → pushed to left child. + /// LEFT SEMI JOIN: pushing fetch is unsafe (not all left rows appear in output). #[test] - fn topk_pushed_for_left_semi_join() -> Result<()> { + fn topk_not_pushed_for_left_semi_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -534,16 +537,15 @@ mod test { @r" Sort: t1.b ASC NULLS LAST, fetch=3 LeftSemi Join: t1.a = t2.a - Sort: t1.b ASC NULLS LAST, fetch=3 - TableScan: t1 + TableScan: t1 TableScan: t2 " ) } - /// LEFT ANTI JOIN: TopK on left-side columns → pushed to left child. + /// LEFT ANTI JOIN: pushing fetch is unsafe (not all left rows appear in output). #[test] - fn topk_pushed_for_left_anti_join() -> Result<()> { + fn topk_not_pushed_for_left_anti_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -562,16 +564,15 @@ mod test { @r" Sort: t1.b ASC NULLS LAST, fetch=3 LeftAnti Join: t1.a = t2.a - Sort: t1.b ASC NULLS LAST, fetch=3 - TableScan: t1 + TableScan: t1 TableScan: t2 " ) } - /// RIGHT SEMI JOIN: TopK on right-side columns → pushed to right child. + /// RIGHT SEMI JOIN: pushing fetch is unsafe (not all right rows appear in output). #[test] - fn topk_pushed_for_right_semi_join() -> Result<()> { + fn topk_not_pushed_for_right_semi_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -591,15 +592,14 @@ mod test { Sort: t2.b ASC NULLS LAST, fetch=3 RightSemi Join: t1.a = t2.a TableScan: t1 - Sort: t2.b ASC NULLS LAST, fetch=3 - TableScan: t2 + TableScan: t2 " ) } - /// RIGHT ANTI JOIN: TopK on right-side columns → pushed to right child. + /// RIGHT ANTI JOIN: pushing fetch is unsafe (not all right rows appear in output). #[test] - fn topk_pushed_for_right_anti_join() -> Result<()> { + fn topk_not_pushed_for_right_anti_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -619,8 +619,7 @@ mod test { Sort: t2.b ASC NULLS LAST, fetch=3 RightAnti Join: t1.a = t2.a TableScan: t1 - Sort: t2.b ASC NULLS LAST, fetch=3 - TableScan: t2 + TableScan: t2 " ) } diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index b3b8f987aa2e6..1b1aebeec4355 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -228,10 +228,12 @@ ORDER BY b ASC LIMIT 5; 2 20 2 ### -### Semi/Anti join cases — pushdown supported +### Semi/Anti join cases — pushdown NOT supported +### (not all preserved-side rows appear in output, so pushing fetch +### could drop rows that would have survived the semi/anti filter) ### -# LEFT SEMI JOIN: push to left child +# LEFT SEMI JOIN: no pushdown query TT EXPLAIN SELECT t1.a, t1.b FROM t1 LEFT SEMI JOIN t2 ON t1.a = t2.x @@ -240,11 +242,10 @@ ORDER BY t1.b ASC LIMIT 3; logical_plan 01)Sort: t1.b ASC NULLS LAST, fetch=3 02)--LeftSemi Join: t1.a = t2.x -03)----Sort: t1.b ASC NULLS LAST, fetch=3 -04)------TableScan: t1 projection=[a, b] -05)----TableScan: t2 projection=[x] +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] -# LEFT ANTI JOIN: push to left child +# LEFT ANTI JOIN: no pushdown query TT EXPLAIN SELECT t1.a, t1.b FROM t1 LEFT ANTI JOIN t2 ON t1.a = t2.x @@ -253,11 +254,10 @@ ORDER BY t1.b ASC LIMIT 3; logical_plan 01)Sort: t1.b ASC NULLS LAST, fetch=3 02)--LeftAnti Join: t1.a = t2.x -03)----Sort: t1.b ASC NULLS LAST, fetch=3 -04)------TableScan: t1 projection=[a, b] -05)----TableScan: t2 projection=[x] +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] -# RIGHT SEMI JOIN: push to right child +# RIGHT SEMI JOIN: no pushdown query TT EXPLAIN SELECT t2.x, t2.y FROM t1 RIGHT SEMI JOIN t2 ON t1.a = t2.x @@ -267,10 +267,9 @@ logical_plan 01)Sort: t2.y ASC NULLS LAST, fetch=3 02)--RightSemi Join: t1.a = t2.x 03)----TableScan: t1 projection=[a] -04)----Sort: t2.y ASC NULLS LAST, fetch=3 -05)------TableScan: t2 projection=[x, y] +04)----TableScan: t2 projection=[x, y] -# RIGHT ANTI JOIN: push to right child +# RIGHT ANTI JOIN: no pushdown query TT EXPLAIN SELECT t2.x, t2.y FROM t1 RIGHT ANTI JOIN t2 ON t1.a = t2.x @@ -280,8 +279,7 @@ logical_plan 01)Sort: t2.y ASC NULLS LAST, fetch=3 02)--RightAnti Join: t1.a = t2.x 03)----TableScan: t1 projection=[a] -04)----Sort: t2.y ASC NULLS LAST, fetch=3 -05)------TableScan: t2 projection=[x, y] +04)----TableScan: t2 projection=[x, y] ### ### Multi-column sort and OFFSET cases From 902ef770174196a2cf99dc454aba3fda68f7438a Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 18 Apr 2026 19:12:57 +0530 Subject: [PATCH 07/23] Handle subquery alias --- .../src/push_down_topk_through_join.rs | 297 +++++++++++++----- .../push_down_topk_through_join.slt | 274 +++++++++++++++- 2 files changed, 486 insertions(+), 85 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index fd13f864390c8..d1c4d9c32e9f6 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -60,7 +60,7 @@ use crate::utils::{has_all_column_refs, schema_columns}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, Result}; use datafusion_expr::logical_plan::{ - JoinType, LogicalPlan, Projection, Sort as SortPlan, + JoinType, LogicalPlan, Projection, Sort as SortPlan, SubqueryAlias, }; use datafusion_expr::{Expr, SortExpr}; @@ -104,41 +104,61 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); } - // Check if the child is a Join (look through Projection) - let (has_projection, join) = match sort.input.as_ref() { - LogicalPlan::Join(join) => (false, join), - LogicalPlan::Projection(proj) => match proj.input.as_ref() { - LogicalPlan::Join(join) => (true, join), + // Peel through transparent nodes (SubqueryAlias, Projection) to find + // the Join. Track intermediate nodes so we can reconstruct the tree + // and resolve sort expressions through them. + let mut current = sort.input.as_ref(); + let mut intermediates: Vec<&LogicalPlan> = Vec::new(); + let join = loop { + match current { + LogicalPlan::Join(join) => break join, + LogicalPlan::Projection(proj) => { + intermediates.push(current); + current = proj.input.as_ref(); + } + LogicalPlan::SubqueryAlias(sq) => { + intermediates.push(current); + current = sq.input.as_ref(); + } _ => return Ok(Transformed::no(plan)), - }, - _ => return Ok(Transformed::no(plan)), + } }; // Only outer joins where the preserved side is known. // Semi/Anti joins are excluded: not all preserved-side rows appear in // the output (only matched/unmatched rows do), so pushing fetch=N to // the preserved child can drop rows that would have survived the filter. - // No non-equijoin filter (conservative — filter may change row count). + // + // Non-equijoin filters in the ON clause are safe: outer joins guarantee + // all preserved-side rows appear in the output regardless of the filter. + // The filter only controls matching (which non-preserved rows pair up), + // not which preserved rows survive. let preserved_is_left = match join.join_type { JoinType::Left => true, JoinType::Right => false, _ => return Ok(Transformed::no(plan)), }; - if join.filter.is_some() { - return Ok(Transformed::no(plan)); - } - // Check all sort expression columns come from the preserved side. - // When there's a projection, resolve sort expressions through it first - // since the sort references post-projection columns. - let resolved_sort_exprs = if has_projection { - let LogicalPlan::Projection(proj) = sort.input.as_ref() else { - unreachable!() - }; - resolve_sort_exprs_through_projection(&sort.expr, proj)? - } else { - sort.expr.clone() - }; + // Resolve sort expressions through all intermediate nodes (Projection, + // SubqueryAlias) so that column references match the join's schema. + let mut resolved_sort_exprs = sort.expr.clone(); + for node in &intermediates { + match node { + LogicalPlan::Projection(proj) => { + resolved_sort_exprs = resolve_sort_exprs_through_projection( + &resolved_sort_exprs, + proj, + )?; + } + LogicalPlan::SubqueryAlias(sq) => { + resolved_sort_exprs = resolve_sort_exprs_through_subquery_alias( + &resolved_sort_exprs, + sq, + )?; + } + _ => unreachable!(), + } + } let preserved_schema = if preserved_is_left { join.left.schema() @@ -154,6 +174,42 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); } + let preserved_child = if preserved_is_left { + &join.left + } else { + &join.right + }; + + // Resolve sort exprs further through any SubqueryAlias wrapping the + // preserved child, so we can compare with the inner Sort's expressions. + // + // After intermediate resolution, resolved_sort_exprs = [t1.b ASC]. + // The inner Sort uses [orders.b ASC]. This step maps t1.b → orders.b. + // + // ```text + // Sort(sub.b ASC, fetch=2) + // SubqueryAlias(sub) ← intermediate, already resolved + // Left Join + // SubqueryAlias(t1) ← preserved child, resolve here + // Sort(orders.b ASC, fetch=5) + // TableScan: orders + // ``` + let (inner_child, child_resolved_exprs) = match preserved_child.as_ref() { + LogicalPlan::SubqueryAlias(sq) => { + let exprs = + resolve_sort_exprs_through_subquery_alias(&resolved_sort_exprs, sq)?; + (sq.input.as_ref(), exprs) + } + _ => (preserved_child.as_ref(), resolved_sort_exprs.clone()), + }; + + // If the inner child is a Limit (PushDownLimit hasn't merged it with + // the Sort yet), skip this iteration. PushDownLimit will merge + // Limit → Sort in the next pass, then our rule will tighten the Sort. + if matches!(inner_child, LogicalPlan::Limit(_)) { + return Ok(Transformed::no(plan)); + } + // Push through when the preserved child has no Sort, or has a Sort // with a larger/no fetch limit (our tighter limit reduces data further). // @@ -165,16 +221,8 @@ impl OptimizerRule for PushDownTopKThroughJoin { // // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) // Child already limits to 3 rows, pushing fetch=5 won't help. - let preserved_child = if preserved_is_left { - &join.left - } else { - &join.right - }; - if let LogicalPlan::Sort(child_sort) = preserved_child.as_ref() { - // Compare using resolved expressions since the parent sort may - // reference post-projection column names while the child uses - // pre-projection expressions. - let same_exprs = child_sort.expr == resolved_sort_exprs; + if let LogicalPlan::Sort(child_sort) = inner_child { + let same_exprs = sort_exprs_equal(&child_sort.expr, &child_resolved_exprs); let child_fetch_tighter = match child_sort.fetch { Some(child_fetch) => child_fetch <= fetch, None => false, @@ -185,44 +233,67 @@ impl OptimizerRule for PushDownTopKThroughJoin { } // Create the new Sort(fetch) on the preserved child. - // Use the resolved expressions (pre-projection) for the pushed Sort. + // Use the resolved expressions for the pushed Sort. + // + // If the inner child is already a Sort with the same expressions but a + // larger fetch, tighten its fetch in-place instead of stacking a + // redundant Sort on top. // - // If the child is already a Sort with the same expressions but a larger - // fetch, tighten its fetch in-place instead of stacking a redundant Sort - // on top. - let (sort_input, sort_exprs) = match preserved_child.as_ref() { - LogicalPlan::Sort(child_sort) if child_sort.expr == resolved_sort_exprs => { - (Arc::clone(&child_sort.input), child_sort.expr.clone()) + // When the preserved child is wrapped in SubqueryAlias, the new Sort + // must sit INSIDE the SubqueryAlias (between it and its input), using + // inner-schema column names. + let inner_input: &Arc = match preserved_child.as_ref() { + LogicalPlan::SubqueryAlias(sq) => &sq.input, + _ => preserved_child, + }; + let new_inner_child = match inner_child { + LogicalPlan::Sort(child_sort) + if sort_exprs_equal(&child_sort.expr, &child_resolved_exprs) => + { + Arc::new(LogicalPlan::Sort(SortPlan { + expr: child_sort.expr.clone(), + input: Arc::clone(&child_sort.input), + fetch: Some(fetch), + })) } - _ => (Arc::clone(preserved_child), resolved_sort_exprs), + _ => Arc::new(LogicalPlan::Sort(SortPlan { + expr: child_resolved_exprs, + input: Arc::clone(inner_input), + fetch: Some(fetch), + })), + }; + + // Wrap the new Sort back in SubqueryAlias if the preserved child had one. + let new_preserved_child = match preserved_child.as_ref() { + LogicalPlan::SubqueryAlias(sq) => Arc::new(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_inner_child, sq.alias.clone())?, + )), + _ => new_inner_child, }; - let new_child_sort = Arc::new(LogicalPlan::Sort(SortPlan { - expr: sort_exprs, - input: sort_input, - fetch: Some(fetch), - })); // Reconstruct the join with the new child let mut new_join = join.clone(); if preserved_is_left { - new_join.left = new_child_sort; + new_join.left = new_preserved_child; } else { - new_join.right = new_child_sort; + new_join.right = new_preserved_child; } - // Rebuild the tree: join → optional projection → top-level sort - let new_join_plan = LogicalPlan::Join(new_join); - let new_sort_input = if has_projection { - // Reconstruct the Projection with the new join - let LogicalPlan::Projection(proj) = sort.input.as_ref() else { - unreachable!() - }; - let mut new_proj = proj.clone(); - new_proj.input = Arc::new(new_join_plan); - Arc::new(LogicalPlan::Projection(new_proj)) - } else { - Arc::new(new_join_plan) - }; + // Rebuild the tree: join → intermediate nodes → top-level sort + let mut new_sort_input = Arc::new(LogicalPlan::Join(new_join)); + for node in intermediates.into_iter().rev() { + new_sort_input = Arc::new(match node { + LogicalPlan::Projection(proj) => { + let mut new_proj = proj.clone(); + new_proj.input = new_sort_input; + LogicalPlan::Projection(new_proj) + } + LogicalPlan::SubqueryAlias(sq) => LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_sort_input, sq.alias.clone())?, + ), + _ => unreachable!(), + }); + } Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { expr: sort.expr.clone(), @@ -264,21 +335,21 @@ impl OptimizerRule for PushDownTopKThroughJoin { /// t1 /// t2 /// ``` -fn resolve_sort_exprs_through_projection( +/// Replace column references in sort expressions using a name→expr map. +/// Uses `transform()` for deep replacement (handles nested expressions +/// like `-t1.b` where the column is inside a Negative wrapper). +/// +/// Example with `replace_map = {"sub.b" → Column(t1.b)}`: +/// +/// ```text +/// Input: [sub.b ASC] → Output: [t1.b ASC] (simple column) +/// Input: [(- sub.b) ASC] → Output: [(- t1.b) ASC] (nested column) +/// Input: [sub.a ASC, sub.b ASC] → Output: [t1.a ASC, t1.b ASC] (multiple) +/// ``` +fn replace_columns_in_sort_exprs( sort_exprs: &[SortExpr], - projection: &Projection, + replace_map: &std::collections::HashMap, ) -> Result> { - // Build map: output column name → underlying expression - let replace_map: std::collections::HashMap = projection - .schema - .iter() - .zip(projection.expr.iter()) - .map(|((qualifier, field), expr)| { - let key = Column::from((qualifier, field)).flat_name(); - (key, expr.clone().unalias()) - }) - .collect(); - sort_exprs .iter() .map(|sort_expr| { @@ -297,6 +368,75 @@ fn resolve_sort_exprs_through_projection( .collect() } +/// Resolve sort expressions through a projection by replacing column +/// references with the underlying projection expressions. +/// +/// Example: sort expr is `neg_b ASC` and projection has `-t1.b AS neg_b`: +/// +/// ```text +/// Input sort exprs: [neg_b ASC] +/// Output sort exprs: [(- t1.b) ASC] +/// ``` +fn resolve_sort_exprs_through_projection( + sort_exprs: &[SortExpr], + projection: &Projection, +) -> Result> { + let replace_map: std::collections::HashMap = projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect(); + + replace_columns_in_sort_exprs(sort_exprs, &replace_map) +} + +/// Compare two slices of `SortExpr` using `flat_name()` for column identity. +/// +/// `Column` derives `PartialEq` which compares the `relation` field +/// (`Option`) structurally. A `Bare("t1")` and +/// `Full { catalog, schema, table: "t1" }` are NOT equal even though they +/// refer to the same column. After resolving through SubqueryAlias the +/// variant may differ, so we compare by flat_name() instead. +fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { + a.len() == b.len() + && a.iter().zip(b.iter()).all(|(left, right)| { + left.asc == right.asc + && left.nulls_first == right.nulls_first + && left.expr.to_string() == right.expr.to_string() + }) +} + +/// Resolve sort expressions through a SubqueryAlias by replacing the alias +/// qualifier with the input schema's qualifier. +/// +/// Example: SubqueryAlias is `sub` wrapping a join whose left input is `t1`: +/// +/// ```text +/// Input sort exprs: [sub.b ASC] +/// Output sort exprs: [t1.b ASC] +/// ``` +fn resolve_sort_exprs_through_subquery_alias( + sort_exprs: &[SortExpr], + subquery_alias: &SubqueryAlias, +) -> Result> { + let replace_map: std::collections::HashMap = subquery_alias + .schema + .iter() + .zip(subquery_alias.input.schema().iter()) + .map(|((alias_qual, alias_field), (input_qual, input_field))| { + let alias_col = Column::from((alias_qual, alias_field)); + let input_col = Column::from((input_qual, input_field)); + (alias_col.flat_name(), Expr::Column(input_col)) + }) + .collect(); + + replace_columns_in_sort_exprs(sort_exprs, &replace_map) +} + #[cfg(test)] mod test { use super::*; @@ -463,9 +603,11 @@ mod test { ) } - /// Join with a non-equijoin filter → no pushdown (conservative). + /// Join with a non-equijoin filter → pushdown still happens. + /// Outer joins preserve all rows from the preserved side regardless + /// of the ON filter. #[test] - fn topk_not_pushed_with_join_filter() -> Result<()> { + fn topk_pushed_with_join_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; @@ -483,7 +625,8 @@ mod test { @r" Sort: t1.b ASC NULLS LAST, fetch=3 Left Join: Filter: t1.a = t2.a - TableScan: t1 + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 TableScan: t2 " ) diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 1b1aebeec4355..630fa85472328 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -134,7 +134,9 @@ logical_plan 05)--------TableScan: t1 projection=[a, b] 06)--------TableScan: t2 projection=[x] -# LEFT JOIN with non-equijoin filter: no pushdown (conservative) +# LEFT JOIN with non-equijoin filter on BOTH sides: pushdown OK +# Filter t1.b > t2.y is in the ON clause — it only controls matching, not +# which preserved (left) rows appear. All left rows are preserved. query TT EXPLAIN SELECT t1.a, t1.b, t2.x FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > t2.y @@ -144,8 +146,69 @@ logical_plan 01)Sort: t1.b ASC NULLS LAST, fetch=3 02)--Projection: t1.a, t1.b, t2.x 03)----Left Join: t1.a = t2.x Filter: t1.b > t2.y +04)------Sort: t1.b ASC NULLS LAST, fetch=3 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x, y] + +# Verify correctness: all left rows appear, filter only affects matching +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > t2.y +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 NULL +2 20 NULL +3 30 NULL + +# LEFT JOIN with non-equijoin filter on non-preserved side only: pushdown OK +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t2.y > 100 +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 04)------TableScan: t1 projection=[a, b] -05)------TableScan: t2 projection=[x, y] +05)----Projection: t2.x +06)------Filter: t2.y > Int32(100) +07)--------TableScan: t2 projection=[x, y] + +# LEFT JOIN with preserved-side-only filter: pushdown OK +# Filter t1.b > 20 prevents matching for left rows with b <= 20, +# but those rows still appear with NULL-filled right columns. +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > 20 +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x Filter: t1.b > Int32(20) +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# Verify correctness: rows with b <= 20 get NULL right columns +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t1.b > 20 +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 NULL +2 20 NULL +3 30 3 + +# Verify correctness: non-preserved side filter +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x AND t2.y > 100 +ORDER BY t1.b ASC LIMIT 3; +---- +1 10 NULL +2 20 2 +3 30 3 # Sort without LIMIT: no pushdown query TT @@ -164,8 +227,7 @@ logical_plan ### # Child has larger fetch: push our tighter limit -# The inner Sort(fetch=5) has a larger limit than our outer Sort(fetch=2), -# so pushing fetch=2 to the preserved child reduces data further. +# The inner Sort(fetch=5) is tightened to fetch=2 in-place. query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t1.b, t2.x @@ -179,7 +241,7 @@ logical_plan 02)--SubqueryAlias: sub 03)----Left Join: t1.a = t2.x 04)------SubqueryAlias: t1 -05)--------Sort: t1.b ASC NULLS LAST, fetch=5 +05)--------Sort: t1.b ASC NULLS LAST, fetch=2 06)----------TableScan: t1 projection=[a, b] 07)------TableScan: t2 projection=[x] @@ -195,9 +257,9 @@ ORDER BY b ASC LIMIT 2; 1 10 1 2 20 2 -# Child has smaller fetch with same sort: skip (already tighter) -# The inner Sort(fetch=2) already has a tighter limit than our outer Sort(fetch=5), -# so pushing fetch=5 would be redundant. +# Child has smaller fetch with same sort: our rule skips (already tighter). +# PushDownLimit inserts a Sort(fetch=5) that gets collapsed with the inner +# Sort(fetch=2) to Sort(fetch=2) by stacked-sort merging. query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t1.b, t2.x @@ -374,6 +436,202 @@ logical_plan 03)----TableScan: t1 projection=[a, b] 04)----TableScan: t2 projection=[x] +### +### SubqueryAlias edge cases +### + +# SubqueryAlias without inner Sort: push new Sort through SubqueryAlias +# Preserved child is SubqueryAlias(t1, TableScan) — no existing Sort to tighten, +# so a new Sort(fetch=2) is inserted inside the SubqueryAlias. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.b ASC NULLS LAST, fetch=2 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# RIGHT JOIN with SubqueryAlias on preserved (right) side +# Inner Sort(fetch=10) is tightened to fetch=3 via stacked-sort merging. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t2.x, t2.y + FROM t1 + RIGHT JOIN (SELECT * FROM t2 ORDER BY y ASC LIMIT 10) t2 + ON t1.a = t2.x +) sub +ORDER BY y ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.y ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Right Join: t1.a = t2.x +04)------TableScan: t1 projection=[a] +05)------SubqueryAlias: t2 +06)--------Sort: t2.y ASC NULLS LAST, fetch=3 +07)----------TableScan: t2 projection=[x, y] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t2.x, t2.y + FROM t1 + RIGHT JOIN (SELECT * FROM t2 ORDER BY y ASC LIMIT 10) t2 + ON t1.a = t2.x +) sub +ORDER BY y ASC LIMIT 3; +---- +1 1 100 +2 2 200 +3 3 300 + +# SubqueryAlias with different alias name (foo ≠ t1) +# Column resolution: foo.b → t1.b through SubqueryAlias renaming. +query TT +EXPLAIN SELECT * FROM ( + SELECT foo.a, foo.b, t2.x + FROM (SELECT * FROM t1) foo + LEFT JOIN t2 ON foo.a = t2.x +) sub +ORDER BY b ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Left Join: foo.a = t2.x +04)------SubqueryAlias: foo +05)--------Sort: t1.b ASC NULLS LAST, fetch=3 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT foo.a, foo.b, t2.x + FROM (SELECT * FROM t1) foo + LEFT JOIN t2 ON foo.a = t2.x +) sub +ORDER BY b ASC LIMIT 3; +---- +1 10 1 +2 20 2 +3 30 3 + +# Sort on non-preserved side column through SubqueryAlias: no pushdown +# ORDER BY t2.x is from the non-preserved (right) side of a LEFT JOIN. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY x ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.x ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# INNER JOIN through SubqueryAlias: no pushdown (only LEFT/RIGHT) +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + INNER JOIN t2 ON t1.a = t2.x +) sub +ORDER BY b ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.b ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Inner Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------TableScan: t1 projection=[a, b] +06)------TableScan: t2 projection=[x] + +# Multiple sort columns from preserved side through SubqueryAlias +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY a ASC, b ASC LIMIT 3; +---- +logical_plan +01)Sort: sub.a ASC NULLS LAST, sub.b ASC NULLS LAST, fetch=3 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------SubqueryAlias: t1 +05)--------Sort: t1.a ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=3 +06)----------TableScan: t1 projection=[a, b] +07)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.b, t2.x + FROM (SELECT * FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY a ASC, b ASC LIMIT 3; +---- +1 10 1 +2 20 2 +3 30 3 + +# WHERE filter on preserved side: pushdown still happens +# PushDownFilter pushes the WHERE filter below the Join first, +# then our rule sees Sort → Join (no Filter in between) and pushes TopK. +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +WHERE t1.b > 10 +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------Filter: t1.b > Int32(10) +05)--------TableScan: t1 projection=[a, b] +06)----TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +WHERE t1.b > 10 +ORDER BY t1.b ASC LIMIT 3; +---- +2 20 2 +3 30 3 +4 40 NULL + ### ### Config reset ### From 254b224b5b4c5192a645f2f778a87b899984fbd7 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 18 Apr 2026 19:28:54 +0530 Subject: [PATCH 08/23] Update comment --- .../sqllogictest/test_files/push_down_topk_through_join.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 630fa85472328..00a1f74b743c1 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -473,7 +473,7 @@ ORDER BY b ASC LIMIT 2; 2 20 2 # RIGHT JOIN with SubqueryAlias on preserved (right) side -# Inner Sort(fetch=10) is tightened to fetch=3 via stacked-sort merging. +# Inner Sort(fetch=10) is tightened to fetch=3 query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t2.x, t2.y From c648e71de92eb60a30dd9dd3d1b996458ba62c9f Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 18 Apr 2026 20:09:40 +0530 Subject: [PATCH 09/23] Doc fix --- .../sqllogictest/test_files/push_down_topk_through_join.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 00a1f74b743c1..153da8a7d3054 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -259,7 +259,7 @@ ORDER BY b ASC LIMIT 2; # Child has smaller fetch with same sort: our rule skips (already tighter). # PushDownLimit inserts a Sort(fetch=5) that gets collapsed with the inner -# Sort(fetch=2) to Sort(fetch=2) by stacked-sort merging. +# Sort(fetch=2) to Sort(fetch=2) query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t1.b, t2.x From 03f64999e467319789d7e31e77f24e73294bc539 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 21 Apr 2026 10:29:36 +0530 Subject: [PATCH 10/23] Handle volatile expr in projection --- .../src/push_down_topk_through_join.rs | 12 ++++++++++-- .../test_files/push_down_topk_through_join.slt | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index d1c4d9c32e9f6..3286319c1de3e 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -160,6 +160,14 @@ impl OptimizerRule for PushDownTopKThroughJoin { } } + // After resolving through projections, the sort expressions may now + // contain volatile functions (e.g. `random() AS col`). Duplicating + // volatile expressions in the pushed Sort would produce different + // values, changing results. + if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + let preserved_schema = if preserved_is_left { join.left.schema() } else { @@ -394,13 +402,13 @@ fn resolve_sort_exprs_through_projection( replace_columns_in_sort_exprs(sort_exprs, &replace_map) } -/// Compare two slices of `SortExpr` using `flat_name()` for column identity. +/// Compare two slices of `SortExpr` using `Expr::to_string()` for column identity. /// /// `Column` derives `PartialEq` which compares the `relation` field /// (`Option`) structurally. A `Bare("t1")` and /// `Full { catalog, schema, table: "t1" }` are NOT equal even though they /// refer to the same column. After resolving through SubqueryAlias the -/// variant may differ, so we compare by flat_name() instead. +/// variant may differ, so we compare by display string instead. fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(left, right)| { diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 153da8a7d3054..ee52c59124a20 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -436,6 +436,23 @@ logical_plan 03)----TableScan: t1 projection=[a, b] 04)----TableScan: t2 projection=[x] +# Non-deterministic projected expression (random() AS col): no pushdown +# Sort references a column that resolves to random() through the projection. +query TT +EXPLAIN SELECT rand_col, t2.x +FROM ( + SELECT random() AS rand_col, t1.a, t2.x + FROM t1 LEFT JOIN t2 ON t1.a = t2.x +) +ORDER BY rand_col ASC LIMIT 3; +---- +logical_plan +01)Sort: rand_col ASC NULLS LAST, fetch=3 +02)--Projection: random() AS rand_col, t2.x +03)----Left Join: t1.a = t2.x +04)------TableScan: t1 projection=[a] +05)------TableScan: t2 projection=[x] + ### ### SubqueryAlias edge cases ### From 051868a28332926e9bd47da08c34181c28cbc9fa Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 21 Apr 2026 17:02:41 +0530 Subject: [PATCH 11/23] use structural equality --- .../optimizer/src/push_down_topk_through_join.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 3286319c1de3e..a9e4995437950 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -402,19 +402,14 @@ fn resolve_sort_exprs_through_projection( replace_columns_in_sort_exprs(sort_exprs, &replace_map) } -/// Compare two slices of `SortExpr` using `Expr::to_string()` for column identity. -/// -/// `Column` derives `PartialEq` which compares the `relation` field -/// (`Option`) structurally. A `Bare("t1")` and -/// `Full { catalog, schema, table: "t1" }` are NOT equal even though they -/// refer to the same column. After resolving through SubqueryAlias the -/// variant may differ, so we compare by display string instead. +/// Compare two slices of `SortExpr` for equality. +/// Uses structural equality on the sort expressions fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(left, right)| { left.asc == right.asc && left.nulls_first == right.nulls_first - && left.expr.to_string() == right.expr.to_string() + && left.expr == right.expr }) } From 1cfeb76c07ca316446845e85cef46237ae5b37c4 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Tue, 21 Apr 2026 18:53:40 +0530 Subject: [PATCH 12/23] Adds UT --- .../src/push_down_topk_through_join.rs | 204 +++++++++++++++++- 1 file changed, 203 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index a9e4995437950..4a024731f1899 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -403,7 +403,9 @@ fn resolve_sort_exprs_through_projection( } /// Compare two slices of `SortExpr` for equality. -/// Uses structural equality on the sort expressions +/// +/// Uses structural equality on the sort expressions (direction, nulls_first, +/// and the expression tree). fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(left, right)| { @@ -865,4 +867,204 @@ mod test { " ) } + + // --------------------------------------------------------------- + // Unit tests for resolve_sort_exprs_through_projection + // --------------------------------------------------------------- + + /// Simple passthrough: sort on a column that projection passes through. + /// Projection: [t1.a, t1.b] → sort on t1.b resolves to t1.b + #[test] + fn resolve_through_projection_passthrough() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .project(vec![col("t1.a"), col("t1.b")])? + .build()?; + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + + let sort_exprs = vec![col("t1.b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "t1.b"); + assert!(resolved[0].asc); + Ok(()) + } + + /// Aliased expression: sort on neg_b resolves to (- t1.b) + #[test] + fn resolve_through_projection_alias() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .project(vec![ + col("t1.a"), + (Expr::Negative(Box::new(col("t1.b")))).alias("neg_b"), + ])? + .build()?; + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + + let sort_exprs = vec![col("neg_b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "(- t1.b)"); + Ok(()) + } + + /// Multiple columns through projection: sort on (t1.a, t1.b) + #[test] + fn resolve_through_projection_multi_column() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .project(vec![col("t1.a"), col("t1.b"), col("t1.c")])? + .build()?; + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + + let sort_exprs = + vec![col("t1.a").sort(true, false), col("t1.b").sort(false, true)]; + let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].expr.to_string(), "t1.a"); + assert!(resolved[0].asc); + assert_eq!(resolved[1].expr.to_string(), "t1.b"); + assert!(!resolved[1].asc); + assert!(resolved[1].nulls_first); + Ok(()) + } + + /// Projection + SubqueryAlias stacked: sort resolves through both. + /// neg_b → (- sub.b) through Projection → (- t1.b) through SubqueryAlias + #[test] + fn resolve_through_projection_and_subquery_alias() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1) + .alias("sub")? + .project(vec![ + col("sub.a"), + (Expr::Negative(Box::new(col("sub.b")))).alias("neg_b"), + ])? + .build()?; + + // Peel: Projection then SubqueryAlias + let LogicalPlan::Projection(proj) = &plan else { + panic!("expected Projection"); + }; + let LogicalPlan::SubqueryAlias(sq) = proj.input.as_ref() else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![col("neg_b").sort(true, false)]; + + // Resolve through Projection: neg_b → (- sub.b) + let after_proj = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; + assert_eq!(after_proj[0].expr.to_string(), "(- sub.b)"); + + // Resolve through SubqueryAlias: (- sub.b) → (- t1.b) + let after_sq = resolve_sort_exprs_through_subquery_alias(&after_proj, sq)?; + assert_eq!(after_sq[0].expr.to_string(), "(- t1.b)"); + assert!(after_sq[0].asc); + assert!(!after_sq[0].nulls_first); + + Ok(()) + } + + // --------------------------------------------------------------- + // Unit tests for resolve_sort_exprs_through_subquery_alias + // --------------------------------------------------------------- + + /// Simple column rename: sub.b → t1.b + #[test] + fn resolve_through_subquery_alias_simple() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![col("sub.b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "t1.b"); + assert!(resolved[0].asc); + assert!(!resolved[0].nulls_first); + Ok(()) + } + + /// Multiple sort columns: sub.a ASC, sub.b DESC → t1.a ASC, t1.b DESC + #[test] + fn resolve_through_subquery_alias_multi_column() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![ + col("sub.a").sort(true, false), + col("sub.b").sort(false, true), + ]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].expr.to_string(), "t1.a"); + assert!(resolved[0].asc); + assert_eq!(resolved[1].expr.to_string(), "t1.b"); + assert!(!resolved[1].asc); + assert!(resolved[1].nulls_first); + Ok(()) + } + + /// Alias name differs from table name: foo.b → t1.b + #[test] + fn resolve_through_subquery_alias_different_name() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("foo")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![col("foo.b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "t1.b"); + Ok(()) + } + + /// Nested expression: (- sub.b) ASC → (- t1.b) ASC + #[test] + fn resolve_through_subquery_alias_nested_expr() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + + let sort_exprs = vec![SortExpr { + expr: Expr::Negative(Box::new(col("sub.b"))), + asc: true, + nulls_first: false, + }]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "(- t1.b)"); + Ok(()) + } } From 0371004e881f1cbc8addfc461a5be68119df9390 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Wed, 6 May 2026 23:49:34 +0530 Subject: [PATCH 13/23] Fix UT --- datafusion/core/src/optimizer_rule_reference.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/optimizer_rule_reference.md b/datafusion/core/src/optimizer_rule_reference.md index fcbb200c71624..a5eef4965c77e 100644 --- a/datafusion/core/src/optimizer_rule_reference.md +++ b/datafusion/core/src/optimizer_rule_reference.md @@ -54,13 +54,14 @@ Rule order matters. The default pipeline may change between releases. | 15 | `filter_null_join_keys` | Adds `IS NOT NULL` filters to nullable equijoin keys that can never match. | | 16 | `eliminate_outer_join` | Rewrites outer joins to inner joins when later filters reject the NULL-extended rows. | | 17 | `push_down_limit` | Moves literal limits closer to scans and unions and merges adjacent limits. | -| 18 | `push_down_filter` | Moves filters as early as possible through filter-commutative operators. | -| 19 | `single_distinct_aggregation_to_group_by` | Rewrites single-column `DISTINCT` aggregations into two-stage `GROUP BY` plans. | -| 20 | `eliminate_group_by_constant` | Removes constant or functionally redundant expressions from `GROUP BY`. | -| 21 | `common_sub_expression_eliminate` | Computes repeated subexpressions once and reuses the result. | -| 22 | `extract_leaf_expressions` | Pulls cheap leaf expressions closer to data sources so later pruning and filter rules can act earlier. | -| 23 | `push_down_leaf_projections` | Pushes the helper projections created by leaf extraction toward leaf inputs. | -| 24 | `optimize_projections` | Prunes unused columns and removes unnecessary logical projections. | +| 18 | `push_down_topk_through_join` | Pushes Sort with LIMIT through joins when sort columns come from the preserved side. | +| 19 | `push_down_filter` | Moves filters as early as possible through filter-commutative operators. | +| 20 | `single_distinct_aggregation_to_group_by` | Rewrites single-column `DISTINCT` aggregations into two-stage `GROUP BY` plans. | +| 21 | `eliminate_group_by_constant` | Removes constant or functionally redundant expressions from `GROUP BY`. | +| 22 | `common_sub_expression_eliminate` | Computes repeated subexpressions once and reuses the result. | +| 23 | `extract_leaf_expressions` | Pulls cheap leaf expressions closer to data sources so later pruning and filter rules can act earlier. | +| 24 | `push_down_leaf_projections` | Pushes the helper projections created by leaf extraction toward leaf inputs. | +| 25 | `optimize_projections` | Prunes unused columns and removes unnecessary logical projections. | ### Physical Optimizer Rules From 38bbf876e46e1f24492d958cb1c33762e1b84b0f Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 9 May 2026 13:11:47 +0530 Subject: [PATCH 14/23] Resolve comment --- .../core/src/optimizer_rule_reference.md | 4 +- datafusion/optimizer/src/optimizer.rs | 2 +- .../src/push_down_topk_through_join.rs | 317 ++++++------ .../push_down_topk_through_join.slt | 475 +++++++++++++++++- 4 files changed, 634 insertions(+), 164 deletions(-) diff --git a/datafusion/core/src/optimizer_rule_reference.md b/datafusion/core/src/optimizer_rule_reference.md index a5eef4965c77e..6902402c4571c 100644 --- a/datafusion/core/src/optimizer_rule_reference.md +++ b/datafusion/core/src/optimizer_rule_reference.md @@ -54,8 +54,8 @@ Rule order matters. The default pipeline may change between releases. | 15 | `filter_null_join_keys` | Adds `IS NOT NULL` filters to nullable equijoin keys that can never match. | | 16 | `eliminate_outer_join` | Rewrites outer joins to inner joins when later filters reject the NULL-extended rows. | | 17 | `push_down_limit` | Moves literal limits closer to scans and unions and merges adjacent limits. | -| 18 | `push_down_topk_through_join` | Pushes Sort with LIMIT through joins when sort columns come from the preserved side. | -| 19 | `push_down_filter` | Moves filters as early as possible through filter-commutative operators. | +| 18 | `push_down_filter` | Moves filters as early as possible through filter-commutative operators. | +| 19 | `push_down_topk_through_join` | Pushes Sort with LIMIT through joins when sort columns come from the preserved side. | | 20 | `single_distinct_aggregation_to_group_by` | Rewrites single-column `DISTINCT` aggregations into two-stage `GROUP BY` plans. | | 21 | `eliminate_group_by_constant` | Removes constant or functionally redundant expressions from `GROUP BY`. | | 22 | `common_sub_expression_eliminate` | Computes repeated subexpressions once and reuses the result. | diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index f00583d5a9d46..c759d998e343b 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -297,8 +297,8 @@ impl Optimizer { Arc::new(EliminateOuterJoin::new()), // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit Arc::new(PushDownLimit::new()), - Arc::new(PushDownTopKThroughJoin::new()), Arc::new(PushDownFilter::new()), + Arc::new(PushDownTopKThroughJoin::new()), Arc::new(SingleDistinctToGroupBy::new()), // The previous optimizations added expressions and projections, // that might benefit from the following rules diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 4a024731f1899..47fa7c92364e5 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -188,28 +188,31 @@ impl OptimizerRule for PushDownTopKThroughJoin { &join.right }; - // Resolve sort exprs further through any SubqueryAlias wrapping the - // preserved child, so we can compare with the inner Sort's expressions. - // - // After intermediate resolution, resolved_sort_exprs = [t1.b ASC]. - // The inner Sort uses [orders.b ASC]. This step maps t1.b → orders.b. - // - // ```text - // Sort(sub.b ASC, fetch=2) - // SubqueryAlias(sub) ← intermediate, already resolved - // Left Join - // SubqueryAlias(t1) ← preserved child, resolve here - // Sort(orders.b ASC, fetch=5) - // TableScan: orders - // ``` - let (inner_child, child_resolved_exprs) = match preserved_child.as_ref() { - LogicalPlan::SubqueryAlias(sq) => { - let exprs = - resolve_sort_exprs_through_subquery_alias(&resolved_sort_exprs, sq)?; - (sq.input.as_ref(), exprs) + // Scan deep inside the preserved child (through SubqueryAlias and + // Projection layers) to find an existing Sort. If found with same + // exprs, tighten its fetch in-place. Otherwise, insert a new Sort + // directly below the join as the preserved child's wrapper. + let mut inner_child = preserved_child.as_ref(); + let mut deep_resolved_exprs = resolved_sort_exprs.clone(); + loop { + match inner_child { + LogicalPlan::SubqueryAlias(sq) => { + deep_resolved_exprs = resolve_sort_exprs_through_subquery_alias( + &deep_resolved_exprs, + sq, + )?; + inner_child = sq.input.as_ref(); + } + LogicalPlan::Projection(proj) => { + deep_resolved_exprs = resolve_sort_exprs_through_projection( + &deep_resolved_exprs, + proj, + )?; + inner_child = proj.input.as_ref(); + } + _ => break, } - _ => (preserved_child.as_ref(), resolved_sort_exprs.clone()), - }; + } // If the inner child is a Limit (PushDownLimit hasn't merged it with // the Sort yet), skip this iteration. PushDownLimit will merge @@ -218,19 +221,24 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); } - // Push through when the preserved child has no Sort, or has a Sort - // with a larger/no fetch limit (our tighter limit reduces data further). + // Determine action based on existing inner Sort: + // - Same exprs, tighter fetch → skip (already optimal) + // - Same exprs, larger/no fetch → tighten in-place + // - Different exprs or no Sort → insert new Sort below the join // - // Example (push): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=10) - // Child limits to 10, our tighter fetch=5 reduces data further. + // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Child limits to 10, our tighter fetch=5 tightens it in-place. // - // Example (push): Sort(a ASC, fetch=5) → Join → Sort(a ASC) - // Child has no fetch (full sort), adding fetch=5 limits early. + // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC) + // Child has no fetch (full sort), tighten to fetch=5. // // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) // Child already limits to 3 rows, pushing fetch=5 won't help. - if let LogicalPlan::Sort(child_sort) = inner_child { - let same_exprs = sort_exprs_equal(&child_sort.expr, &child_resolved_exprs); + // + // Example (new): Sort(b ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Different exprs, insert Sort(b, fetch=5) above preserved child. + let new_preserved_child = if let LogicalPlan::Sort(child_sort) = inner_child { + let same_exprs = sort_exprs_equal(&child_sort.expr, &deep_resolved_exprs); let child_fetch_tighter = match child_sort.fetch { Some(child_fetch) => child_fetch <= fetch, None => false, @@ -238,45 +246,28 @@ impl OptimizerRule for PushDownTopKThroughJoin { if same_exprs && child_fetch_tighter { return Ok(Transformed::no(plan)); } - } - - // Create the new Sort(fetch) on the preserved child. - // Use the resolved expressions for the pushed Sort. - // - // If the inner child is already a Sort with the same expressions but a - // larger fetch, tighten its fetch in-place instead of stacking a - // redundant Sort on top. - // - // When the preserved child is wrapped in SubqueryAlias, the new Sort - // must sit INSIDE the SubqueryAlias (between it and its input), using - // inner-schema column names. - let inner_input: &Arc = match preserved_child.as_ref() { - LogicalPlan::SubqueryAlias(sq) => &sq.input, - _ => preserved_child, - }; - let new_inner_child = match inner_child { - LogicalPlan::Sort(child_sort) - if sort_exprs_equal(&child_sort.expr, &child_resolved_exprs) => - { + if same_exprs { + // Tighten existing Sort in-place by rebuilding the path + // from preserved child down to the Sort. + rebuild_with_tightened_sort(preserved_child.as_ref(), child_sort, fetch)? + } else { + // Different exprs — insert new Sort above the preserved child. + // If the inner Sort has no fetch, our pushed Sort is the only + // row reduction. If it has a fetch, re-sorting a small set is + // cheap and still reduces rows entering the join. Arc::new(LogicalPlan::Sort(SortPlan { - expr: child_sort.expr.clone(), - input: Arc::clone(&child_sort.input), + expr: resolved_sort_exprs, + input: Arc::clone(preserved_child), fetch: Some(fetch), })) } - _ => Arc::new(LogicalPlan::Sort(SortPlan { - expr: child_resolved_exprs, - input: Arc::clone(inner_input), + } else { + // No existing Sort — insert new Sort below the join. + Arc::new(LogicalPlan::Sort(SortPlan { + expr: resolved_sort_exprs, + input: Arc::clone(preserved_child), fetch: Some(fetch), - })), - }; - - // Wrap the new Sort back in SubqueryAlias if the preserved child had one. - let new_preserved_child = match preserved_child.as_ref() { - LogicalPlan::SubqueryAlias(sq) => Arc::new(LogicalPlan::SubqueryAlias( - SubqueryAlias::try_new(new_inner_child, sq.alias.clone())?, - )), - _ => new_inner_child, + })) }; // Reconstruct the join with the new child @@ -442,6 +433,56 @@ fn resolve_sort_exprs_through_subquery_alias( replace_columns_in_sort_exprs(sort_exprs, &replace_map) } +/// Rebuild the tree from `root` down to an existing Sort, tightening the +/// Sort's fetch to `new_fetch`. The path from `root` to the target Sort +/// may contain Projections and SubqueryAliases. +/// +/// Before (new_fetch=2): +/// ```text +/// SubqueryAlias(t1) +/// Projection(a, b AS renamed_b) +/// Sort(t1.b ASC, fetch=10) ← target, fetch too large +/// TableScan: t1 +/// ``` +/// +/// After: +/// ```text +/// SubqueryAlias(t1) ← rebuilt +/// Projection(a, b AS renamed_b) ← rebuilt +/// Sort(t1.b ASC, fetch=2) ← tightened +/// TableScan: t1 +/// ``` +fn rebuild_with_tightened_sort( + root: &LogicalPlan, + target_sort: &SortPlan, + new_fetch: usize, +) -> Result> { + match root { + LogicalPlan::Sort(s) if std::ptr::eq(s, target_sort) => { + Ok(Arc::new(LogicalPlan::Sort(SortPlan { + expr: s.expr.clone(), + input: Arc::clone(&s.input), + fetch: Some(new_fetch), + }))) + } + LogicalPlan::Projection(proj) => { + let new_input = + rebuild_with_tightened_sort(proj.input.as_ref(), target_sort, new_fetch)?; + let mut new_proj = proj.clone(); + new_proj.input = new_input; + Ok(Arc::new(LogicalPlan::Projection(new_proj))) + } + LogicalPlan::SubqueryAlias(sq) => { + let new_input = + rebuild_with_tightened_sort(sq.input.as_ref(), target_sort, new_fetch)?; + Ok(Arc::new(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_input, sq.alias.clone())?, + ))) + } + _ => unreachable!("rebuild_with_tightened_sort: unexpected node"), + } +} + #[cfg(test)] mod test { use super::*; @@ -868,37 +909,28 @@ mod test { ) } - // --------------------------------------------------------------- - // Unit tests for resolve_sort_exprs_through_projection - // --------------------------------------------------------------- - - /// Simple passthrough: sort on a column that projection passes through. - /// Projection: [t1.a, t1.b] → sort on t1.b resolves to t1.b + /// Projection passthrough: sort expr matches a projected column directly. #[test] fn resolve_through_projection_passthrough() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; - let plan = LogicalPlanBuilder::from(t1) .project(vec![col("t1.a"), col("t1.b")])? .build()?; let LogicalPlan::Projection(proj) = &plan else { panic!("expected Projection"); }; - let sort_exprs = vec![col("t1.b").sort(true, false)]; let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; - assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].expr.to_string(), "t1.b"); assert!(resolved[0].asc); Ok(()) } - /// Aliased expression: sort on neg_b resolves to (- t1.b) + /// Projection alias: sort expr references an alias that maps to a negation. #[test] fn resolve_through_projection_alias() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; - let plan = LogicalPlanBuilder::from(t1) .project(vec![ col("t1.a"), @@ -908,31 +940,26 @@ mod test { let LogicalPlan::Projection(proj) = &plan else { panic!("expected Projection"); }; - let sort_exprs = vec![col("neg_b").sort(true, false)]; let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; - assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].expr.to_string(), "(- t1.b)"); Ok(()) } - /// Multiple columns through projection: sort on (t1.a, t1.b) + /// Multi-column resolution preserves direction and nulls_first per column. #[test] fn resolve_through_projection_multi_column() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; - let plan = LogicalPlanBuilder::from(t1) .project(vec![col("t1.a"), col("t1.b"), col("t1.c")])? .build()?; let LogicalPlan::Projection(proj) = &plan else { panic!("expected Projection"); }; - let sort_exprs = vec![col("t1.a").sort(true, false), col("t1.b").sort(false, true)]; let resolved = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; - assert_eq!(resolved.len(), 2); assert_eq!(resolved[0].expr.to_string(), "t1.a"); assert!(resolved[0].asc); @@ -942,12 +969,10 @@ mod test { Ok(()) } - /// Projection + SubqueryAlias stacked: sort resolves through both. - /// neg_b → (- sub.b) through Projection → (- t1.b) through SubqueryAlias + /// Stacked Projection + SubqueryAlias: resolve through both layers. #[test] fn resolve_through_projection_and_subquery_alias() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; - let plan = LogicalPlanBuilder::from(t1) .alias("sub")? .project(vec![ @@ -955,116 +980,104 @@ mod test { (Expr::Negative(Box::new(col("sub.b")))).alias("neg_b"), ])? .build()?; - - // Peel: Projection then SubqueryAlias let LogicalPlan::Projection(proj) = &plan else { panic!("expected Projection"); }; let LogicalPlan::SubqueryAlias(sq) = proj.input.as_ref() else { panic!("expected SubqueryAlias"); }; - let sort_exprs = vec![col("neg_b").sort(true, false)]; - - // Resolve through Projection: neg_b → (- sub.b) let after_proj = resolve_sort_exprs_through_projection(&sort_exprs, proj)?; assert_eq!(after_proj[0].expr.to_string(), "(- sub.b)"); - - // Resolve through SubqueryAlias: (- sub.b) → (- t1.b) let after_sq = resolve_sort_exprs_through_subquery_alias(&after_proj, sq)?; assert_eq!(after_sq[0].expr.to_string(), "(- t1.b)"); assert!(after_sq[0].asc); assert!(!after_sq[0].nulls_first); - Ok(()) } - // --------------------------------------------------------------- - // Unit tests for resolve_sort_exprs_through_subquery_alias - // --------------------------------------------------------------- - - /// Simple column rename: sub.b → t1.b + /// Simple SubqueryAlias resolution: sub.b → t1.b. #[test] fn resolve_through_subquery_alias_simple() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; - let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; let LogicalPlan::SubqueryAlias(sq) = &plan else { panic!("expected SubqueryAlias"); }; - let sort_exprs = vec![col("sub.b").sort(true, false)]; let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; - assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].expr.to_string(), "t1.b"); - assert!(resolved[0].asc); - assert!(!resolved[0].nulls_first); - Ok(()) - } - - /// Multiple sort columns: sub.a ASC, sub.b DESC → t1.a ASC, t1.b DESC - #[test] - fn resolve_through_subquery_alias_multi_column() -> Result<()> { - let t1 = test_table_scan_with_name("t1")?; - - let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; - let LogicalPlan::SubqueryAlias(sq) = &plan else { - panic!("expected SubqueryAlias"); - }; - - let sort_exprs = vec![ - col("sub.a").sort(true, false), - col("sub.b").sort(false, true), - ]; - let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; - - assert_eq!(resolved.len(), 2); - assert_eq!(resolved[0].expr.to_string(), "t1.a"); - assert!(resolved[0].asc); - assert_eq!(resolved[1].expr.to_string(), "t1.b"); - assert!(!resolved[1].asc); - assert!(resolved[1].nulls_first); Ok(()) } - /// Alias name differs from table name: foo.b → t1.b + /// Inner Sort has different exprs WITH fetch → stacked sorts. + /// Sort(b, fetch=2) is inserted above Sort(a, fetch=5). Re-sorting + /// 5 rows is cheap and reduces join input from 5 to 2. #[test] - fn resolve_through_subquery_alias_different_name() -> Result<()> { + fn topk_stacked_when_child_has_different_exprs_with_fetch() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; - let plan = LogicalPlanBuilder::from(t1).alias("foo")?.build()?; - let LogicalPlan::SubqueryAlias(sq) = &plan else { - panic!("expected SubqueryAlias"); - }; + let t1_with_sort = LogicalPlanBuilder::from(t1) + .sort_with_limit(vec![col("t1.a").sort(true, false)], Some(5))? + .build()?; - let sort_exprs = vec![col("foo.b").sort(true, false)]; - let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + let plan = LogicalPlanBuilder::from(t1_with_sort) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(2))? + .build()?; - assert_eq!(resolved.len(), 1); - assert_eq!(resolved[0].expr.to_string(), "t1.b"); - Ok(()) + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=2 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=2 + Sort: t1.a ASC NULLS LAST, fetch=5 + TableScan: t1 + TableScan: t2 + " + ) } - /// Nested expression: (- sub.b) ASC → (- t1.b) ASC + /// Inner Sort has different exprs WITHOUT fetch → stacked sorts. + /// Full sort doesn't limit rows, so pushed Sort(fetch=2) is the + /// only row reduction before the join. #[test] - fn resolve_through_subquery_alias_nested_expr() -> Result<()> { + fn topk_stacked_when_child_has_different_exprs_no_fetch() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; - let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; - let LogicalPlan::SubqueryAlias(sq) = &plan else { - panic!("expected SubqueryAlias"); - }; + let t1_with_sort = LogicalPlanBuilder::from(t1) + .sort(vec![col("t1.a").sort(true, false)])? + .build()?; - let sort_exprs = vec![SortExpr { - expr: Expr::Negative(Box::new(col("sub.b"))), - asc: true, - nulls_first: false, - }]; - let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + let plan = LogicalPlanBuilder::from(t1_with_sort) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Left, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(2))? + .build()?; - assert_eq!(resolved.len(), 1); - assert_eq!(resolved[0].expr.to_string(), "(- t1.b)"); - Ok(()) + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=2 + Left Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=2 + Sort: t1.a ASC NULLS LAST + TableScan: t1 + TableScan: t2 + " + ) } } diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index ee52c59124a20..1a18ea70e990d 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -457,9 +457,9 @@ logical_plan ### SubqueryAlias edge cases ### -# SubqueryAlias without inner Sort: push new Sort through SubqueryAlias +# SubqueryAlias without inner Sort: push new Sort below the join. # Preserved child is SubqueryAlias(t1, TableScan) — no existing Sort to tighten, -# so a new Sort(fetch=2) is inserted inside the SubqueryAlias. +# so a new Sort(fetch=2) is inserted above the SubqueryAlias. query TT EXPLAIN SELECT * FROM ( SELECT t1.a, t1.b, t2.x @@ -472,8 +472,8 @@ logical_plan 01)Sort: sub.b ASC NULLS LAST, fetch=2 02)--SubqueryAlias: sub 03)----Left Join: t1.a = t2.x -04)------SubqueryAlias: t1 -05)--------Sort: t1.b ASC NULLS LAST, fetch=2 +04)------Sort: t1.b ASC NULLS LAST, fetch=2 +05)--------SubqueryAlias: t1 06)----------TableScan: t1 projection=[a, b] 07)------TableScan: t2 projection=[x] @@ -537,8 +537,8 @@ logical_plan 01)Sort: sub.b ASC NULLS LAST, fetch=3 02)--SubqueryAlias: sub 03)----Left Join: foo.a = t2.x -04)------SubqueryAlias: foo -05)--------Sort: t1.b ASC NULLS LAST, fetch=3 +04)------Sort: foo.b ASC NULLS LAST, fetch=3 +05)--------SubqueryAlias: foo 06)----------TableScan: t1 projection=[a, b] 07)------TableScan: t2 projection=[x] @@ -603,8 +603,8 @@ logical_plan 01)Sort: sub.a ASC NULLS LAST, sub.b ASC NULLS LAST, fetch=3 02)--SubqueryAlias: sub 03)----Left Join: t1.a = t2.x -04)------SubqueryAlias: t1 -05)--------Sort: t1.a ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=3 +04)------Sort: t1.a ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=3 +05)--------SubqueryAlias: t1 06)----------TableScan: t1 projection=[a, b] 07)------TableScan: t2 projection=[x] @@ -649,6 +649,463 @@ ORDER BY t1.b ASC LIMIT 3; 3 30 3 4 40 NULL +### +### Descending order and NULLS FIRST cases +### + +# LEFT JOIN: TopK with DESC sort pushed to left child +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b DESC LIMIT 3; +---- +logical_plan +01)Sort: t1.b DESC NULLS FIRST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b DESC NULLS FIRST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b DESC LIMIT 3; +---- +5 50 NULL +4 40 NULL +3 30 3 + +# LEFT JOIN: TopK with ASC NULLS FIRST pushed to left child +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC NULLS FIRST LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS FIRST, fetch=3 +02)--Left Join: t1.a = t2.x +03)----Sort: t1.b ASC NULLS FIRST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT t1.a, t1.b, t2.x +FROM t1 LEFT JOIN t2 ON t1.a = t2.x +ORDER BY t1.b ASC NULLS FIRST LIMIT 3; +---- +1 10 1 +2 20 2 +3 30 3 + +# RIGHT JOIN: TopK with DESC NULLS LAST pushed to right child +query TT +EXPLAIN SELECT t1.a, t2.x, t2.y +FROM t1 RIGHT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y DESC NULLS LAST LIMIT 3; +---- +logical_plan +01)Sort: t2.y DESC NULLS LAST, fetch=3 +02)--Right Join: t1.a = t2.x +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y DESC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +# Verify correctness +query III +SELECT t1.a, t2.x, t2.y +FROM t1 RIGHT JOIN t2 ON t1.a = t2.x +ORDER BY t2.y DESC NULLS LAST LIMIT 3; +---- +NULL 7 700 +NULL 6 600 +3 3 300 + +### +### CROSS JOIN — no pushdown +### + +# CROSS JOIN: no pushdown (no preserved side) +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 CROSS JOIN t2 +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Cross Join: +03)----TableScan: t1 projection=[a, b] +04)----TableScan: t2 projection=[x] + +### +### Multi-level outer joins +### + +# Chained LEFT JOINs: TopK pushed to leftmost preserved child +statement ok +CREATE TABLE t3 (p INT, q INT) AS VALUES + (1, 1000), + (2, 2000), + (3, 3000); + +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x, t3.p +FROM t1 +LEFT JOIN t2 ON t1.a = t2.x +LEFT JOIN t3 ON t1.a = t3.p +ORDER BY t1.b ASC LIMIT 2; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=2 +02)--Left Join: t1.a = t3.p +03)----Sort: t1.b ASC NULLS LAST, fetch=2 +04)------Left Join: t1.a = t2.x +05)--------Sort: t1.b ASC NULLS LAST, fetch=2 +06)----------TableScan: t1 projection=[a, b] +07)--------TableScan: t2 projection=[x] +08)----TableScan: t3 projection=[p] + +# Verify correctness +query IIII +SELECT t1.a, t1.b, t2.x, t3.p +FROM t1 +LEFT JOIN t2 ON t1.a = t2.x +LEFT JOIN t3 ON t1.a = t3.p +ORDER BY t1.b ASC LIMIT 2; +---- +1 10 1 1 +2 20 2 2 + +statement ok +DROP TABLE t3; + +### +### Tied sort key scenarios +### + +# Tied sort keys: pushdown is safe, all tied rows from preserved side appear +statement ok +CREATE TABLE t_tied (a INT, b INT) AS VALUES + (1, 10), + (2, 10), + (3, 10), + (4, 20), + (5, 30); + +statement ok +CREATE TABLE t_other (x INT) AS VALUES (1), (2), (3); + +query TT +EXPLAIN SELECT t_tied.a, t_tied.b, t_other.x +FROM t_tied LEFT JOIN t_other ON t_tied.a = t_other.x +ORDER BY t_tied.b ASC, t_tied.a ASC LIMIT 3; +---- +logical_plan +01)Sort: t_tied.b ASC NULLS LAST, t_tied.a ASC NULLS LAST, fetch=3 +02)--Left Join: t_tied.a = t_other.x +03)----Sort: t_tied.b ASC NULLS LAST, t_tied.a ASC NULLS LAST, fetch=3 +04)------TableScan: t_tied projection=[a, b] +05)----TableScan: t_other projection=[x] + +# Verify correctness: 3 rows with b=10, tied but only 3 emitted by LIMIT +query III +SELECT t_tied.a, t_tied.b, t_other.x +FROM t_tied LEFT JOIN t_other ON t_tied.a = t_other.x +ORDER BY t_tied.b ASC, t_tied.a ASC LIMIT 3; +---- +1 10 1 +2 10 2 +3 10 3 + +statement ok +DROP TABLE t_tied; + +statement ok +DROP TABLE t_other; + +### +### Nested SubqueryAlias cases +### + +# Nested SubqueryAlias: resolve through multiple alias layers +query TT +EXPLAIN SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: outer_sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: outer_sub +03)----Left Join: inner_sub.a = t2.x +04)------Sort: inner_sub.b ASC NULLS LAST, fetch=2 +05)--------SubqueryAlias: inner_sub +06)----------SubqueryAlias: inner_alias +07)------------TableScan: t1 projection=[a, b] +08)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Nested SubqueryAlias with existing inner Sort: tighten fetch in-place. +# The inner Sort(fetch=5) is behind two SubqueryAlias layers. The deep +# resolution finds it, confirms same sort exprs, and tightens to fetch=2. +query TT +EXPLAIN SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: outer_sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: outer_sub +03)----Left Join: inner_sub.a = t2.x +04)------SubqueryAlias: inner_sub +05)--------SubqueryAlias: inner_alias +06)----------Sort: t1.b ASC NULLS LAST, fetch=2 +07)------------TableScan: t1 projection=[a, b] +08)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Nested SubqueryAlias with tighter existing inner Sort: skip (already tighter). +# Inner Sort(fetch=2) is tighter than outer fetch=5, so rule skips. +query TT +EXPLAIN SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 2) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 5; +---- +logical_plan +01)Sort: outer_sub.b ASC NULLS LAST, fetch=5 +02)--SubqueryAlias: outer_sub +03)----Left Join: inner_sub.a = t2.x +04)------SubqueryAlias: inner_sub +05)--------SubqueryAlias: inner_alias +06)----------Sort: t1.b ASC NULLS LAST, fetch=2 +07)------------TableScan: t1 projection=[a, b] +08)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 2) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 5; +---- +1 10 1 +2 20 2 + +# Inner Sort with DIFFERENT exprs and fetch: pushdown still happens. +# Inner Sort(a, fetch=5) already limits to 5 rows. Pushed Sort(b, fetch=2) +# re-sorts those 5 rows (cheap) and reduces to 2 rows entering the join. +query TT +EXPLAIN SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY a ASC LIMIT 5) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: outer_sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: outer_sub +03)----Left Join: inner_sub.a = t2.x +04)------Sort: inner_sub.b ASC NULLS LAST, fetch=2 +05)--------SubqueryAlias: inner_sub +06)----------SubqueryAlias: inner_alias +07)------------Sort: t1.a ASC NULLS LAST, fetch=5 +08)--------------TableScan: t1 projection=[a, b] +09)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY a ASC LIMIT 5) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Inner Sort with different exprs but NO fetch (full sort): pushdown OK. +# Our rule inserts Sort(b, fetch=2) above Sort(a, no fetch). The inner full +# sort is then eliminated by other optimizer rules since it's redundant. +query TT +EXPLAIN SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY a ASC) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +logical_plan +01)Sort: outer_sub.b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: outer_sub +03)----Left Join: inner_sub.a = t2.x +04)------Sort: inner_sub.b ASC NULLS LAST, fetch=2 +05)--------SubqueryAlias: inner_sub +06)----------SubqueryAlias: inner_alias +07)------------TableScan: t1 projection=[a, b] +08)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT inner_sub.a, inner_sub.b, t2.x + FROM ( + SELECT * FROM (SELECT * FROM t1 ORDER BY a ASC) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Tighten Sort through SubqueryAlias + Projection + SubqueryAlias. +# The inner Sort(fetch=5) is behind SubqueryAlias(inner_sub) → Projection(rename) → SubqueryAlias(inner_alias). +# The Projection renames b → renamed_b, so it survives as a plan node. +# Deep resolution looks through the Projection to find the Sort and tightens it to fetch=2. +query TT +EXPLAIN SELECT * FROM ( + SELECT inner_sub.a, inner_sub.renamed_b, t2.x + FROM ( + SELECT a, b AS renamed_b FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY renamed_b ASC LIMIT 2; +---- +logical_plan +01)Sort: outer_sub.renamed_b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: outer_sub +03)----Left Join: inner_sub.a = t2.x +04)------SubqueryAlias: inner_sub +05)--------Projection: inner_alias.a, inner_alias.b AS renamed_b +06)----------SubqueryAlias: inner_alias +07)------------Sort: t1.b ASC NULLS LAST, fetch=2 +08)--------------TableScan: t1 projection=[a, b] +09)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT inner_sub.a, inner_sub.renamed_b, t2.x + FROM ( + SELECT a, b AS renamed_b FROM (SELECT * FROM t1 ORDER BY b ASC LIMIT 5) inner_alias + ) inner_sub + LEFT JOIN t2 ON inner_sub.a = t2.x +) outer_sub +ORDER BY renamed_b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Sort above Projection only needs to sort the projected column subset, +# which is more efficient than sorting all pre-projection columns. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.a, t1.renamed_b, t2.x + FROM (SELECT a, b AS renamed_b FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY renamed_b ASC LIMIT 2; +---- +logical_plan +01)Sort: sub.renamed_b ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: sub +03)----Left Join: t1.a = t2.x +04)------Sort: t1.renamed_b ASC NULLS LAST, fetch=2 +05)--------SubqueryAlias: t1 +06)----------Projection: t1.a, t1.b AS renamed_b +07)------------TableScan: t1 projection=[a, b] +08)------TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT * FROM ( + SELECT t1.a, t1.renamed_b, t2.x + FROM (SELECT a, b AS renamed_b FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY renamed_b ASC LIMIT 2; +---- +1 10 1 +2 20 2 + +# Volatile expression inside Projection of preserved child: pushdown IS safe. +# The Projection computes random() once and names it rand_col. The pushed +# Sort only reorders by the pre-computed rand_col column — it does NOT +# re-evaluate random(). This is different from having random() directly +# in the sort expression (which IS blocked by the volatility check). +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.rand_col, t2.x + FROM (SELECT random() AS rand_col, a FROM t1) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY rand_col ASC LIMIT 2; +---- +logical_plan +01)Sort: sub.rand_col ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: sub +03)----Projection: t1.rand_col, t2.x +04)------Left Join: t1.a = t2.x +05)--------Sort: t1.rand_col ASC NULLS LAST, fetch=2 +06)----------SubqueryAlias: t1 +07)------------Projection: random() AS rand_col, t1.a +08)--------------TableScan: t1 projection=[a] +09)--------TableScan: t2 projection=[x] + ### ### Config reset ### @@ -663,4 +1120,4 @@ statement ok DROP TABLE t1; statement ok -DROP TABLE t2; +DROP TABLE t2; \ No newline at end of file From 8d4fc2ccf9eb33a06b1810908b9d1c5d27aad530 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sat, 9 May 2026 14:35:51 +0530 Subject: [PATCH 15/23] Adds back missing UT --- .../src/push_down_topk_through_join.rs | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 47fa7c92364e5..146b5d6219cc5 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -1008,6 +1008,61 @@ mod test { let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; assert_eq!(resolved.len(), 1); assert_eq!(resolved[0].expr.to_string(), "t1.b"); + assert!(resolved[0].asc); + assert!(!resolved[0].nulls_first); + Ok(()) + } + + /// Multi-column SubqueryAlias resolution preserves direction per column. + #[test] + fn resolve_through_subquery_alias_multi_column() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + let sort_exprs = vec![ + col("sub.a").sort(true, false), + col("sub.b").sort(false, true), + ]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + assert_eq!(resolved.len(), 2); + assert_eq!(resolved[0].expr.to_string(), "t1.a"); + assert!(resolved[0].asc); + assert_eq!(resolved[1].expr.to_string(), "t1.b"); + assert!(!resolved[1].asc); + assert!(resolved[1].nulls_first); + Ok(()) + } + + /// SubqueryAlias with a different alias name (foo ≠ t1). + #[test] + fn resolve_through_subquery_alias_different_name() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let plan = LogicalPlanBuilder::from(t1).alias("foo")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + let sort_exprs = vec![col("foo.b").sort(true, false)]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "t1.b"); + Ok(()) + } + + /// SubqueryAlias with nested expression: (- sub.b) → (- t1.b). + #[test] + fn resolve_through_subquery_alias_nested_expr() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let plan = LogicalPlanBuilder::from(t1).alias("sub")?.build()?; + let LogicalPlan::SubqueryAlias(sq) = &plan else { + panic!("expected SubqueryAlias"); + }; + let sort_exprs = vec![Expr::Negative(Box::new(col("sub.b"))).sort(true, false)]; + let resolved = resolve_sort_exprs_through_subquery_alias(&sort_exprs, sq)?; + assert_eq!(resolved.len(), 1); + assert_eq!(resolved[0].expr.to_string(), "(- t1.b)"); + assert!(resolved[0].asc); Ok(()) } From 80f84bcc73e52e53f0919663993627a08819df16 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Sun, 10 May 2026 10:31:46 +0530 Subject: [PATCH 16/23] Fix explain slt test --- datafusion/sqllogictest/test_files/explain.slt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index fd69bb46599ab..0726b943ef713 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -193,8 +193,8 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE @@ -218,8 +218,8 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE @@ -570,8 +570,8 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE @@ -595,8 +595,8 @@ logical_plan after propagate_empty_relation SAME TEXT AS ABOVE logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE From 1efd76355a28ccd6555a6d95a326ae7d8bd3c3d6 Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Sat, 16 May 2026 10:34:59 +0530 Subject: [PATCH 17/23] Fix sort expr comparision --- .../src/push_down_topk_through_join.rs | 72 +++++++++---------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index 146b5d6219cc5..f34ba2bc12b7d 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -58,7 +58,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::utils::{has_all_column_refs, schema_columns}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{Column, Result}; +use datafusion_common::{Column, Result, internal_err}; use datafusion_expr::logical_plan::{ JoinType, LogicalPlan, Projection, Sort as SortPlan, SubqueryAlias, }; @@ -73,7 +73,7 @@ use datafusion_expr::{Expr, SortExpr}; pub struct PushDownTopKThroughJoin; impl PushDownTopKThroughJoin { - #[expect(missing_docs)] + /// Create a new `PushDownTopKThroughJoin` rule. pub fn new() -> Self { Self {} } @@ -156,7 +156,12 @@ impl OptimizerRule for PushDownTopKThroughJoin { sq, )?; } - _ => unreachable!(), + _ => { + return internal_err!( + "PushDownTopKThroughJoin: unexpected intermediate node: {}", + node.display() + ); + } } } @@ -249,7 +254,11 @@ impl OptimizerRule for PushDownTopKThroughJoin { if same_exprs { // Tighten existing Sort in-place by rebuilding the path // from preserved child down to the Sort. - rebuild_with_tightened_sort(preserved_child.as_ref(), child_sort, fetch)? + rebuild_with_tightened_sort( + preserved_child.as_ref(), + &deep_resolved_exprs, + fetch, + )? } else { // Different exprs — insert new Sort above the preserved child. // If the inner Sort has no fetch, our pushed Sort is the only @@ -290,7 +299,12 @@ impl OptimizerRule for PushDownTopKThroughJoin { LogicalPlan::SubqueryAlias(sq) => LogicalPlan::SubqueryAlias( SubqueryAlias::try_new(new_sort_input, sq.alias.clone())?, ), - _ => unreachable!(), + _ => { + return internal_err!( + "PushDownTopKThroughJoin: unexpected intermediate node: {}", + node.display() + ); + } }); } @@ -310,30 +324,6 @@ impl OptimizerRule for PushDownTopKThroughJoin { } } -/// Resolve sort expressions through a projection by replacing column -/// references with the underlying projection expressions. -/// -/// For example, if sort expr is `b ASC` and projection has `-t1.b AS b`, -/// the resolved sort expr becomes `-t1.b ASC`. -/// -/// Before: -/// ```text -/// Sort: b ASC, fetch=3 -/// Projection: -t1.b AS b -/// Join -/// t1 -/// t2 -/// ``` -/// -/// After resolving, the pushed Sort uses pre-projection expressions: -/// ```text -/// Sort: b ASC, fetch=3 -/// Projection: -t1.b AS b -/// Join -/// Sort: -t1.b ASC, fetch=3 ← resolved through projection -/// t1 -/// t2 -/// ``` /// Replace column references in sort expressions using a name→expr map. /// Uses `transform()` for deep replacement (handles nested expressions /// like `-t1.b` where the column is inside a Negative wrapper). @@ -433,9 +423,9 @@ fn resolve_sort_exprs_through_subquery_alias( replace_columns_in_sort_exprs(sort_exprs, &replace_map) } -/// Rebuild the tree from `root` down to an existing Sort, tightening the -/// Sort's fetch to `new_fetch`. The path from `root` to the target Sort -/// may contain Projections and SubqueryAliases. +/// Rebuild the tree from `root` down to an existing Sort whose expressions +/// match `target_exprs`, tightening its fetch to `new_fetch`. The path from +/// `root` to the target Sort may contain Projections and SubqueryAliases. /// /// Before (new_fetch=2): /// ```text @@ -454,11 +444,11 @@ fn resolve_sort_exprs_through_subquery_alias( /// ``` fn rebuild_with_tightened_sort( root: &LogicalPlan, - target_sort: &SortPlan, + target_exprs: &[SortExpr], new_fetch: usize, ) -> Result> { match root { - LogicalPlan::Sort(s) if std::ptr::eq(s, target_sort) => { + LogicalPlan::Sort(s) if sort_exprs_equal(&s.expr, target_exprs) => { Ok(Arc::new(LogicalPlan::Sort(SortPlan { expr: s.expr.clone(), input: Arc::clone(&s.input), @@ -466,20 +456,26 @@ fn rebuild_with_tightened_sort( }))) } LogicalPlan::Projection(proj) => { - let new_input = - rebuild_with_tightened_sort(proj.input.as_ref(), target_sort, new_fetch)?; + let new_input = rebuild_with_tightened_sort( + proj.input.as_ref(), + target_exprs, + new_fetch, + )?; let mut new_proj = proj.clone(); new_proj.input = new_input; Ok(Arc::new(LogicalPlan::Projection(new_proj))) } LogicalPlan::SubqueryAlias(sq) => { let new_input = - rebuild_with_tightened_sort(sq.input.as_ref(), target_sort, new_fetch)?; + rebuild_with_tightened_sort(sq.input.as_ref(), target_exprs, new_fetch)?; Ok(Arc::new(LogicalPlan::SubqueryAlias( SubqueryAlias::try_new(new_input, sq.alias.clone())?, ))) } - _ => unreachable!("rebuild_with_tightened_sort: unexpected node"), + _ => internal_err!( + "rebuild_with_tightened_sort: unexpected node: {}", + root.display() + ), } } From 93997ebcdf0294b098156116b93cd8af7a02b762 Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Tue, 19 May 2026 21:27:23 +0530 Subject: [PATCH 18/23] Add cross, left and right mark join --- .../src/push_down_topk_through_join.rs | 256 +++++++++++++++--- .../push_down_topk_through_join.slt | 70 ++++- 2 files changed, 287 insertions(+), 39 deletions(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index f34ba2bc12b7d..a2c63494ecebd 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through outer joins +//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through joins +//! whose preserved side is known. //! -//! When a `Sort` with a fetch limit sits above an outer join and all sort +//! When a `Sort` with a fetch limit sits above such a join and all sort //! expressions come from the **preserved** side, this rule inserts a copy //! of the `Sort(fetch)` on that input to reduce the number of rows //! entering the join. @@ -28,6 +29,13 @@ //! left rows. //! - The same reasoning applies symmetrically for RIGHT JOIN and right-side //! columns. +//! - A CROSS JOIN preserves every row from both sides (Cartesian product). +//! The top-N by one side's columns must come from the top-N rows of that +//! side, since each surviving row is duplicated by the other side's row +//! count. +//! - LEFT MARK / RIGHT MARK joins emit exactly one record per row of the +//! marked side (with an extra mark column), so that side is fully +//! preserved and pushdown applies symmetrically to LEFT / RIGHT joins. //! //! The top-level sort is kept for correctness since a 1-to-many join can //! produce more than N output rows from N input rows. @@ -64,9 +72,17 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::{Expr, SortExpr}; -/// Optimization rule that pushes TopK (Sort with fetch) through -/// LEFT / RIGHT outer joins when all sort expressions come from -/// the preserved side. +/// Which child of a join is being treated as the preserved side. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Side { + Left, + Right, +} + +/// Optimization rule that pushes TopK (Sort with fetch) through joins +/// that have a known preserved side (LEFT / RIGHT outer, +/// LEFT MARK / RIGHT MARK, or CROSS) when all sort expressions come +/// from a preserved side. /// /// See module-level documentation for details. #[derive(Default, Debug)] @@ -124,18 +140,30 @@ impl OptimizerRule for PushDownTopKThroughJoin { } }; - // Only outer joins where the preserved side is known. - // Semi/Anti joins are excluded: not all preserved-side rows appear in - // the output (only matched/unmatched rows do), so pushing fetch=N to - // the preserved child can drop rows that would have survived the filter. + // Determine which side(s) of the join are preserved. + // + // - LEFT / LeftMark: only left preserved (and only left appears in + // the output schema for LEFT, or left + mark column for LeftMark). + // - RIGHT / RightMark: symmetric. + // - CROSS JOIN (represented as Inner with no `on` keys and no filter): + // every row from both sides appears in the output (Cartesian + // product), so we can push to whichever side has all the sort cols. // - // Non-equijoin filters in the ON clause are safe: outer joins guarantee - // all preserved-side rows appear in the output regardless of the filter. - // The filter only controls matching (which non-preserved rows pair up), - // not which preserved rows survive. - let preserved_is_left = match join.join_type { - JoinType::Left => true, - JoinType::Right => false, + // For LEFT/RIGHT, non-equijoin filters in the ON clause are safe: + // outer joins guarantee all preserved-side rows appear in the output + // regardless of the filter, and the non-preserved side never appears + // as a standalone unmatched row. + // + // For Inner joins (cross-join detection), the filter check is strict + // (`filter.is_none()`). When an Inner join has a filter, that filter + // can drop rows from either side, so pushing fetch=N may select rows + // that get filtered out while discarding rows that would have matched. + let preserved_candidates: &[Side] = match join.join_type { + JoinType::Left | JoinType::LeftMark => &[Side::Left], + JoinType::Right | JoinType::RightMark => &[Side::Right], + JoinType::Inner if join.on.is_empty() && join.filter.is_none() => { + &[Side::Left, Side::Right] + } _ => return Ok(Transformed::no(plan)), }; @@ -173,24 +201,25 @@ impl OptimizerRule for PushDownTopKThroughJoin { return Ok(Transformed::no(plan)); } - let preserved_schema = if preserved_is_left { - join.left.schema() - } else { - join.right.schema() - }; - let preserved_cols = schema_columns(preserved_schema); - - let all_from_preserved = resolved_sort_exprs - .iter() - .all(|sort_expr| has_all_column_refs(&sort_expr.expr, &preserved_cols)); - if !all_from_preserved { + // Pick the first preserved-side candidate whose schema contains all + // referenced sort columns. For LEFT/RIGHT this is the fixed side; + // for CROSS we try both. + let Some(preserved_side) = preserved_candidates.iter().copied().find(|&side| { + let schema = match side { + Side::Left => join.left.schema(), + Side::Right => join.right.schema(), + }; + let cols = schema_columns(schema); + resolved_sort_exprs + .iter() + .all(|se| has_all_column_refs(&se.expr, &cols)) + }) else { return Ok(Transformed::no(plan)); - } + }; - let preserved_child = if preserved_is_left { - &join.left - } else { - &join.right + let preserved_child = match preserved_side { + Side::Left => &join.left, + Side::Right => &join.right, }; // Scan deep inside the preserved child (through SubqueryAlias and @@ -281,10 +310,9 @@ impl OptimizerRule for PushDownTopKThroughJoin { // Reconstruct the join with the new child let mut new_join = join.clone(); - if preserved_is_left { - new_join.left = new_preserved_child; - } else { - new_join.right = new_preserved_child; + match preserved_side { + Side::Left => new_join.left = new_preserved_child, + Side::Right => new_join.right = new_preserved_child, } // Rebuild the tree: join → intermediate nodes → top-level sort @@ -618,6 +646,162 @@ mod test { ) } + /// CROSS JOIN sorted by left-side columns → pushed to left child. + #[test] + fn topk_pushed_to_left_of_cross_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .cross_join(LogicalPlanBuilder::from(t2).build()?)? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Cross Join: + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// CROSS JOIN sorted by right-side columns → pushed to right child. + #[test] + fn topk_pushed_to_right_of_cross_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .cross_join(LogicalPlanBuilder::from(t2).build()?)? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + Cross Join: + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=3 + TableScan: t2 + " + ) + } + + /// CROSS JOIN sorted by columns from both sides → no pushdown. + #[test] + fn topk_not_pushed_for_cross_join_mixed_side_sort() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .cross_join(LogicalPlanBuilder::from(t2).build()?)? + .sort_with_limit( + vec![(col("t1.b") + col("t2.b")).sort(true, false)], + Some(3), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b + t2.b ASC NULLS LAST, fetch=3 + Cross Join: + TableScan: t1 + TableScan: t2 + " + ) + } + + /// Inner join with no equi-keys but a non-empty filter: the filter can + /// drop rows from either side, so pushing fetch=N can produce fewer + /// output rows than the unpushed plan. + #[test] + fn topk_not_pushed_for_inner_with_filter_no_on() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join_on( + LogicalPlanBuilder::from(t2).build()?, + JoinType::Inner, + vec![col("t1.b").gt(col("t2.b"))], + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + Inner Join: Filter: t1.b > t2.b + TableScan: t1 + TableScan: t2 + " + ) + } + + /// LEFT MARK join: one record per left row (with extra mark column), + /// so left is fully preserved → pushdown to left. + #[test] + fn topk_pushed_to_left_of_left_mark_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::LeftMark, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t1.b ASC NULLS LAST, fetch=3 + LeftMark Join: t1.a = t2.a + Sort: t1.b ASC NULLS LAST, fetch=3 + TableScan: t1 + TableScan: t2 + " + ) + } + + /// RIGHT MARK join: symmetric to LeftMark. + #[test] + fn topk_pushed_to_right_of_right_mark_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).build()?, + JoinType::RightMark, + (vec!["a"], vec!["a"]), + None, + )? + .sort_with_limit(vec![col("t2.b").sort(true, false)], Some(3))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Sort: t2.b ASC NULLS LAST, fetch=3 + RightMark Join: t1.a = t2.a + TableScan: t1 + Sort: t2.b ASC NULLS LAST, fetch=3 + TableScan: t2 + " + ) + } + /// LEFT JOIN but sort on right-side columns → no pushdown. #[test] fn topk_not_pushed_for_wrong_side() -> Result<()> { diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 1a18ea70e990d..59c570d1963f7 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -723,10 +723,12 @@ NULL 6 600 3 3 300 ### -### CROSS JOIN — no pushdown +### CROSS JOIN — pushdown to whichever side has the sort columns ### -# CROSS JOIN: no pushdown (no preserved side) +# CROSS JOIN: TopK on left-side columns pushed to left child. +# Every left row appears |t2| times in the output, so the top-N by +# left columns must come from the top-N left rows. query TT EXPLAIN SELECT t1.a, t1.b, t2.x FROM t1 CROSS JOIN t2 @@ -735,8 +737,70 @@ ORDER BY t1.b ASC LIMIT 3; logical_plan 01)Sort: t1.b ASC NULLS LAST, fetch=3 02)--Cross Join: +03)----Sort: t1.b ASC NULLS LAST, fetch=3 +04)------TableScan: t1 projection=[a, b] +05)----TableScan: t2 projection=[x] + +# Verify correctness +query III +SELECT t1.a, t1.b, t2.x +FROM t1 CROSS JOIN t2 +ORDER BY t1.b ASC, t2.x ASC LIMIT 3; +---- +1 10 1 +1 10 2 +1 10 3 + +# CROSS JOIN: TopK on right-side columns pushed to right child. +query TT +EXPLAIN SELECT t1.a, t2.x, t2.y +FROM t1 CROSS JOIN t2 +ORDER BY t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t2.y ASC NULLS LAST, fetch=3 +02)--Cross Join: +03)----TableScan: t1 projection=[a] +04)----Sort: t2.y ASC NULLS LAST, fetch=3 +05)------TableScan: t2 projection=[x, y] + +# Verify correctness +query III +SELECT t1.a, t2.x, t2.y +FROM t1 CROSS JOIN t2 +ORDER BY t2.y ASC, t1.a ASC LIMIT 3; +---- +1 1 100 +2 1 100 +3 1 100 + +# CROSS JOIN: sort spans both sides → no pushdown +query TT +EXPLAIN SELECT t1.a, t1.b, t2.y +FROM t1 CROSS JOIN t2 +ORDER BY t1.b + t2.y ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b + t2.y ASC NULLS LAST, fetch=3 +02)--Cross Join: 03)----TableScan: t1 projection=[a, b] -04)----TableScan: t2 projection=[x] +04)----TableScan: t2 projection=[y] + +# Inner join with no equi-keys but a non-empty filter +# the filter can drop rows from either side, so pushing fetch=N may select +# rows that get filtered out while discarding rows that would have matched. +# Sort stays above the join, no pushdown to t1. +query TT +EXPLAIN SELECT t1.a, t1.b, t2.x +FROM t1 INNER JOIN t2 ON t1.b > t2.y +ORDER BY t1.b ASC LIMIT 3; +---- +logical_plan +01)Sort: t1.b ASC NULLS LAST, fetch=3 +02)--Projection: t1.a, t1.b, t2.x +03)----Inner Join: Filter: t1.b > t2.y +04)------TableScan: t1 projection=[a, b] +05)------TableScan: t2 projection=[x, y] ### ### Multi-level outer joins From 7b4cf06c45191fc11c2267bdb314861e4745a29f Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Wed, 20 May 2026 20:55:15 +0530 Subject: [PATCH 19/23] Handle volatile expr --- .../src/push_down_topk_through_join.rs | 18 ++++++++- .../push_down_topk_through_join.slt | 37 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_topk_through_join.rs index a2c63494ecebd..9eb2889603b8b 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_topk_through_join.rs @@ -271,7 +271,23 @@ impl OptimizerRule for PushDownTopKThroughJoin { // // Example (new): Sort(b ASC, fetch=5) → Join → Sort(a ASC, fetch=10) // Different exprs, insert Sort(b, fetch=5) above preserved child. - let new_preserved_child = if let LogicalPlan::Sort(child_sort) = inner_child { + // + // If `deep_resolved_exprs` became volatile while resolving through + // projections inside the preserved child (e.g. a `random() AS col` + // projection turns the column reference into `random()` itself), + // structural equality with an existing inner Sort is unsound: two + // syntactically identical `random()` expressions evaluate to + // different values. In that case we must not match against the + // inner Sort — fall back to inserting a new Sort above the + // preserved child using `resolved_sort_exprs`, which is guaranteed + // non-volatile (verified above). + let deep_exprs_volatile = + deep_resolved_exprs.iter().any(|se| se.expr.is_volatile()); + let inner_sort = match inner_child { + LogicalPlan::Sort(s) if !deep_exprs_volatile => Some(s), + _ => None, + }; + let new_preserved_child = if let Some(child_sort) = inner_sort { let same_exprs = sort_exprs_equal(&child_sort.expr, &deep_resolved_exprs); let child_fetch_tighter = match child_sort.fetch { Some(child_fetch) => child_fetch <= fetch, diff --git a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt index 59c570d1963f7..328d6f0b26f69 100644 --- a/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt +++ b/datafusion/sqllogictest/test_files/push_down_topk_through_join.slt @@ -1170,6 +1170,43 @@ logical_plan 08)--------------TableScan: t1 projection=[a] 09)--------TableScan: t2 projection=[x] +# Volatile expression inside the preserved child *and* an existing inner +# Sort whose expr is also `random()`: must NOT tighten the inner Sort. +# +# After resolving the outer ORDER BY (a non-volatile column reference) +# through the inner `random() AS rand_col` Projection, the deep-resolved +# sort expr becomes `random()`. The existing inner Sort below the +# Projection happens to also be on `random()` — but those are independent +# random() invocations producing different orderings. Treating them as +# "same expr" and tightening fetch=10 → fetch=2 would discard rows the +# outer ordering would have ranked high. +# +# Expected: fall back to inserting a new Sort(t1.rand_col, fetch=2) above +# the preserved-side SubqueryAlias; the inner Sort(random(), fetch=10) +# stays untouched. +query TT +EXPLAIN SELECT * FROM ( + SELECT t1.rand_col, t2.x + FROM ( + SELECT random() AS rand_col, a + FROM (SELECT a FROM t1 ORDER BY random() LIMIT 10) + ) t1 + LEFT JOIN t2 ON t1.a = t2.x +) sub +ORDER BY rand_col ASC LIMIT 2; +---- +logical_plan +01)Sort: sub.rand_col ASC NULLS LAST, fetch=2 +02)--SubqueryAlias: sub +03)----Projection: t1.rand_col, t2.x +04)------Left Join: t1.a = t2.x +05)--------Sort: t1.rand_col ASC NULLS LAST, fetch=2 +06)----------SubqueryAlias: t1 +07)------------Projection: random() AS rand_col, t1.a +08)--------------Sort: random() ASC NULLS LAST, fetch=10 +09)----------------TableScan: t1 projection=[a] +10)--------TableScan: t2 projection=[x] + ### ### Config reset ### From ca5b416943ad2028707237f06235501f384b49e5 Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Thu, 28 May 2026 19:06:50 +0530 Subject: [PATCH 20/23] Adds pushdown topk bench --- benchmarks/queries/push_down_topk/q1.sql | 8 ++++++++ benchmarks/queries/push_down_topk/q2.sql | 7 +++++++ benchmarks/queries/push_down_topk/q3.sql | 7 +++++++ benchmarks/queries/push_down_topk/q4.sql | 7 +++++++ benchmarks/queries/push_down_topk/q5.sql | 9 +++++++++ 5 files changed, 38 insertions(+) create mode 100644 benchmarks/queries/push_down_topk/q1.sql create mode 100644 benchmarks/queries/push_down_topk/q2.sql create mode 100644 benchmarks/queries/push_down_topk/q3.sql create mode 100644 benchmarks/queries/push_down_topk/q4.sql create mode 100644 benchmarks/queries/push_down_topk/q5.sql diff --git a/benchmarks/queries/push_down_topk/q1.sql b/benchmarks/queries/push_down_topk/q1.sql new file mode 100644 index 0000000000000..3a2b796f3d70a --- /dev/null +++ b/benchmarks/queries/push_down_topk/q1.sql @@ -0,0 +1,8 @@ +-- LEFT JOIN, ORDER BY column from preserved (left) side, small LIMIT. +-- Canonical case for push_down_topk_through_join: the Sort(fetch=10) is +-- duplicated below the join over the customer scan, so only the top 10 +-- rows (by c_acctbal) are joined against orders. +SELECT c_custkey, c_acctbal +FROM customer LEFT JOIN orders ON c_custkey = o_custkey +ORDER BY c_acctbal +LIMIT 10 \ No newline at end of file diff --git a/benchmarks/queries/push_down_topk/q2.sql b/benchmarks/queries/push_down_topk/q2.sql new file mode 100644 index 0000000000000..1675babcb93cb --- /dev/null +++ b/benchmarks/queries/push_down_topk/q2.sql @@ -0,0 +1,7 @@ +-- RIGHT JOIN, ORDER BY column from preserved (right) side. +-- Symmetric to q1: the Sort(fetch) is pushed below the join over the +-- orders scan (the right/preserved side). +SELECT o_orderkey, o_totalprice +FROM customer RIGHT JOIN orders ON c_custkey = o_custkey +ORDER BY o_totalprice +LIMIT 10 \ No newline at end of file diff --git a/benchmarks/queries/push_down_topk/q3.sql b/benchmarks/queries/push_down_topk/q3.sql new file mode 100644 index 0000000000000..4f53d8ca91ee8 --- /dev/null +++ b/benchmarks/queries/push_down_topk/q3.sql @@ -0,0 +1,7 @@ +-- LEFT JOIN, multi-column ORDER BY (both columns from preserved side). +-- All sort exprs must come from the preserved side for the rule to fire; +-- this query checks that multi-column sorts are still pushed. +SELECT c_custkey, c_acctbal, c_nationkey +FROM customer LEFT JOIN orders ON c_custkey = o_custkey +ORDER BY c_acctbal, c_nationkey +LIMIT 100 \ No newline at end of file diff --git a/benchmarks/queries/push_down_topk/q4.sql b/benchmarks/queries/push_down_topk/q4.sql new file mode 100644 index 0000000000000..c21604b678d50 --- /dev/null +++ b/benchmarks/queries/push_down_topk/q4.sql @@ -0,0 +1,7 @@ +-- CROSS JOIN, ORDER BY column from one side. +-- Cross joins preserve every row from both sides; the rule pushes the +-- Sort(fetch) below the join over the side referenced by ORDER BY. +SELECT c_custkey, c_acctbal +FROM customer CROSS JOIN nation +ORDER BY c_acctbal +LIMIT 10 \ No newline at end of file diff --git a/benchmarks/queries/push_down_topk/q5.sql b/benchmarks/queries/push_down_topk/q5.sql new file mode 100644 index 0000000000000..0db3a8b36ea50 --- /dev/null +++ b/benchmarks/queries/push_down_topk/q5.sql @@ -0,0 +1,9 @@ +-- Negative case: ORDER BY references the probe (non-preserved) side. +-- The rule MUST NOT fire here — orders is the right side of a LEFT JOIN +-- so it isn't preserved (rows can be NULL when there's no match), and +-- pushing a Sort with fetch onto orders would change semantics. +-- Included so the bench harness can verify the rule's selectivity. +SELECT c_custkey, o_totalprice +FROM customer LEFT JOIN orders ON c_custkey = o_custkey +ORDER BY o_totalprice +LIMIT 10 \ No newline at end of file From 216be5edfcbd4e045a9cc4540d21d7240a087afe Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Sat, 30 May 2026 00:27:54 +0530 Subject: [PATCH 21/23] Adds df bench --- benchmarks/src/bin/dfbench.rs | 5 +- benchmarks/src/lib.rs | 1 + benchmarks/src/push_down_topk.rs | 264 +++++++++++++++++++++++++++++++ 3 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 benchmarks/src/push_down_topk.rs diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 50dd99368b7f0..e660fda268f45 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -32,7 +32,8 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; use datafusion_benchmarks::{ - cancellation, clickbench, dict, h2o, hj, imdb, nlj, smj, sort_tpch, tpcds, tpch, + cancellation, clickbench, dict, h2o, hj, imdb, nlj, push_down_topk, smj, sort_tpch, + tpcds, tpch, }; #[derive(Debug, Parser)] @@ -51,6 +52,7 @@ enum Options { HJ(hj::RunOpt), Imdb(imdb::RunOpt), Nlj(nlj::RunOpt), + PushDownTopk(push_down_topk::RunOpt), Smj(smj::RunOpt), SortPushdown(sort_pushdown::RunOpt), SortTpch(sort_tpch::RunOpt), @@ -72,6 +74,7 @@ pub async fn main() -> Result<()> { Options::HJ(opt) => opt.run().await, Options::Imdb(opt) => Box::pin(opt.run()).await, Options::Nlj(opt) => opt.run().await, + Options::PushDownTopk(opt) => opt.run().await, Options::Smj(opt) => opt.run().await, Options::SortPushdown(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index f41fd5ebed205..0148eab6a9b04 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -24,6 +24,7 @@ pub mod h2o; pub mod hj; pub mod imdb; pub mod nlj; +pub mod push_down_topk; pub mod smj; pub mod sort_pushdown; pub mod sort_tpch; diff --git a/benchmarks/src/push_down_topk.rs b/benchmarks/src/push_down_topk.rs new file mode 100644 index 0000000000000..568792e4f8550 --- /dev/null +++ b/benchmarks/src/push_down_topk.rs @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark for `push_down_topk_through_join`. +//! +//! Runs SQL files from `queries/push_down_topk/` against TPC-H +//! `customer`, `orders`, and `nation`. Intended to be run on a branch +//! with the `push_down_topk_through_join` rule registered and +//! against a baseline that does not register the rule, with results +//! compared via `compare.py`. +//! +//! # Usage +//! +//! ```text +//! # Generate TPC-H SF=1 (one-time) +//! ./bench.sh data tpch +//! +//! # Run with rule registered (this branch) and write results +//! ./bench.sh run push_down_topk -o pr.json +//! +//! # Run again on a baseline (e.g. main, or this branch with rule +//! # registration reverted) and write results +//! ./bench.sh run push_down_topk -o baseline.json +//! +//! ./compare.py baseline.json pr.json +//! ``` + +use clap::Args; +use futures::StreamExt; +use std::path::PathBuf; +use std::sync::Arc; + +use datafusion::datasource::TableProvider; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; +use datafusion::physical_plan::{ + display::DisplayableExecutionPlan, displayable, execute_stream, +}; +use datafusion::prelude::*; +use datafusion_common::DEFAULT_PARQUET_EXTENSION; +use datafusion_common::instant::Instant; + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult, print_memory_stats}; + +const PUSH_DOWN_TOPK_QUERY_DIR: &str = "queries/push_down_topk"; + +#[derive(Debug, Args)] +pub struct RunOpt { + #[command(flatten)] + common: CommonOpt, + + /// Query number (1-N). If unset, runs every query in the directory. + #[arg(short, long)] + pub query: Option, + + /// Path to TPC-H parquet directory (must contain `customer`, `orders`, + /// `nation` subdirectories). + #[arg(required = true, short = 'p', long = "path")] + path: PathBuf, + + /// Path to JSON benchmark result, comparable via `compare.py`. + #[arg(short = 'o', long = "output")] + output_path: Option, + + /// Path to directory containing query SQL files. + /// Defaults to `queries/push_down_topk/` relative to current directory. + #[arg(long = "queries-path")] + queries_path: Option, +} + +impl RunOpt { + const TABLES: [&'static str; 3] = ["customer", "orders", "nation"]; + + fn queries_dir(&self) -> PathBuf { + self.queries_path + .clone() + .unwrap_or_else(|| PathBuf::from(PUSH_DOWN_TOPK_QUERY_DIR)) + } + + fn load_query(&self, query_id: usize) -> Result { + let path = self.queries_dir().join(format!("q{query_id}.sql")); + std::fs::read_to_string(&path).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Failed to read query file {}: {e}", + path.display() + )) + }) + } + + fn available_queries(&self) -> Vec { + let dir = self.queries_dir(); + let mut ids = Vec::new(); + if let Ok(entries) = std::fs::read_dir(&dir) { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if let Some(rest) = name.strip_prefix('q') + && let Some(num_str) = rest.strip_suffix(".sql") + && let Ok(id) = num_str.parse::() + { + ids.push(id); + } + } + } + ids.sort(); + ids + } + + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let query_ids = match self.query { + Some(query_id) => vec![query_id], + None => self.available_queries(), + }; + + for query_id in query_ids { + benchmark_run.start_new_case(&format!("{query_id}")); + + match self.benchmark_query(query_id).await { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query(&self, query_id: usize) -> Result> { + let sql = self.load_query(query_id)?; + + let config = self.common.config()?; + let rt = self.common.build_runtime()?; + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(rt) + .with_default_features() + .build(); + let ctx = SessionContext::from(state); + + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + let row_count = self.execute_query(&ctx, &sql).await?; + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + print_memory_stats(); + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::TABLES { + let provider = self.get_table(ctx, table).await?; + ctx.register_table(table, provider)?; + } + Ok(()) + } + + async fn execute_query(&self, ctx: &SessionContext, sql: &str) -> Result { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + + let mut row_count = 0; + let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + } + + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + } + + Ok(row_count) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let state = ctx.state(); + let table_path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let options = ListingOptions::new(format) + .with_file_extension(DEFAULT_PARQUET_EXTENSION) + .with_collect_stat(true); + let table_path = ListingTableUrl::parse(table_path)?; + let schema = options.infer_schema(&state, &table_path).await?; + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } +} From 28c72d82eeeb20708578065399cd9e89ea7b5d1c Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Mon, 1 Jun 2026 11:32:42 +0530 Subject: [PATCH 22/23] Merge push down topk with push down limit rule --- .../core/src/optimizer_rule_reference.md | 56 +- datafusion/optimizer/src/lib.rs | 1 - datafusion/optimizer/src/optimizer.rs | 2 - datafusion/optimizer/src/push_down_limit.rs | 244 +++---- .../topk_through_join.rs} | 605 +++++++----------- 5 files changed, 386 insertions(+), 522 deletions(-) rename datafusion/optimizer/src/{push_down_topk_through_join.rs => push_down_limit/topk_through_join.rs} (64%) diff --git a/datafusion/core/src/optimizer_rule_reference.md b/datafusion/core/src/optimizer_rule_reference.md index a754d777f2dfc..94789a067358c 100644 --- a/datafusion/core/src/optimizer_rule_reference.md +++ b/datafusion/core/src/optimizer_rule_reference.md @@ -35,34 +35,34 @@ Rule order matters. The default pipeline may change between releases. ### Logical Optimizer Rules -| order | rule | summary | -| ----- | ----------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | -| 1 | `rewrite_set_comparison` | Rewrites `ANY` and `ALL` set-comparison subqueries into `EXISTS`-based boolean expressions with correct SQL NULL semantics. | -| 2 | `optimize_unions` | Flattens nested unions and removes unions with a single input. | -| 3 | `unions_to_filter` | Merges `UNION DISTINCT` branches that share the same source into a single filtered branch with a disjunctive predicate. | -| 4 | `simplify_expressions` | Constant-folds and simplifies expressions while preserving output names. | -| 5 | `replace_distinct_aggregate` | Rewrites `DISTINCT` and `DISTINCT ON` operators into aggregate-based plans that later rules can optimize further. | -| 6 | `eliminate_join` | Replaces keyless inner joins with a literal `false` filter by an empty relation. | -| 7 | `decorrelate_predicate_subquery` | Converts eligible `IN` and `EXISTS` predicate subqueries into semi or anti joins. | -| 8 | `scalar_subquery_to_join` | Rewrites eligible scalar subqueries into joins and adds schema-preserving projections. | -| 9 | `decorrelate_lateral_join` | Rewrites eligible lateral joins into regular joins. | -| 10 | `extract_equijoin_predicate` | Splits join filters into equijoin keys and residual predicates. | -| 11 | `eliminate_duplicated_expr` | Removes duplicate expressions from projections, aggregates, and similar operators. | -| 12 | `eliminate_filter` | Drops always-true filters and replaces always-false or NULL filters with empty relations. | -| 13 | `eliminate_cross_join` | Uses filter predicates to replace cross joins with inner joins when join keys can be found. | -| 14 | `eliminate_limit` | Removes no-op limits and simplifies trivial limit shapes. | -| 15 | `propagate_empty_relation` | Pushes empty-relation knowledge upward so operators fed by no rows collapse early. | -| 16 | `filter_null_join_keys` | Adds `IS NOT NULL` filters to nullable equijoin keys that can never match. | -| 17 | `eliminate_outer_join` | Rewrites outer joins to inner joins when later filters reject the NULL-extended rows. | -| 18 | `push_down_limit` | Moves literal limits closer to scans and unions and merges adjacent limits. | -| 19 | `push_down_filter` | Moves filters as early as possible through filter-commutative operators. | -| 20 | `push_down_topk_through_join` | Pushes Sort with LIMIT through joins when sort columns come from the preserved side. | -| 21 | `single_distinct_aggregation_to_group_by` | Rewrites single-column `DISTINCT` aggregations into two-stage `GROUP BY` plans. | -| 22 | `eliminate_group_by_constant` | Removes constant or functionally redundant expressions from `GROUP BY`. | -| 23 | `common_sub_expression_eliminate` | Computes repeated subexpressions once and reuses the result. | -| 24 | `extract_leaf_expressions` | Pulls cheap leaf expressions closer to data sources so later pruning and filter rules can act earlier. | -| 25 | `push_down_leaf_projections` | Pushes the helper projections created by leaf extraction toward leaf inputs. | -| 26 | `optimize_projections` | Prunes unused columns and removes unnecessary logical projections. | +| order | rule | summary | +| ----------------------------------------------------------------------------------- | ----------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | +| 1 | `rewrite_set_comparison` | Rewrites `ANY` and `ALL` set-comparison subqueries into `EXISTS`-based boolean expressions with correct SQL NULL semantics. | +| 2 | `optimize_unions` | Flattens nested unions and removes unions with a single input. | +| 3 | `unions_to_filter` | Merges `UNION DISTINCT` branches that share the same source into a single filtered branch with a disjunctive predicate. | +| 4 | `simplify_expressions` | Constant-folds and simplifies expressions while preserving output names. | +| 5 | `replace_distinct_aggregate` | Rewrites `DISTINCT` and `DISTINCT ON` operators into aggregate-based plans that later rules can optimize further. | +| 6 | `eliminate_join` | Replaces keyless inner joins with a literal `false` filter by an empty relation. | +| 7 | `decorrelate_predicate_subquery` | Converts eligible `IN` and `EXISTS` predicate subqueries into semi or anti joins. | +| 8 | `scalar_subquery_to_join` | Rewrites eligible scalar subqueries into joins and adds schema-preserving projections. | +| 9 | `decorrelate_lateral_join` | Rewrites eligible lateral joins into regular joins. | +| 10 | `extract_equijoin_predicate` | Splits join filters into equijoin keys and residual predicates. | +| 11 | `eliminate_duplicated_expr` | Removes duplicate expressions from projections, aggregates, and similar operators. | +| 12 | `eliminate_filter` | Drops always-true filters and replaces always-false or NULL filters with empty relations. | +| 13 | `eliminate_cross_join` | Uses filter predicates to replace cross joins with inner joins when join keys can be found. | +| 14 | `eliminate_limit` | Removes no-op limits and simplifies trivial limit shapes. | +| 15 | `propagate_empty_relation` | Pushes empty-relation knowledge upward so operators fed by no rows collapse early. | +| 16 | `filter_null_join_keys` | Adds `IS NOT NULL` filters to nullable equijoin keys that can never match. | +| 17 | `eliminate_outer_join` | Rewrites outer joins to inner joins when later filters reject the NULL-extended rows. | +| 18 | `push_down_limit` | Moves literal limits closer to scans and unions and merges adjacent limits, and pushes | +| `Sort(fetch=N)` onto a join's preserved-side child for LEFT/RIGHT/CROSS/MARK joins. | +| 19 | `push_down_filter` | Moves filters as early as possible through filter-commutative operators. | +| 20 | `single_distinct_aggregation_to_group_by` | Rewrites single-column `DISTINCT` aggregations into two-stage `GROUP BY` plans. | +| 21 | `eliminate_group_by_constant` | Removes constant or functionally redundant expressions from `GROUP BY`. | +| 22 | `common_sub_expression_eliminate` | Computes repeated subexpressions once and reuses the result. | +| 23 | `extract_leaf_expressions` | Pulls cheap leaf expressions closer to data sources so later pruning and filter rules can act earlier. | +| 24 | `push_down_leaf_projections` | Pushes the helper projections created by leaf extraction toward leaf inputs. | +| 25 | `optimize_projections` | Prunes unused columns and removes unnecessary logical projections. | ### Physical Optimizer Rules diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 019dd582d65c5..fbe7ad2f4d327 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -65,7 +65,6 @@ pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; -pub mod push_down_topk_through_join; pub mod replace_distinct_aggregate; pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index dfec8be093932..a765d7f27a51e 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -60,7 +60,6 @@ use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; -use crate::push_down_topk_through_join::PushDownTopKThroughJoin; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; @@ -309,7 +308,6 @@ impl Optimizer { // Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit Arc::new(PushDownLimit::new()), Arc::new(PushDownFilter::new()), - Arc::new(PushDownTopKThroughJoin::new()), Arc::new(SingleDistinctToGroupBy::new()), // The previous optimizations added expressions and projections, // that might benefit from the following rules diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 4a26cd5884f6b..34db732ecdca9 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -29,6 +29,9 @@ use datafusion_common::utils::combine_limit; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; use datafusion_expr::{FetchType, SkipType, lit}; +mod topk_through_join; +use topk_through_join::push_topk_through_join; + /// Optimization rule that tries to push down `LIMIT`. //. It will push down through projection, limits (taking the smaller limit) #[derive(Default, Debug)] @@ -47,146 +50,159 @@ impl OptimizerRule for PushDownLimit { true } - #[expect(clippy::only_used_in_recursion)] fn rewrite( &self, plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let LogicalPlan::Limit(mut limit) = plan else { - return Ok(Transformed::no(plan)); - }; + match plan { + LogicalPlan::Limit(limit) => rewrite_limit(limit, config), + LogicalPlan::Sort(s) if s.fetch.is_some() => { + push_topk_through_join(LogicalPlan::Sort(s)) + } + other => Ok(Transformed::no(other)), + } + } - // Currently only rewrite if skip and fetch are both literals - let SkipType::Literal(skip) = limit.get_skip_type()? else { + fn name(&self) -> &str { + "push_down_limit" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +/// Limit-side dispatch (split out from `rewrite` so that the top-level +/// match in `OptimizerRule::rewrite` reads as a parallel branch alongside +/// the Sort-with-fetch handler). +#[expect(clippy::only_used_in_recursion)] +fn rewrite_limit( + mut limit: Limit, + config: &dyn OptimizerConfig, +) -> Result> { + // Currently only rewrite if skip and fetch are both literals + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + // Merge the Parent Limit and the Child Limit. + if let LogicalPlan::Limit(child) = limit.input.as_ref() { + let SkipType::Literal(child_skip) = child.get_skip_type()? else { return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - // Merge the Parent Limit and the Child Limit. - if let LogicalPlan::Limit(child) = limit.input.as_ref() { - let SkipType::Literal(child_skip) = child.get_skip_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); - }; - let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); - }; - - let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); - let plan = LogicalPlan::Limit(Limit { - skip: Some(Box::new(lit(skip as i64))), - fetch: fetch.map(|f| Box::new(lit(f as i64))), - input: Arc::clone(&child.input), - }); + let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); + let new_limit = Limit { + skip: Some(Box::new(lit(skip as i64))), + fetch: fetch.map(|f| Box::new(lit(f as i64))), + input: Arc::clone(&child.input), + }; - // recursively reapply the rule on the new plan - return self.rewrite(plan, config); - } + // recursively reapply the rule on the new limit + return rewrite_limit(new_limit, config); + } - // no fetch to push, so return the original plan - let Some(fetch) = fetch else { - return Ok(Transformed::no(LogicalPlan::Limit(limit))); - }; + // no fetch to push, so return the original plan + let Some(fetch) = fetch else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; - match Arc::unwrap_or_clone(limit.input) { - LogicalPlan::TableScan(mut scan) => { - let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; - let new_fetch = scan + match Arc::unwrap_or_clone(limit.input) { + LogicalPlan::TableScan(mut scan) => { + let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; + let new_fetch = scan + .fetch + .map(|x| min(x, rows_needed)) + .or(Some(rows_needed)); + if new_fetch == scan.fetch { + original_limit(skip, fetch, LogicalPlan::TableScan(scan)) + } else { + // push limit into the table scan itself + scan.fetch = scan .fetch .map(|x| min(x, rows_needed)) .or(Some(rows_needed)); - if new_fetch == scan.fetch { - original_limit(skip, fetch, LogicalPlan::TableScan(scan)) - } else { - // push limit into the table scan itself - scan.fetch = scan - .fetch - .map(|x| min(x, rows_needed)) - .or(Some(rows_needed)); - transformed_limit(skip, fetch, LogicalPlan::TableScan(scan)) - } - } - LogicalPlan::Union(mut union) => { - // push limits to each input of the union - union.inputs = union - .inputs - .into_iter() - .map(|input| make_arc_limit(0, fetch + skip, input)) - .collect(); - transformed_limit(skip, fetch, LogicalPlan::Union(union)) + transformed_limit(skip, fetch, LogicalPlan::TableScan(scan)) } + } + LogicalPlan::Union(mut union) => { + // push limits to each input of the union + union.inputs = union + .inputs + .into_iter() + .map(|input| make_arc_limit(0, fetch + skip, input)) + .collect(); + transformed_limit(skip, fetch, LogicalPlan::Union(union)) + } + + LogicalPlan::Join(join) => { + Ok(push_down_join(join, fetch + skip).update_data(|join| { + make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) + })) + } - LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) - .update_data(|join| { - make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) - })), - - LogicalPlan::Sort(mut sort) => { - let new_fetch = { - let sort_fetch = skip + fetch; - Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) - }; - if new_fetch == sort.fetch { - if skip > 0 { - original_limit(skip, fetch, LogicalPlan::Sort(sort)) - } else { - Ok(Transformed::yes(LogicalPlan::Sort(sort))) - } + LogicalPlan::Sort(mut sort) => { + let new_fetch = { + let sort_fetch = skip + fetch; + Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) + }; + if new_fetch == sort.fetch { + if skip > 0 { + original_limit(skip, fetch, LogicalPlan::Sort(sort)) } else { - sort.fetch = new_fetch; - limit.input = Arc::new(LogicalPlan::Sort(sort)); - Ok(Transformed::yes(LogicalPlan::Limit(limit))) + Ok(Transformed::yes(LogicalPlan::Sort(sort))) } + } else { + sort.fetch = new_fetch; + limit.input = Arc::new(LogicalPlan::Sort(sort)); + Ok(Transformed::yes(LogicalPlan::Limit(limit))) } - LogicalPlan::Projection(mut proj) => { - // commute - limit.input = Arc::clone(&proj.input); - let new_limit = LogicalPlan::Limit(limit); - proj.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::Projection(proj))) - } - LogicalPlan::SubqueryAlias(mut subquery_alias) => { - // commute - limit.input = Arc::clone(&subquery_alias.input); - let new_limit = LogicalPlan::Limit(limit); - subquery_alias.input = Arc::new(new_limit); - Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) - } - LogicalPlan::Extension(extension_plan) - if extension_plan.node.supports_limit_pushdown() => - { - let new_children = extension_plan - .node - .inputs() - .into_iter() - .map(|child| { - LogicalPlan::Limit(Limit { - skip: None, - fetch: Some(Box::new(lit((fetch + skip) as i64))), - input: Arc::new(child.clone()), - }) + } + LogicalPlan::Projection(mut proj) => { + // commute + limit.input = Arc::clone(&proj.input); + let new_limit = LogicalPlan::Limit(limit); + proj.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::Projection(proj))) + } + LogicalPlan::SubqueryAlias(mut subquery_alias) => { + // commute + limit.input = Arc::clone(&subquery_alias.input); + let new_limit = LogicalPlan::Limit(limit); + subquery_alias.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) + } + LogicalPlan::Extension(extension_plan) + if extension_plan.node.supports_limit_pushdown() => + { + let new_children = extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(lit((fetch + skip) as i64))), + input: Arc::new(child.clone()), }) - .collect::>(); + }) + .collect::>(); - // Create a new extension node with updated inputs - let child_plan = LogicalPlan::Extension(extension_plan); - let new_extension = - child_plan.with_new_exprs(child_plan.expressions(), new_children)?; + // Create a new extension node with updated inputs + let child_plan = LogicalPlan::Extension(extension_plan); + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), new_children)?; - transformed_limit(skip, fetch, new_extension) - } - input => original_limit(skip, fetch, input), + transformed_limit(skip, fetch, new_extension) } - } - - fn name(&self) -> &str { - "push_down_limit" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) + input => original_limit(skip, fetch, input), } } diff --git a/datafusion/optimizer/src/push_down_topk_through_join.rs b/datafusion/optimizer/src/push_down_limit/topk_through_join.rs similarity index 64% rename from datafusion/optimizer/src/push_down_topk_through_join.rs rename to datafusion/optimizer/src/push_down_limit/topk_through_join.rs index 9eb2889603b8b..567ca5cab6f25 100644 --- a/datafusion/optimizer/src/push_down_topk_through_join.rs +++ b/datafusion/optimizer/src/push_down_limit/topk_through_join.rs @@ -15,56 +15,23 @@ // specific language governing permissions and limitations // under the License. -//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through joins -//! whose preserved side is known. +//! Sort(fetch) → Join pushdown — a sub-module of `push_down_limit`. //! -//! When a `Sort` with a fetch limit sits above such a join and all sort -//! expressions come from the **preserved** side, this rule inserts a copy -//! of the `Sort(fetch)` on that input to reduce the number of rows -//! entering the join. +//! When a `Sort` with a fetch limit (TopK) sits above a join whose +//! preserved side is known (LEFT / RIGHT / LeftMark / RightMark / CROSS) +//! and all sort expressions come from the preserved side, we insert a +//! copy of the `Sort(fetch)` onto that input to reduce rows entering +//! the join. The outer `Sort` is kept because a 1-to-many join can +//! produce more than N output rows from N preserved-side rows. //! -//! This is correct because: -//! - A LEFT JOIN preserves every left row (each appears at least once in the -//! output). The final top-N by left-side columns must come from the top-N -//! left rows. -//! - The same reasoning applies symmetrically for RIGHT JOIN and right-side -//! columns. -//! - A CROSS JOIN preserves every row from both sides (Cartesian product). -//! The top-N by one side's columns must come from the top-N rows of that -//! side, since each surviving row is duplicated by the other side's row -//! count. -//! - LEFT MARK / RIGHT MARK joins emit exactly one record per row of the -//! marked side (with an extra mark column), so that side is fully -//! preserved and pushdown applies symmetrically to LEFT / RIGHT joins. -//! -//! The top-level sort is kept for correctness since a 1-to-many join can -//! produce more than N output rows from N input rows. -//! -//! ## Example -//! -//! Before: -//! ```text -//! Sort: t1.b ASC, fetch=3 -//! Left Join: t1.a = t2.c -//! Scan: t1 ← scans ALL rows -//! Scan: t2 -//! ``` -//! -//! After: -//! ```text -//! Sort: t1.b ASC, fetch=3 -//! Left Join: t1.a = t2.c -//! Sort: t1.b ASC, fetch=3 ← pushed down -//! Scan: t1 -//! Scan: t2 -//! ``` +//! Dispatched from `PushDownLimit::rewrite` when the plan node is +//! `LogicalPlan::Sort` with `fetch.is_some()`. +use std::collections::HashMap; use std::sync::Arc; -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; - use crate::utils::{has_all_column_refs, schema_columns}; + use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, Result, internal_err}; use datafusion_expr::logical_plan::{ @@ -79,309 +46,235 @@ enum Side { Right, } -/// Optimization rule that pushes TopK (Sort with fetch) through joins -/// that have a known preserved side (LEFT / RIGHT outer, -/// LEFT MARK / RIGHT MARK, or CROSS) when all sort expressions come -/// from a preserved side. -/// -/// See module-level documentation for details. -#[derive(Default, Debug)] -pub struct PushDownTopKThroughJoin; - -impl PushDownTopKThroughJoin { - /// Create a new `PushDownTopKThroughJoin` rule. - pub fn new() -> Self { - Self {} +/// Top-level pushdown for `Sort(fetch) → ... → Join` patterns. The plan +/// passed in is guaranteed by the caller to be `LogicalPlan::Sort` with +/// `fetch.is_some()`; we re-bind to a borrow inside. +pub(super) fn push_topk_through_join( + plan: LogicalPlan, +) -> Result> { + let LogicalPlan::Sort(sort) = &plan else { + return Ok(Transformed::no(plan)); + }; + let Some(fetch) = sort.fetch else { + return Ok(Transformed::no(plan)); + }; + + // Don't push if any sort expression is non-deterministic (e.g. + // `random()`). Duplicating such expressions would produce different + // values at each evaluation point, potentially changing results. + if sort.expr.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); } -} - -impl OptimizerRule for PushDownTopKThroughJoin { - fn supports_rewrite(&self) -> bool { - true - } - - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - // Match Sort with fetch (TopK) - let LogicalPlan::Sort(sort) = &plan else { - return Ok(Transformed::no(plan)); - }; - let Some(fetch) = sort.fetch else { - return Ok(Transformed::no(plan)); - }; - - // Don't push if any sort expression is non-deterministic (e.g. random()). - // Duplicating such expressions would produce different values at each - // evaluation point, potentially changing the result. - if sort.expr.iter().any(|se| se.expr.is_volatile()) { - return Ok(Transformed::no(plan)); - } - // Peel through transparent nodes (SubqueryAlias, Projection) to find - // the Join. Track intermediate nodes so we can reconstruct the tree - // and resolve sort expressions through them. - let mut current = sort.input.as_ref(); - let mut intermediates: Vec<&LogicalPlan> = Vec::new(); - let join = loop { - match current { - LogicalPlan::Join(join) => break join, - LogicalPlan::Projection(proj) => { - intermediates.push(current); - current = proj.input.as_ref(); - } - LogicalPlan::SubqueryAlias(sq) => { - intermediates.push(current); - current = sq.input.as_ref(); - } - _ => return Ok(Transformed::no(plan)), + // Peel through transparent nodes (SubqueryAlias, Projection) to + // find the Join. Track intermediates so we can reconstruct the tree + // and resolve sort expressions through them. + let mut current = sort.input.as_ref(); + let mut intermediates: Vec<&LogicalPlan> = Vec::new(); + let join = loop { + match current { + LogicalPlan::Join(join) => break join, + LogicalPlan::Projection(proj) => { + intermediates.push(current); + current = proj.input.as_ref(); } - }; - - // Determine which side(s) of the join are preserved. - // - // - LEFT / LeftMark: only left preserved (and only left appears in - // the output schema for LEFT, or left + mark column for LeftMark). - // - RIGHT / RightMark: symmetric. - // - CROSS JOIN (represented as Inner with no `on` keys and no filter): - // every row from both sides appears in the output (Cartesian - // product), so we can push to whichever side has all the sort cols. - // - // For LEFT/RIGHT, non-equijoin filters in the ON clause are safe: - // outer joins guarantee all preserved-side rows appear in the output - // regardless of the filter, and the non-preserved side never appears - // as a standalone unmatched row. - // - // For Inner joins (cross-join detection), the filter check is strict - // (`filter.is_none()`). When an Inner join has a filter, that filter - // can drop rows from either side, so pushing fetch=N may select rows - // that get filtered out while discarding rows that would have matched. - let preserved_candidates: &[Side] = match join.join_type { - JoinType::Left | JoinType::LeftMark => &[Side::Left], - JoinType::Right | JoinType::RightMark => &[Side::Right], - JoinType::Inner if join.on.is_empty() && join.filter.is_none() => { - &[Side::Left, Side::Right] + LogicalPlan::SubqueryAlias(sq) => { + intermediates.push(current); + current = sq.input.as_ref(); } _ => return Ok(Transformed::no(plan)), - }; - - // Resolve sort expressions through all intermediate nodes (Projection, - // SubqueryAlias) so that column references match the join's schema. - let mut resolved_sort_exprs = sort.expr.clone(); - for node in &intermediates { - match node { - LogicalPlan::Projection(proj) => { - resolved_sort_exprs = resolve_sort_exprs_through_projection( - &resolved_sort_exprs, - proj, - )?; - } - LogicalPlan::SubqueryAlias(sq) => { - resolved_sort_exprs = resolve_sort_exprs_through_subquery_alias( - &resolved_sort_exprs, - sq, - )?; - } - _ => { - return internal_err!( - "PushDownTopKThroughJoin: unexpected intermediate node: {}", - node.display() - ); - } - } } - - // After resolving through projections, the sort expressions may now - // contain volatile functions (e.g. `random() AS col`). Duplicating - // volatile expressions in the pushed Sort would produce different - // values, changing results. - if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { - return Ok(Transformed::no(plan)); + }; + + // Determine which side(s) of the join are preserved. + // + // - LEFT / LeftMark: only left preserved. + // - RIGHT / RightMark: symmetric. + // - CROSS JOIN (Inner with no `on` keys and no filter): + // every row from both sides appears in the output (Cartesian + // product), so we can push to whichever side has all the sort + // columns. + // + // For LEFT/RIGHT, non-equijoin filters in the ON clause are safe: + // outer joins guarantee all preserved-side rows appear in the + // output regardless of the filter. For Inner joins (cross-join + // detection), the filter check is strict (`filter.is_none()`) — + // any filter on Inner can drop rows from either side. + let preserved_candidates: &[Side] = match join.join_type { + JoinType::Left | JoinType::LeftMark => &[Side::Left], + JoinType::Right | JoinType::RightMark => &[Side::Right], + JoinType::Inner if join.on.is_empty() && join.filter.is_none() => { + &[Side::Left, Side::Right] } + _ => return Ok(Transformed::no(plan)), + }; + + // Resolve sort expressions through all intermediate nodes + // (Projection, SubqueryAlias) so column references match the + // join's schema. + let mut resolved_sort_exprs = sort.expr.clone(); + for node in &intermediates { + match node { + LogicalPlan::Projection(proj) => { + resolved_sort_exprs = + resolve_sort_exprs_through_projection(&resolved_sort_exprs, proj)?; + } + LogicalPlan::SubqueryAlias(sq) => { + resolved_sort_exprs = + resolve_sort_exprs_through_subquery_alias(&resolved_sort_exprs, sq)?; + } + _ => { + return internal_err!( + "push_topk_through_join: unexpected intermediate node: {}", + node.display() + ); + } + } + } - // Pick the first preserved-side candidate whose schema contains all - // referenced sort columns. For LEFT/RIGHT this is the fixed side; - // for CROSS we try both. - let Some(preserved_side) = preserved_candidates.iter().copied().find(|&side| { - let schema = match side { - Side::Left => join.left.schema(), - Side::Right => join.right.schema(), - }; - let cols = schema_columns(schema); - resolved_sort_exprs - .iter() - .all(|se| has_all_column_refs(&se.expr, &cols)) - }) else { - return Ok(Transformed::no(plan)); - }; + // After resolving through projections, sort expressions may now + // contain volatile functions (e.g. `random() AS col`). Duplicating + // them would change results. + if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } - let preserved_child = match preserved_side { - Side::Left => &join.left, - Side::Right => &join.right, + // Pick the first preserved-side candidate whose schema contains all + // referenced sort columns. For LEFT/RIGHT this is the fixed side; + // for CROSS we try both. + let Some(preserved_side) = preserved_candidates.iter().copied().find(|&side| { + let schema = match side { + Side::Left => join.left.schema(), + Side::Right => join.right.schema(), }; - - // Scan deep inside the preserved child (through SubqueryAlias and - // Projection layers) to find an existing Sort. If found with same - // exprs, tighten its fetch in-place. Otherwise, insert a new Sort - // directly below the join as the preserved child's wrapper. - let mut inner_child = preserved_child.as_ref(); - let mut deep_resolved_exprs = resolved_sort_exprs.clone(); - loop { - match inner_child { - LogicalPlan::SubqueryAlias(sq) => { - deep_resolved_exprs = resolve_sort_exprs_through_subquery_alias( - &deep_resolved_exprs, - sq, - )?; - inner_child = sq.input.as_ref(); - } - LogicalPlan::Projection(proj) => { - deep_resolved_exprs = resolve_sort_exprs_through_projection( - &deep_resolved_exprs, - proj, - )?; - inner_child = proj.input.as_ref(); - } - _ => break, + let cols = schema_columns(schema); + resolved_sort_exprs + .iter() + .all(|se| has_all_column_refs(&se.expr, &cols)) + }) else { + return Ok(Transformed::no(plan)); + }; + + let preserved_child = match preserved_side { + Side::Left => &join.left, + Side::Right => &join.right, + }; + + // Scan deep inside the preserved child (through SubqueryAlias and + // Projection layers) to find an existing Sort. If found with same + // exprs, tighten its fetch in-place. Otherwise, insert a new Sort + // directly below the join as the preserved child's wrapper. + let mut inner_child = preserved_child.as_ref(); + let mut deep_resolved_exprs = resolved_sort_exprs.clone(); + loop { + match inner_child { + LogicalPlan::SubqueryAlias(sq) => { + deep_resolved_exprs = + resolve_sort_exprs_through_subquery_alias(&deep_resolved_exprs, sq)?; + inner_child = sq.input.as_ref(); } + LogicalPlan::Projection(proj) => { + deep_resolved_exprs = + resolve_sort_exprs_through_projection(&deep_resolved_exprs, proj)?; + inner_child = proj.input.as_ref(); + } + _ => break, } + } - // If the inner child is a Limit (PushDownLimit hasn't merged it with - // the Sort yet), skip this iteration. PushDownLimit will merge - // Limit → Sort in the next pass, then our rule will tighten the Sort. - if matches!(inner_child, LogicalPlan::Limit(_)) { - return Ok(Transformed::no(plan)); - } + // If the inner child is a Limit (PushDownLimit's own Limit handling + // hasn't merged it with the Sort yet), skip this iteration. + if matches!(inner_child, LogicalPlan::Limit(_)) { + return Ok(Transformed::no(plan)); + } - // Determine action based on existing inner Sort: - // - Same exprs, tighter fetch → skip (already optimal) - // - Same exprs, larger/no fetch → tighten in-place - // - Different exprs or no Sort → insert new Sort below the join - // - // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=10) - // Child limits to 10, our tighter fetch=5 tightens it in-place. - // - // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC) - // Child has no fetch (full sort), tighten to fetch=5. - // - // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) - // Child already limits to 3 rows, pushing fetch=5 won't help. - // - // Example (new): Sort(b ASC, fetch=5) → Join → Sort(a ASC, fetch=10) - // Different exprs, insert Sort(b, fetch=5) above preserved child. - // - // If `deep_resolved_exprs` became volatile while resolving through - // projections inside the preserved child (e.g. a `random() AS col` - // projection turns the column reference into `random()` itself), - // structural equality with an existing inner Sort is unsound: two - // syntactically identical `random()` expressions evaluate to - // different values. In that case we must not match against the - // inner Sort — fall back to inserting a new Sort above the - // preserved child using `resolved_sort_exprs`, which is guaranteed - // non-volatile (verified above). - let deep_exprs_volatile = - deep_resolved_exprs.iter().any(|se| se.expr.is_volatile()); - let inner_sort = match inner_child { - LogicalPlan::Sort(s) if !deep_exprs_volatile => Some(s), - _ => None, + // Determine action based on existing inner Sort: + // - Same exprs, tighter fetch → skip (already optimal) + // - Same exprs, larger/no fetch → tighten in-place + // - Different exprs or no Sort → insert new Sort below the join + // + // If `deep_resolved_exprs` became volatile while resolving through + // projections inside the preserved child (e.g. `random() AS col`), + // structural equality with an existing inner Sort is unsound: two + // identical `random()` exprs evaluate to different values. Fall + // back to inserting a new Sort with `resolved_sort_exprs`. + let deep_exprs_volatile = deep_resolved_exprs.iter().any(|se| se.expr.is_volatile()); + let inner_sort = match inner_child { + LogicalPlan::Sort(s) if !deep_exprs_volatile => Some(s), + _ => None, + }; + let new_preserved_child = if let Some(child_sort) = inner_sort { + let same_exprs = sort_exprs_equal(&child_sort.expr, &deep_resolved_exprs); + let child_fetch_tighter = match child_sort.fetch { + Some(child_fetch) => child_fetch <= fetch, + None => false, }; - let new_preserved_child = if let Some(child_sort) = inner_sort { - let same_exprs = sort_exprs_equal(&child_sort.expr, &deep_resolved_exprs); - let child_fetch_tighter = match child_sort.fetch { - Some(child_fetch) => child_fetch <= fetch, - None => false, - }; - if same_exprs && child_fetch_tighter { - return Ok(Transformed::no(plan)); - } - if same_exprs { - // Tighten existing Sort in-place by rebuilding the path - // from preserved child down to the Sort. - rebuild_with_tightened_sort( - preserved_child.as_ref(), - &deep_resolved_exprs, - fetch, - )? - } else { - // Different exprs — insert new Sort above the preserved child. - // If the inner Sort has no fetch, our pushed Sort is the only - // row reduction. If it has a fetch, re-sorting a small set is - // cheap and still reduces rows entering the join. - Arc::new(LogicalPlan::Sort(SortPlan { - expr: resolved_sort_exprs, - input: Arc::clone(preserved_child), - fetch: Some(fetch), - })) - } + if same_exprs && child_fetch_tighter { + return Ok(Transformed::no(plan)); + } + if same_exprs { + rebuild_with_tightened_sort( + preserved_child.as_ref(), + &deep_resolved_exprs, + fetch, + )? } else { - // No existing Sort — insert new Sort below the join. + // Different exprs — insert new Sort above the preserved + // child. If the inner Sort has no fetch, our pushed Sort + // is the only row reduction. If it has a fetch, re-sorting + // a small set is cheap and still reduces join input. Arc::new(LogicalPlan::Sort(SortPlan { expr: resolved_sort_exprs, input: Arc::clone(preserved_child), fetch: Some(fetch), })) - }; - - // Reconstruct the join with the new child - let mut new_join = join.clone(); - match preserved_side { - Side::Left => new_join.left = new_preserved_child, - Side::Right => new_join.right = new_preserved_child, } - - // Rebuild the tree: join → intermediate nodes → top-level sort - let mut new_sort_input = Arc::new(LogicalPlan::Join(new_join)); - for node in intermediates.into_iter().rev() { - new_sort_input = Arc::new(match node { - LogicalPlan::Projection(proj) => { - let mut new_proj = proj.clone(); - new_proj.input = new_sort_input; - LogicalPlan::Projection(new_proj) - } - LogicalPlan::SubqueryAlias(sq) => LogicalPlan::SubqueryAlias( - SubqueryAlias::try_new(new_sort_input, sq.alias.clone())?, - ), - _ => { - return internal_err!( - "PushDownTopKThroughJoin: unexpected intermediate node: {}", - node.display() - ); - } - }); - } - - Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { - expr: sort.expr.clone(), - input: new_sort_input, - fetch: sort.fetch, - }))) + } else { + Arc::new(LogicalPlan::Sort(SortPlan { + expr: resolved_sort_exprs, + input: Arc::clone(preserved_child), + fetch: Some(fetch), + })) + }; + + let mut new_join = join.clone(); + match preserved_side { + Side::Left => new_join.left = new_preserved_child, + Side::Right => new_join.right = new_preserved_child, } - fn name(&self) -> &str { - "push_down_topk_through_join" + // Rebuild the tree: join → intermediate nodes → top-level sort. + let mut new_sort_input = Arc::new(LogicalPlan::Join(new_join)); + for node in intermediates.into_iter().rev() { + new_sort_input = Arc::new(match node { + LogicalPlan::Projection(proj) => { + let mut new_proj = proj.clone(); + new_proj.input = new_sort_input; + LogicalPlan::Projection(new_proj) + } + LogicalPlan::SubqueryAlias(sq) => LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_sort_input, sq.alias.clone())?, + ), + _ => { + return internal_err!( + "push_topk_through_join: unexpected intermediate node: {}", + node.display() + ); + } + }); } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } + Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { + expr: sort.expr.clone(), + input: new_sort_input, + fetch: sort.fetch, + }))) } /// Replace column references in sort expressions using a name→expr map. -/// Uses `transform()` for deep replacement (handles nested expressions -/// like `-t1.b` where the column is inside a Negative wrapper). -/// -/// Example with `replace_map = {"sub.b" → Column(t1.b)}`: -/// -/// ```text -/// Input: [sub.b ASC] → Output: [t1.b ASC] (simple column) -/// Input: [(- sub.b) ASC] → Output: [(- t1.b) ASC] (nested column) -/// Input: [sub.a ASC, sub.b ASC] → Output: [t1.a ASC, t1.b ASC] (multiple) -/// ``` fn replace_columns_in_sort_exprs( sort_exprs: &[SortExpr], - replace_map: &std::collections::HashMap, + replace_map: &HashMap, ) -> Result> { sort_exprs .iter() @@ -403,18 +296,11 @@ fn replace_columns_in_sort_exprs( /// Resolve sort expressions through a projection by replacing column /// references with the underlying projection expressions. -/// -/// Example: sort expr is `neg_b ASC` and projection has `-t1.b AS neg_b`: -/// -/// ```text -/// Input sort exprs: [neg_b ASC] -/// Output sort exprs: [(- t1.b) ASC] -/// ``` fn resolve_sort_exprs_through_projection( sort_exprs: &[SortExpr], projection: &Projection, ) -> Result> { - let replace_map: std::collections::HashMap = projection + let replace_map: HashMap = projection .schema .iter() .zip(projection.expr.iter()) @@ -427,10 +313,7 @@ fn resolve_sort_exprs_through_projection( replace_columns_in_sort_exprs(sort_exprs, &replace_map) } -/// Compare two slices of `SortExpr` for equality. -/// -/// Uses structural equality on the sort expressions (direction, nulls_first, -/// and the expression tree). +/// Compare two slices of `SortExpr` for structural equality. fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { a.len() == b.len() && a.iter().zip(b.iter()).all(|(left, right)| { @@ -440,20 +323,13 @@ fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { }) } -/// Resolve sort expressions through a SubqueryAlias by replacing the alias -/// qualifier with the input schema's qualifier. -/// -/// Example: SubqueryAlias is `sub` wrapping a join whose left input is `t1`: -/// -/// ```text -/// Input sort exprs: [sub.b ASC] -/// Output sort exprs: [t1.b ASC] -/// ``` +/// Resolve sort expressions through a `SubqueryAlias` by replacing the +/// alias qualifier with the input schema's qualifier. fn resolve_sort_exprs_through_subquery_alias( sort_exprs: &[SortExpr], subquery_alias: &SubqueryAlias, ) -> Result> { - let replace_map: std::collections::HashMap = subquery_alias + let replace_map: HashMap = subquery_alias .schema .iter() .zip(subquery_alias.input.schema().iter()) @@ -468,24 +344,7 @@ fn resolve_sort_exprs_through_subquery_alias( } /// Rebuild the tree from `root` down to an existing Sort whose expressions -/// match `target_exprs`, tightening its fetch to `new_fetch`. The path from -/// `root` to the target Sort may contain Projections and SubqueryAliases. -/// -/// Before (new_fetch=2): -/// ```text -/// SubqueryAlias(t1) -/// Projection(a, b AS renamed_b) -/// Sort(t1.b ASC, fetch=10) ← target, fetch too large -/// TableScan: t1 -/// ``` -/// -/// After: -/// ```text -/// SubqueryAlias(t1) ← rebuilt -/// Projection(a, b AS renamed_b) ← rebuilt -/// Sort(t1.b ASC, fetch=2) ← tightened -/// TableScan: t1 -/// ``` +/// match `target_exprs`, tightening its fetch to `new_fetch`. fn rebuild_with_tightened_sort( root: &LogicalPlan, target_exprs: &[SortExpr], @@ -528,6 +387,7 @@ mod test { use super::*; use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; + use crate::push_down_limit::PushDownLimit; use crate::test::*; use datafusion_expr::col; @@ -539,7 +399,8 @@ mod test { @ $expected:literal $(,)? ) => {{ let optimizer_ctx = OptimizerContext::new().with_max_passes(1); - let rules: Vec> = vec![Arc::new(PushDownTopKThroughJoin::new())]; + let rules: Vec> = + vec![Arc::new(PushDownLimit::new())]; assert_optimized_plan_eq_snapshot!( optimizer_ctx, rules, @@ -733,9 +594,8 @@ mod test { ) } - /// Inner join with no equi-keys but a non-empty filter: the filter can - /// drop rows from either side, so pushing fetch=N can produce fewer - /// output rows than the unpushed plan. + /// Inner join with no equi-keys but a non-empty filter: filter can drop + /// rows from either side, so pushing fetch=N is unsafe. #[test] fn topk_not_pushed_for_inner_with_filter_no_on() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -761,8 +621,7 @@ mod test { ) } - /// LEFT MARK join: one record per left row (with extra mark column), - /// so left is fully preserved → pushdown to left. + /// LEFT MARK join: one record per left row → pushdown to left. #[test] fn topk_pushed_to_left_of_left_mark_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -845,9 +704,7 @@ mod test { ) } - /// Join with a non-equijoin filter → pushdown still happens. - /// Outer joins preserve all rows from the preserved side regardless - /// of the ON filter. + /// Join with non-equijoin filter → pushdown still happens. #[test] fn topk_pushed_with_join_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -874,7 +731,7 @@ mod test { ) } - /// Sort without fetch (unbounded) → no pushdown. + /// Sort without fetch → no pushdown. #[test] fn topk_not_pushed_without_fetch() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -901,7 +758,7 @@ mod test { ) } - /// LEFT SEMI JOIN: pushing fetch is unsafe (not all left rows appear in output). + /// LEFT SEMI JOIN: not all left rows appear in output → no pushdown. #[test] fn topk_not_pushed_for_left_semi_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -928,7 +785,7 @@ mod test { ) } - /// LEFT ANTI JOIN: pushing fetch is unsafe (not all left rows appear in output). + /// LEFT ANTI JOIN: not all left rows appear in output → no pushdown. #[test] fn topk_not_pushed_for_left_anti_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -955,7 +812,7 @@ mod test { ) } - /// RIGHT SEMI JOIN: pushing fetch is unsafe (not all right rows appear in output). + /// RIGHT SEMI JOIN: not all right rows appear in output → no pushdown. #[test] fn topk_not_pushed_for_right_semi_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -982,7 +839,7 @@ mod test { ) } - /// RIGHT ANTI JOIN: pushing fetch is unsafe (not all right rows appear in output). + /// RIGHT ANTI JOIN: not all right rows appear in output → no pushdown. #[test] fn topk_not_pushed_for_right_anti_join() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -1045,7 +902,6 @@ mod test { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; - // Child already has Sort(b ASC, fetch=10); our outer Sort has fetch=3 (tighter). let t1_with_sort = LogicalPlanBuilder::from(t1) .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(10))? .build()?; @@ -1078,7 +934,6 @@ mod test { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; - // Child already has Sort(b ASC, fetch=2); our outer Sort has fetch=5 (looser). let t1_with_sort = LogicalPlanBuilder::from(t1) .sort_with_limit(vec![col("t1.b").sort(true, false)], Some(2))? .build()?; @@ -1123,7 +978,7 @@ mod test { Ok(()) } - /// Projection alias: sort expr references an alias that maps to a negation. + /// Projection alias: sort expr references an alias mapping to a negation. #[test] fn resolve_through_projection_alias() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -1263,8 +1118,6 @@ mod test { } /// Inner Sort has different exprs WITH fetch → stacked sorts. - /// Sort(b, fetch=2) is inserted above Sort(a, fetch=5). Re-sorting - /// 5 rows is cheap and reduces join input from 5 to 2. #[test] fn topk_stacked_when_child_has_different_exprs_with_fetch() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -1298,8 +1151,6 @@ mod test { } /// Inner Sort has different exprs WITHOUT fetch → stacked sorts. - /// Full sort doesn't limit rows, so pushed Sort(fetch=2) is the - /// only row reduction before the join. #[test] fn topk_stacked_when_child_has_different_exprs_no_fetch() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; From 44596bcbe1bac7e982d676ba6c64e1511bddbbfb Mon Sep 17 00:00:00 2001 From: SubhamSinghal Date: Mon, 1 Jun 2026 18:36:51 +0530 Subject: [PATCH 23/23] Fix build failure --- datafusion/sqllogictest/test_files/explain.slt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index e40ebcda1741e..67d2c1e7b516e 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -195,7 +195,6 @@ logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE @@ -221,7 +220,6 @@ logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE @@ -574,7 +572,6 @@ logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE @@ -600,7 +597,6 @@ logical_plan after filter_null_join_keys SAME TEXT AS ABOVE logical_plan after eliminate_outer_join SAME TEXT AS ABOVE logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE -logical_plan after push_down_topk_through_join SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE