Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1150,14 +1150,22 @@ class CodegenContext extends Logging {
* evaluation, we can look for generated subexpressions and do replacement.
*/
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
// Create a clear EquivalentExpressions and compute the common subexpressions.
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
expressions.foreach(equivalentExpressions.addExprTree(_))
subexpressionEliminationForWholeStageCodegen(equivalentExpressions)
}

/**
* Same as above, but takes a pre-built [[EquivalentExpressions]]. A caller that has already
* analyzed the expressions (e.g. to decide whether any common subexpression exists) can reuse
* that analysis here instead of rebuilding it.
*/
def subexpressionEliminationForWholeStageCodegen(
equivalentExpressions: EquivalentExpressions): SubExprCodes = {
val localSubExprEliminationExprsForNonSplit =
mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getCommonSubexpressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,22 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)

// `otherPreds` bound against this operator's `output`, shared between the CSE gate in
// `doConsume` and the CSE codegen itself. Codegen-only derived state, so `@transient`: it is
// computed on the driver during code generation and never accessed on executors.
@transient private lazy val boundOtherPreds: Seq[Expression] =
otherPreds.map(BindReferences.bindReference(_, output))

// CSE analysis of `boundOtherPreds`, built once and reused. `doConsume` consults it to decide
// whether any common subexpression is worth eliminating; when one is, the same analysis is
// handed to `subexpressionEliminationForWholeStageCodegen` rather than rebuilt. `@transient`
// because `EquivalentExpressions` is not serializable (and this is driver-only codegen state).
@transient private lazy val otherPredsEquivalentExpressions: EquivalentExpressions = {
val equivalentExpressions = new EquivalentExpressions
boundOtherPreds.foreach(equivalentExpressions.addExprTree(_))
equivalentExpressions
}

// Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
// all the variables at the beginning to take advantage of short circuiting.
override def usedInputs: AttributeSet = AttributeSet.empty
Expand Down Expand Up @@ -291,8 +307,17 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// without consulting `isNull_X`. The (b) interleaving gives us that ordering
// for free, since the IsNotNull check fires before the CSE precompute keyed
// off the same reference.
// Only take the CSE path when there is actually a common subexpression to eliminate. That
// path emits the `inputVarsEvalCode` prologue below, which eagerly evaluates every
// `otherPreds` input column at the top of the row loop -- required so eliminated
// subexpressions can be materialized into shared variables, but it defeats the
// short-circuiting the non-CSE path gets from loading columns lazily, just before the
// predicate that needs them. With no common subexpression the prologue is pure overhead
// (e.g. decoding a decimal column for rows a cheaper earlier predicate would reject), so we
// fall back to `generatePredicateCode`.
val (prologueCode, predicateCode) =
if (conf.subexpressionEliminationEnabled && otherPreds.nonEmpty) {
if (conf.subexpressionEliminationEnabled && otherPreds.nonEmpty &&
otherPredsEquivalentExpressions.getCommonSubexpressions.nonEmpty) {
// Pre-evaluate input variables before CSE analysis: CSE clears
// ctx.currentVars[i].code as a side effect; without this pre-evaluation, Janino
// fails when otherPreds reference the same input columns that CSE already
Expand All @@ -301,8 +326,8 @@ case class FilterExec(condition: Expression, child: SparkPlan)
val inputVarsEvalCode = evaluateRequiredVariables(
child.output, input, otherPredInputAttrs)

val boundOtherPreds = otherPreds.map(BindReferences.bindReference(_, output))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundOtherPreds)
val subExprs =
ctx.subexpressionEliminationForWholeStageCodegen(otherPredsEquivalentExpressions)

// Group CSE states by the index of the first otherPred that references them.
// `evaluateSubExprEliminationState` recursively emits each state's children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1186,4 +1186,42 @@ class WholeStageCodegenSuite extends SharedSparkSession
}
}
}

test("SPARK-56032: FilterExec skips CSE codegen when there is no common subexpression") {
// When otherPreds share no common subexpression, the CSE codegen path provides no benefit
// but would still eagerly evaluate every referenced input column at the top of the row loop
// (the inputVarsEvalCode prologue), defeating the lazy, short-circuiting column loads of the
// non-CSE path. Verify that with CSE enabled we fall back to the exact same generated code as
// with CSE disabled, so no column is decoded for rows an earlier predicate would reject.
val schema = StructType(Seq(
StructField("a", IntegerType, nullable = true),
StructField("b", IntegerType, nullable = true)))
val data = spark.sparkContext.parallelize(Seq(
Row(1, 5), Row(null, 3), Row(4, null), Row(5, 6), Row(7, 8), Row(2, 3)))
val expected = Seq(Row(5, 6), Row(7, 8))

def filterCode(cseEnabled: Boolean): String = {
withSQLConf(
SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString,
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val df = spark.createDataFrame(data, schema)
// `a > 4` and `b > 4` reference different columns and share no subexpression.
val filtered = df.where("a IS NOT NULL AND a > 4 AND b > 4")
val plan = filtered.queryExecution.executedPlan
assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]),
"Filter should be in whole-stage codegen")
checkAnswer(filtered, expected)
codegenString(plan)
}
}

// Each `createDataFrame` mints fresh attribute exprIds (e.g. `a#16`), which appear in the
// plan-tree header of the codegen dump but not in the generated Java. Normalize them away so
// the comparison reflects the generated code, not the id counter.
def normalize(code: String): String = code.replaceAll("#\\d+", "#")
assert(normalize(filterCode(cseEnabled = true)) == normalize(filterCode(cseEnabled = false)),
"With no common subexpression, CSE-enabled FilterExec codegen should be identical to " +
"CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting non-CSE path)")
}
}