From 1506e75c7fd17ab431abe8d997971660abf44819 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 30 May 2026 11:00:45 +0000 Subject: [PATCH] [SPARK-57171][SQL] Simplify Slice codegen by extracting index arithmetic into a static Java helper ### What changes were proposed in this pull request? Add `ArrayExpressionUtils.sliceStartIndex(int start, int numElements, String functionName)` and `sliceLength(int length, int numElements, int startIdx, String functionName)`, and route `Slice`'s codegen through them. `Slice.doGenCode` previously emitted ~17 lines of inline, element-type-independent index arithmetic (1-based -> 0-based start resolution, the `start == 0` / `length < 0` validations, and the result-length clamp). It now emits two helper calls. The eval path reuses `sliceStartIndex` for the shared start resolution. Unlike the existing SPARK-56908 sub-tasks, this is neither ANSI-specific nor a try/catch wrapper -- it is a plain, type-independent block of generated logic, which is exactly the kind of boilerplate the umbrella aims to deduplicate into static Java helpers. ### Why are the changes needed? Part of SPARK-56908 (umbrella). Moving the fixed index arithmetic out of the generated Java shrinks the per-stage source for every plan that uses `slice`. ### Does this PR introduce _any_ user-facing change? No. The compiled behavior is identical; only the emitted Java source text changes. The codegen path keeps its existing result-length clamp (`sliceLength`); the eval path keeps its existing `data.slice(...)` length handling unchanged. ### How was this patch tested? ``` build/sbt "catalyst/testOnly *CollectionExpressionsSuite" ``` 59/59 pass, including `Slice` (exercised both with and without whole-stage codegen). ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code (Opus 4.8) Co-authored-by: Isaac --- .../expressions/ArrayExpressionUtils.java | 35 +++++++++++++++++++ .../expressions/collectionOperations.scala | 32 ++++------------- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java index 3d7c5dccc7f2b..07a7f98ea1b8b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java @@ -248,4 +248,39 @@ public static int binarySearch(Double[] data, Double value) { public static int binarySearch(Object[] data, Object value, Comparator comp) { return Arrays.binarySearch(data, value, comp); } + + // ----- slice(array, start, length) index resolution ----- + // Pure 1-based -> 0-based index arithmetic, independent of the array element + // type, shared by Slice's eval and codegen paths. + + /** + * Resolves the 0-based start index for {@code slice(array, start, length)}. + * SQL {@code slice} is 1-based; a negative {@code start} counts back from the + * end of the array. A {@code start} of 0 is rejected. + */ + public static int sliceStartIndex(int start, int numElements, String functionName) { + if (start == 0) { + throw QueryExecutionErrors.unexpectedValueForStartInFunctionError(functionName); + } else if (start < 0) { + return start + numElements; + } else { + // arrays in SQL are 1-based instead of 0-based + return start - 1; + } + } + + /** + * Resolves the result length for {@code slice(array, start, length)} given the + * already-resolved {@code startIdx}, clamping it to the number of elements + * remaining after {@code startIdx}. A negative {@code length} is rejected. + */ + public static int sliceLength(int length, int numElements, int startIdx, String functionName) { + if (length < 0) { + throw QueryExecutionErrors.unexpectedValueForLengthInFunctionError(functionName, length); + } else if (length > numElements - startIdx) { + return numElements - startIdx; + } else { + return length; + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 85172f7957442..3346f23a70ad9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2052,13 +2052,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) val startInt = startVal.asInstanceOf[Int] val lengthInt = lengthVal.asInstanceOf[Int] val arr = xVal.asInstanceOf[ArrayData] - val startIndex = if (startInt == 0) { - throw QueryExecutionErrors.unexpectedValueForStartInFunctionError(prettyName) - } else if (startInt < 0) { - startInt + arr.numElements() - } else { - startInt - 1 - } + val startIndex = ArrayExpressionUtils.sliceStartIndex(startInt, arr.numElements(), prettyName) if (lengthInt < 0) { throw QueryExecutionErrors.unexpectedValueForLengthInFunctionError(prettyName, lengthInt) } @@ -2075,26 +2069,12 @@ case class Slice(x: Expression, start: Expression, length: Expression) nullSafeCodeGen(ctx, ev, (x, start, length) => { val startIdx = ctx.freshName("startIdx") val resLength = ctx.freshName("resLength") - val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + val utils = classOf[ArrayExpressionUtils].getName s""" - |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; - |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; - |if ($start == 0) { - | throw QueryExecutionErrors.unexpectedValueForStartInFunctionError("$prettyName"); - |} else if ($start < 0) { - | $startIdx = $start + $x.numElements(); - |} else { - | // arrays in SQL are 1-based instead of 0-based - | $startIdx = $start - 1; - |} - |if ($length < 0) { - | throw QueryExecutionErrors.unexpectedValueForLengthInFunctionError( - | "$prettyName", $length); - |} else if ($length > $x.numElements() - $startIdx) { - | $resLength = $x.numElements() - $startIdx; - |} else { - | $resLength = $length; - |} + |${CodeGenerator.JAVA_INT} $startIdx = + | $utils.sliceStartIndex($start, $x.numElements(), "$prettyName"); + |${CodeGenerator.JAVA_INT} $resLength = + | $utils.sliceLength($length, $x.numElements(), $startIdx, "$prettyName"); |${genCodeForResult(ctx, ev, x, startIdx, resLength)} """.stripMargin })