Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.UTF8String;

/**
* Static helpers used by {@code Cast.doGenCode} (and corresponding eval paths)
Expand Down Expand Up @@ -112,4 +113,37 @@ public static Decimal changePrecisionExact(
public static Decimal changePrecisionOrNull(Decimal d, int precision, int scale) {
return d.changePrecision(precision, scale) ? d : null;
}

// ----- string -> floating point (ANSI: throw on invalid input) -----
// Mirrors castToFloatCode / castToDoubleCode: parse the string, and on a
// NumberFormatException fall back to the special-literal forms handled by
// Cast.processFloatingPointSpecialLiterals (inf / +inf / -inf / infinity / nan,
// case-insensitive). If that also yields no value, throw the ANSI
// CAST_INVALID_INPUT error citing the original (untrimmed) input string.

public static float stringToFloatExact(UTF8String s, QueryContext context) {
String str = s.toString();
try {
return Float.parseFloat(str);
} catch (NumberFormatException e) {
Float f = (Float) Cast.processFloatingPointSpecialLiterals(str, true);
if (f == null) {
throw QueryExecutionErrors.invalidInputInCastToNumberError(FLOAT, s, context);
}
return f;
}
}

public static double stringToDoubleExact(UTF8String s, QueryContext context) {
String str = s.toString();
try {
return Double.parseDouble(str);
} catch (NumberFormatException e) {
Double d = (Double) Cast.processFloatingPointSpecialLiterals(str, false);
if (d == null) {
throw QueryExecutionErrors.invalidInputInCastToNumberError(DOUBLE, s, context);
}
return d;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1170,16 +1170,14 @@ case class Cast(
private[this] def castToDouble(from: DataType): Any => Any = from match {
case _: StringType =>
buildCast[UTF8String](_, s => {
val doubleStr = s.toString
try doubleStr.toDouble catch {
case _: NumberFormatException =>
val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
if (ansiEnabled && d == null) {
throw QueryExecutionErrors.invalidInputInCastToNumberError(
DoubleType, s, getContextOrNull())
} else {
d
}
if (ansiEnabled) {
CastUtils.stringToDoubleExact(s, getContextOrNull())
} else {
val doubleStr = s.toString
try doubleStr.toDouble catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(doubleStr, false)
}
}
})
case BooleanType =>
Expand All @@ -1197,16 +1195,14 @@ case class Cast(
private[this] def castToFloat(from: DataType): Any => Any = from match {
case _: StringType =>
buildCast[UTF8String](_, s => {
val floatStr = s.toString
try floatStr.toFloat catch {
case _: NumberFormatException =>
val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
if (ansiEnabled && f == null) {
throw QueryExecutionErrors.invalidInputInCastToNumberError(
FloatType, s, getContextOrNull())
} else {
f
}
if (ansiEnabled) {
CastUtils.stringToFloatExact(s, getContextOrNull())
} else {
val floatStr = s.toString
try floatStr.toFloat catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(floatStr, true)
}
}
})
case BooleanType =>
Expand Down Expand Up @@ -2212,28 +2208,27 @@ case class Cast(
private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case _: StringType =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
if (ansiEnabled) {
val castUtils = classOf[CastUtils].getName
val errorContext = getContextOrNullCode(ctx)
"throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, $errorContext);"
code"$evPrim = $castUtils.stringToFloatExact($c, $errorContext);"
} else {
s"$evNull = true;"
}
code"""
final String $floatStr = $c.toString();
try {
$evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
if (f == null) {
$handleNull
} else {
$evPrim = f.floatValue();
val floatStr = ctx.freshVariable("floatStr", StringType)
code"""
final String $floatStr = $c.toString();
try {
$evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
if (f == null) {
$evNull = true;
} else {
$evPrim = f.floatValue();
}
}
"""
}
"""
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
Expand All @@ -2250,28 +2245,27 @@ case class Cast(
private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case _: StringType =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
val handleNull = if (ansiEnabled) {
if (ansiEnabled) {
val castUtils = classOf[CastUtils].getName
val errorContext = getContextOrNullCode(ctx)
"throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, $errorContext);"
code"$evPrim = $castUtils.stringToDoubleExact($c, $errorContext);"
} else {
s"$evNull = true;"
}
code"""
final String $doubleStr = $c.toString();
try {
$evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
if (d == null) {
$handleNull
} else {
$evPrim = d.doubleValue();
val doubleStr = ctx.freshVariable("doubleStr", StringType)
code"""
final String $doubleStr = $c.toString();
try {
$evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
if (d == null) {
$evNull = true;
} else {
$evPrim = d.doubleValue();
}
}
"""
}
"""
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
Expand Down