diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 330116e592482..f83152225933b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1150,14 +1150,22 @@ class CodegenContext extends Logging { * evaluation, we can look for generated subexpressions and do replacement. */ def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { - // Create a clear EquivalentExpressions and SubExprEliminationState mapping + // Create a clear EquivalentExpressions and compute the common subexpressions. val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + expressions.foreach(equivalentExpressions.addExprTree(_)) + subexpressionEliminationForWholeStageCodegen(equivalentExpressions) + } + + /** + * Same as above, but takes a pre-built [[EquivalentExpressions]]. A caller that has already + * analyzed the expressions (e.g. to decide whether any common subexpression exists) can reuse + * that analysis here instead of rebuilding it. + */ + def subexpressionEliminationForWholeStageCodegen( + equivalentExpressions: EquivalentExpressions): SubExprCodes = { val localSubExprEliminationExprsForNonSplit = mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] - // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_)) - // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getCommonSubexpressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 92cf3f59d575b..8d183f915e8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -242,6 +242,22 @@ case class FilterExec(condition: Expression, child: SparkPlan) // The columns that will filtered out by `IsNotNull` could be considered as not nullable. private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) + // `otherPreds` bound against this operator's `output`, shared between the CSE gate in + // `doConsume` and the CSE codegen itself. Codegen-only derived state, so `@transient`: it is + // computed on the driver during code generation and never accessed on executors. + @transient private lazy val boundOtherPreds: Seq[Expression] = + otherPreds.map(BindReferences.bindReference(_, output)) + + // CSE analysis of `boundOtherPreds`, built once and reused. `doConsume` consults it to decide + // whether any common subexpression is worth eliminating; when one is, the same analysis is + // handed to `subexpressionEliminationForWholeStageCodegen` rather than rebuilt. `@transient` + // because `EquivalentExpressions` is not serializable (and this is driver-only codegen state). + @transient private lazy val otherPredsEquivalentExpressions: EquivalentExpressions = { + val equivalentExpressions = new EquivalentExpressions + boundOtherPreds.foreach(equivalentExpressions.addExprTree(_)) + equivalentExpressions + } + // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate // all the variables at the beginning to take advantage of short circuiting. override def usedInputs: AttributeSet = AttributeSet.empty @@ -291,8 +307,17 @@ case class FilterExec(condition: Expression, child: SparkPlan) // without consulting `isNull_X`. The (b) interleaving gives us that ordering // for free, since the IsNotNull check fires before the CSE precompute keyed // off the same reference. + // Only take the CSE path when there is actually a common subexpression to eliminate. That + // path emits the `inputVarsEvalCode` prologue below, which eagerly evaluates every + // `otherPreds` input column at the top of the row loop -- required so eliminated + // subexpressions can be materialized into shared variables, but it defeats the + // short-circuiting the non-CSE path gets from loading columns lazily, just before the + // predicate that needs them. With no common subexpression the prologue is pure overhead + // (e.g. decoding a decimal column for rows a cheaper earlier predicate would reject), so we + // fall back to `generatePredicateCode`. val (prologueCode, predicateCode) = - if (conf.subexpressionEliminationEnabled && otherPreds.nonEmpty) { + if (conf.subexpressionEliminationEnabled && otherPreds.nonEmpty && + otherPredsEquivalentExpressions.getCommonSubexpressions.nonEmpty) { // Pre-evaluate input variables before CSE analysis: CSE clears // ctx.currentVars[i].code as a side effect; without this pre-evaluation, Janino // fails when otherPreds reference the same input columns that CSE already @@ -301,8 +326,8 @@ case class FilterExec(condition: Expression, child: SparkPlan) val inputVarsEvalCode = evaluateRequiredVariables( child.output, input, otherPredInputAttrs) - val boundOtherPreds = otherPreds.map(BindReferences.bindReference(_, output)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundOtherPreds) + val subExprs = + ctx.subexpressionEliminationForWholeStageCodegen(otherPredsEquivalentExpressions) // Group CSE states by the index of the first otherPred that references them. // `evaluateSubExprEliminationState` recursively emits each state's children diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index e5b9e7016841e..8f0ec0ffd6f1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -1186,4 +1186,42 @@ class WholeStageCodegenSuite extends SharedSparkSession } } } + + test("SPARK-56032: FilterExec skips CSE codegen when there is no common subexpression") { + // When otherPreds share no common subexpression, the CSE codegen path provides no benefit + // but would still eagerly evaluate every referenced input column at the top of the row loop + // (the inputVarsEvalCode prologue), defeating the lazy, short-circuiting column loads of the + // non-CSE path. Verify that with CSE enabled we fall back to the exact same generated code as + // with CSE disabled, so no column is decoded for rows an earlier predicate would reject. + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true))) + val data = spark.sparkContext.parallelize(Seq( + Row(1, 5), Row(null, 3), Row(4, null), Row(5, 6), Row(7, 8), Row(2, 3))) + val expected = Seq(Row(5, 6), Row(7, 8)) + + def filterCode(cseEnabled: Boolean): String = { + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = spark.createDataFrame(data, schema) + // `a > 4` and `b > 4` reference different columns and share no subexpression. + val filtered = df.where("a IS NOT NULL AND a > 4 AND b > 4") + val plan = filtered.queryExecution.executedPlan + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]), + "Filter should be in whole-stage codegen") + checkAnswer(filtered, expected) + codegenString(plan) + } + } + + // Each `createDataFrame` mints fresh attribute exprIds (e.g. `a#16`), which appear in the + // plan-tree header of the codegen dump but not in the generated Java. Normalize them away so + // the comparison reflects the generated code, not the id counter. + def normalize(code: String): String = code.replaceAll("#\\d+", "#") + assert(normalize(filterCode(cseEnabled = true)) == normalize(filterCode(cseEnabled = false)), + "With no common subexpression, CSE-enabled FilterExec codegen should be identical to " + + "CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting non-CSE path)") + } }