diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java index a2e427b4a4ce2..424a52e7d6388 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; /** * Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths) @@ -112,4 +113,37 @@ public static Decimal changePrecisionExact( public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) { return d.changePrecision(precision, scale) ? d : null; } + + // ----- string -> floating point (ANSI: throw on invalid input) ----- + // Mirrors castToFloatCode / castToDoubleCode: parse the string, and on a + // NumberFormatException fall back to the special-literal forms handled by + // Cast.processFloatingPointSpecialLiterals (inf / +inf / -inf / infinity / nan, + // case-insensitive). If that also yields no value, throw the ANSI + // CAST_INVALID_INPUT error citing the original (untrimmed) input string. + + public static float stringToFloatExact(UTF8String s, QueryContext context) { + String str = s.toString(); + try { + return Float.parseFloat(str); + } catch (NumberFormatException e) { + Float f = (Float) Cast.processFloatingPointSpecialLiterals(str, true); + if (f == null) { + throw QueryExecutionErrors.invalidInputInCastToNumberError(FLOAT, s, context); + } + return f; + } + } + + public static double stringToDoubleExact(UTF8String s, QueryContext context) { + String str = s.toString(); + try { + return Double.parseDouble(str); + } catch (NumberFormatException e) { + Double d = (Double) Cast.processFloatingPointSpecialLiterals(str, false); + if (d == null) { + throw QueryExecutionErrors.invalidInputInCastToNumberError(DOUBLE, s, context); + } + return d; + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 66501ebe7d5c8..8b336fef453a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1170,16 +1170,14 @@ case class Cast( private[this] def castToDouble(from: DataType): Any => Any = from match { case _: StringType => buildCast[UTF8String](_, s => { - val doubleStr = s.toString - try doubleStr.toDouble catch { - case _: NumberFormatException => - val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false) - if (ansiEnabled && d == null) { - throw QueryExecutionErrors.invalidInputInCastToNumberError( - DoubleType, s, getContextOrNull()) - } else { - d - } + if (ansiEnabled) { + CastUtils.stringToDoubleExact(s, getContextOrNull()) + } else { + val doubleStr = s.toString + try doubleStr.toDouble catch { + case _: NumberFormatException => + Cast.processFloatingPointSpecialLiterals(doubleStr, false) + } } }) case BooleanType => @@ -1197,16 +1195,14 @@ case class Cast( private[this] def castToFloat(from: DataType): Any => Any = from match { case _: StringType => buildCast[UTF8String](_, s => { - val floatStr = s.toString - try floatStr.toFloat catch { - case _: NumberFormatException => - val f = Cast.processFloatingPointSpecialLiterals(floatStr, true) - if (ansiEnabled && f == null) { - throw QueryExecutionErrors.invalidInputInCastToNumberError( - FloatType, s, getContextOrNull()) - } else { - f - } + if (ansiEnabled) { + CastUtils.stringToFloatExact(s, getContextOrNull()) + } else { + val floatStr = s.toString + try floatStr.toFloat catch { + case _: NumberFormatException => + Cast.processFloatingPointSpecialLiterals(floatStr, true) + } } }) case BooleanType => @@ -2212,28 +2208,27 @@ case class Cast( private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case _: StringType => - val floatStr = ctx.freshVariable("floatStr", StringType) (c, evPrim, evNull) => - val handleNull = if (ansiEnabled) { + if (ansiEnabled) { + val castUtils = classOf[CastUtils].getName val errorContext = getContextOrNullCode(ctx) - "throw QueryExecutionErrors.invalidInputInCastToNumberError(" + - s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, $errorContext);" + code"$evPrim = $castUtils.stringToFloatExact($c, $errorContext);" } else { - s"$evNull = true;" - } - code""" - final String $floatStr = $c.toString(); - try { - $evPrim = Float.valueOf($floatStr); - } catch (java.lang.NumberFormatException e) { - final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); - if (f == null) { - $handleNull - } else { - $evPrim = f.floatValue(); + val floatStr = ctx.freshVariable("floatStr", StringType) + code""" + final String $floatStr = $c.toString(); + try { + $evPrim = Float.valueOf($floatStr); + } catch (java.lang.NumberFormatException e) { + final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true); + if (f == null) { + $evNull = true; + } else { + $evPrim = f.floatValue(); + } } + """ } - """ case BooleanType => (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;" case DateType => @@ -2250,28 +2245,27 @@ case class Cast( private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case _: StringType => - val doubleStr = ctx.freshVariable("doubleStr", StringType) (c, evPrim, evNull) => - val handleNull = if (ansiEnabled) { + if (ansiEnabled) { + val castUtils = classOf[CastUtils].getName val errorContext = getContextOrNullCode(ctx) - "throw QueryExecutionErrors.invalidInputInCastToNumberError(" + - s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, $errorContext);" + code"$evPrim = $castUtils.stringToDoubleExact($c, $errorContext);" } else { - s"$evNull = true;" - } - code""" - final String $doubleStr = $c.toString(); - try { - $evPrim = Double.valueOf($doubleStr); - } catch (java.lang.NumberFormatException e) { - final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); - if (d == null) { - $handleNull - } else { - $evPrim = d.doubleValue(); + val doubleStr = ctx.freshVariable("doubleStr", StringType) + code""" + final String $doubleStr = $c.toString(); + try { + $evPrim = Double.valueOf($doubleStr); + } catch (java.lang.NumberFormatException e) { + final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false); + if (d == null) { + $evNull = true; + } else { + $evPrim = d.doubleValue(); + } } + """ } - """ case BooleanType => (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;" case DateType =>