diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DateTimeExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DateTimeExpressionUtils.java index 0413278d0cb86..97fc2abb32551 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DateTimeExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/DateTimeExpressionUtils.java @@ -19,7 +19,11 @@ import java.time.DateTimeException; import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneId; +import org.apache.spark.SparkDateTimeException; +import org.apache.spark.sql.catalyst.util.DateTimeConstants; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.catalyst.util.IntervalUtils; import org.apache.spark.sql.errors.QueryExecutionErrors; @@ -68,4 +72,67 @@ public static CalendarInterval makeIntervalExact( throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", null); } } + + /** + * Builds the microsecond count for + * {@code MakeTimestamp(year, month, day, hour, min, secAndMicros[, timezone])}. + * {@code secAndMicros} carries the whole seconds plus the microsecond fraction + * (scale 6); a value of {@code 60} seconds with no fraction is accepted for + * PostgreSQL compatibility and rolls over to the next minute. When + * {@code timestampNTZ} is {@code true} the result is the local-time micros + * (no zone applied); otherwise {@code zoneId} is used to resolve the instant. + * + *
This is the shared, exception-raising core used by both the eval and + * codegen paths. It throws {@link SparkDateTimeException} for an invalid + * fraction-of-second and {@link DateTimeException} for an invalid + * year/month/day/hour/min combination; callers decide how to translate those. + */ + public static long makeTimestampMicros( + int year, int month, int day, int hour, int min, + Decimal secAndMicros, ZoneId zoneId, boolean timestampNTZ) { + assert secAndMicros.scale() == 6 : + "Seconds fraction must have 6 digits for microseconds but got " + secAndMicros.scale(); + // 8 digits cannot overflow Int. + int totalMicros = (int) secAndMicros.toUnscaledLong(); + int microsPerSecond = (int) DateTimeConstants.MICROS_PER_SECOND; + int nanosPerMicros = (int) DateTimeConstants.NANOS_PER_MICROS; + int seconds = Math.floorDiv(totalMicros, microsPerSecond); + int nanos = Math.floorMod(totalMicros, microsPerSecond) * nanosPerMicros; + LocalDateTime ldt; + if (seconds == 60) { + if (nanos == 0) { + // This case of sec = 60 and nanos = 0 is supported for compatibility with PostgreSQL. + ldt = LocalDateTime.of(year, month, day, hour, min, 0, 0).plusMinutes(1); + } else { + throw QueryExecutionErrors.invalidFractionOfSecondError(secAndMicros.toDouble()); + } + } else { + ldt = LocalDateTime.of(year, month, day, hour, min, seconds, nanos); + } + if (timestampNTZ) { + return DateTimeUtils.localDateTimeToMicros(ldt); + } else { + return DateTimeUtils.instantToMicros(ldt.atZone(zoneId).toInstant()); + } + } + + /** + * ANSI ({@code failOnError = true}) variant of {@link #makeTimestampMicros}: a + * {@link SparkDateTimeException} (e.g. an invalid fraction of second) is + * rethrown as-is to preserve its message, while any other + * {@link DateTimeException} is translated to {@code ansiDateTimeArgumentOutOfRange}. + * {@code SparkDateTimeException} is caught first because it is itself a + * {@link DateTimeException}. + */ + public static long makeTimestampExact( + int year, int month, int day, int hour, int min, + Decimal secAndMicros, ZoneId zoneId, boolean timestampNTZ) { + try { + return makeTimestampMicros(year, month, day, hour, min, secAndMicros, zoneId, timestampNTZ); + } catch (SparkDateTimeException e) { + throw e; + } catch (DateTimeException e) { + throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e); + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index a724f02cd107e..3274a268e158e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -24,7 +24,7 @@ import java.util.Locale import org.apache.commons.text.StringEscapeUtils -import org.apache.spark.{SparkDateTimeException, SparkException, SparkIllegalArgumentException} +import org.apache.spark.{SparkException, SparkIllegalArgumentException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} @@ -2937,33 +2937,17 @@ case class MakeTimestamp( min: Int, secAndMicros: Decimal, zoneId: ZoneId): Any = { - try { - assert(secAndMicros.scale == 6, - s"Seconds fraction must have 6 digits for microseconds but got ${secAndMicros.scale}") - val unscaledSecFrac = secAndMicros.toUnscaledLong - val totalMicros = unscaledSecFrac.toInt // 8 digits cannot overflow Int - val seconds = Math.floorDiv(totalMicros, MICROS_PER_SECOND.toInt) - val nanos = Math.floorMod(totalMicros, MICROS_PER_SECOND.toInt) * NANOS_PER_MICROS.toInt - val ldt = if (seconds == 60) { - if (nanos == 0) { - // This case of sec = 60 and nanos = 0 is supported for compatibility with PostgreSQL - LocalDateTime.of(year, month, day, hour, min, 0, 0).plusMinutes(1) - } else { - throw QueryExecutionErrors.invalidFractionOfSecondError(secAndMicros.toDouble) - } - } else { - LocalDateTime.of(year, month, day, hour, min, seconds, nanos) - } - if (dataType == TimestampType) { - instantToMicros(ldt.atZone(zoneId).toInstant) - } else { - localDateTimeToMicros(ldt) + val timestampNTZ = dataType != TimestampType + if (failOnError) { + DateTimeExpressionUtils.makeTimestampExact( + year, month, day, hour, min, secAndMicros, zoneId, timestampNTZ) + } else { + try { + DateTimeExpressionUtils.makeTimestampMicros( + year, month, day, hour, min, secAndMicros, zoneId, timestampNTZ) + } catch { + case _: DateTimeException => null } - } catch { - case e: SparkDateTimeException if failOnError => throw e - case e: DateTimeException if failOnError => - throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e) - case _: DateTimeException => null } } @@ -2990,47 +2974,23 @@ case class MakeTimestamp( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val utils = classOf[DateTimeExpressionUtils].getName val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) - val d = Decimal.getClass.getName.stripSuffix("$") - val failOnErrorBranch = if (failOnError) { - "throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e);" - } else { - s"${ev.isNull} = true;" - } - val failOnSparkErrorBranch = if (failOnError) "throw e;" else s"${ev.isNull} = true;" + val timestampNTZ = dataType != TimestampType nullSafeCodeGen(ctx, ev, (year, month, day, hour, min, secAndNanos, timezone) => { - val zoneId = timezone.map(tz => s"$dtu.getZoneId(${tz}.toString())").getOrElse(zid) - val toMicrosCode = if (dataType == TimestampType) { - s""" - |java.time.Instant instant = ldt.atZone($zoneId).toInstant(); - |${ev.value} = $dtu.instantToMicros(instant); - |""".stripMargin + val zoneIdExpr = timezone.map(tz => s"$dtu.getZoneId(${tz}.toString())").getOrElse(zid) + if (failOnError) { + s"${ev.value} = $utils.makeTimestampExact(" + + s"$year, $month, $day, $hour, $min, $secAndNanos, $zoneIdExpr, $timestampNTZ);" } else { - s"${ev.value} = $dtu.localDateTimeToMicros(ldt);" + s""" + try { + ${ev.value} = $utils.makeTimestampMicros( + $year, $month, $day, $hour, $min, $secAndNanos, $zoneIdExpr, $timestampNTZ); + } catch (java.time.DateTimeException e) { + ${ev.isNull} = true; + }""" } - s""" - try { - org.apache.spark.sql.types.Decimal secFloor = $secAndNanos.floor(); - org.apache.spark.sql.types.Decimal nanosPerSec = $d$$.MODULE$$.apply(1000000000L, 10, 0); - int nanos = (($secAndNanos.$$minus(secFloor)).$$times(nanosPerSec)).toInt(); - int seconds = secFloor.toInt(); - java.time.LocalDateTime ldt; - if (seconds == 60) { - if (nanos == 0) { - ldt = java.time.LocalDateTime.of( - $year, $month, $day, $hour, $min, 0, 0).plusMinutes(1); - } else { - throw QueryExecutionErrors.invalidFractionOfSecondError($secAndNanos.toDouble()); - } - } else { - ldt = java.time.LocalDateTime.of($year, $month, $day, $hour, $min, seconds, nanos); - } - $toMicrosCode - } catch (org.apache.spark.SparkDateTimeException e) { - $failOnSparkErrorBranch - } catch (java.time.DateTimeException e) { - $failOnErrorBranch - }""" }) }