diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java index 3d7c5dccc7f2b..07a7f98ea1b8b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java @@ -248,4 +248,39 @@ public static int binarySearch(Double[] data, Double value) { public static int binarySearch(Object[] data, Object value, Comparator comp) { return Arrays.binarySearch(data, value, comp); } + + // ----- slice(array, start, length) index resolution ----- + // Pure 1-based -> 0-based index arithmetic, independent of the array element + // type, shared by Slice's eval and codegen paths. + + /** + * Resolves the 0-based start index for {@code slice(array, start, length)}. + * SQL {@code slice} is 1-based; a negative {@code start} counts back from the + * end of the array. A {@code start} of 0 is rejected. + */ + public static int sliceStartIndex(int start, int numElements, String functionName) { + if (start == 0) { + throw QueryExecutionErrors.unexpectedValueForStartInFunctionError(functionName); + } else if (start < 0) { + return start + numElements; + } else { + // arrays in SQL are 1-based instead of 0-based + return start - 1; + } + } + + /** + * Resolves the result length for {@code slice(array, start, length)} given the + * already-resolved {@code startIdx}, clamping it to the number of elements + * remaining after {@code startIdx}. A negative {@code length} is rejected. + */ + public static int sliceLength(int length, int numElements, int startIdx, String functionName) { + if (length < 0) { + throw QueryExecutionErrors.unexpectedValueForLengthInFunctionError(functionName, length); + } else if (length > numElements - startIdx) { + return numElements - startIdx; + } else { + return length; + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 85172f7957442..3346f23a70ad9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2052,13 +2052,7 @@ case class Slice(x: Expression, start: Expression, length: Expression) val startInt = startVal.asInstanceOf[Int] val lengthInt = lengthVal.asInstanceOf[Int] val arr = xVal.asInstanceOf[ArrayData] - val startIndex = if (startInt == 0) { - throw QueryExecutionErrors.unexpectedValueForStartInFunctionError(prettyName) - } else if (startInt < 0) { - startInt + arr.numElements() - } else { - startInt - 1 - } + val startIndex = ArrayExpressionUtils.sliceStartIndex(startInt, arr.numElements(), prettyName) if (lengthInt < 0) { throw QueryExecutionErrors.unexpectedValueForLengthInFunctionError(prettyName, lengthInt) } @@ -2075,26 +2069,12 @@ case class Slice(x: Expression, start: Expression, length: Expression) nullSafeCodeGen(ctx, ev, (x, start, length) => { val startIdx = ctx.freshName("startIdx") val resLength = ctx.freshName("resLength") - val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + val utils = classOf[ArrayExpressionUtils].getName s""" - |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; - |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; - |if ($start == 0) { - | throw QueryExecutionErrors.unexpectedValueForStartInFunctionError("$prettyName"); - |} else if ($start < 0) { - | $startIdx = $start + $x.numElements(); - |} else { - | // arrays in SQL are 1-based instead of 0-based - | $startIdx = $start - 1; - |} - |if ($length < 0) { - | throw QueryExecutionErrors.unexpectedValueForLengthInFunctionError( - | "$prettyName", $length); - |} else if ($length > $x.numElements() - $startIdx) { - | $resLength = $x.numElements() - $startIdx; - |} else { - | $resLength = $length; - |} + |${CodeGenerator.JAVA_INT} $startIdx = + | $utils.sliceStartIndex($start, $x.numElements(), "$prettyName"); + |${CodeGenerator.JAVA_INT} $resLength = + | $utils.sliceLength($length, $x.numElements(), $startIdx, "$prettyName"); |${genCodeForResult(ctx, ev, x, startIdx, resLength)} """.stripMargin })