Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, UTF8String}
import org.apache.spark.util.ArrayImplicits._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -430,6 +430,11 @@ abstract class HashExpression[E] extends Expression {
s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);"
}

protected def genHashTimestampNanos(input: String, result: String): String = {
val epochMicrosHash = s"$hasherClassName.hashLong($input.epochMicros, $result)"
s"$result = $hasherClassName.hashInt($input.nanosWithinMicro, $epochMicrosHash);"
}

protected def genHashString(
ctx: CodegenContext, stringType: StringType, input: String, result: String): String = {
if (stringType.supportsBinaryEquality) {
Expand Down Expand Up @@ -549,6 +554,8 @@ abstract class HashExpression[E] extends Expression {
case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result)
case LongType | _: TimeType => genHashLong(input, result)
case TimestampType | TimestampNTZType => genHashTimestamp(input, result)
case _: TimestampNTZNanosType | _: TimestampLTZNanosType =>
genHashTimestampNanos(input, result)
case FloatType => genHashFloat(input, result)
case DoubleType => genHashDouble(input, result)
case d: DecimalType => genHashDecimal(ctx, d, input, result)
Expand Down Expand Up @@ -636,6 +643,7 @@ abstract class InterpretedHashFunction {
hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed)
}
case c: CalendarInterval => hashInt(c.months, hashInt(c.days, hashLong(c.microseconds, seed)))
case t: TimestampNanosVal => hashInt(t.nanosWithinMicro, hashLong(t.epochMicros, seed))
case a: Array[Byte] =>
hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed)
case s: UTF8String =>
Expand Down Expand Up @@ -977,6 +985,12 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
$result = (int) ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestamp($input);
"""

override protected def genHashTimestampNanos(input: String, result: String): String =
s"""
$result = (int)
${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestampNanos($input);
"""

override protected def genHashString(
ctx: CodegenContext, stringType: StringType, input: String, result: String): String = {
if (stringType.supportsBinaryEquality || !isCollationAware) {
Expand Down Expand Up @@ -1144,6 +1158,17 @@ object HiveHashFunction extends InterpretedHashFunction {
((result >>> 32) ^ result).toInt
}

/**
* Extends [[hashTimestamp]] with the sub-microsecond nanoseconds carried by a
* [[TimestampNanosVal]], folding the extra field in with the same `* 37 + field` idiom used by
* [[hashCalendarInterval]]. Hive has no nanosecond-precision timestamp type, so this is a
* Spark-defined, self-consistent hash (equal values hash equally) rather than a Hive-compatible
* one.
*/
def hashTimestampNanos(t: TimestampNanosVal): Long = {
(hashTimestamp(t.epochMicros) * 37) + t.nanosWithinMicro
}

/**
* Hive allows input intervals to be defined using units below but the intervals
* have to be from the same category:
Expand Down Expand Up @@ -1242,6 +1267,7 @@ object HiveHashFunction extends InterpretedHashFunction {

case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode()
case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp)
case timestampNanos: TimestampNanosVal => hashTimestampNanos(timestampNanos)
case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval)
case _ => super.hash(value, dataType, 0, isCollationAware, legacyCollationAwareHashing)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CollationFactory, DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, StructType, _}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{TimestampNanosVal, UTF8String}
import org.apache.spark.util.ArrayImplicits._

class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -885,6 +885,50 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(HiveHash(Seq(time)), -1567775210)
}

test("HashExpression supports nanosecond timestamp types") {
// (epochMicros, nanosWithinMicro) pairs covering zero/mid/max nanos, negative micros, and
// the Long epoch-micro boundaries.
val values = Seq(
TimestampNanosVal.fromParts(0L, 0.toShort),
TimestampNanosVal.fromParts(1L, 1.toShort),
TimestampNanosVal.fromParts(1234567890L, 999.toShort),
TimestampNanosVal.fromParts(-1L, 500.toShort),
TimestampNanosVal.fromParts(Long.MinValue, 0.toShort),
TimestampNanosVal.fromParts(Long.MaxValue, 999.toShort))

Seq(TimestampNTZNanosType(9), TimestampLTZNanosType(9),
TimestampNTZNanosType(7), TimestampLTZNanosType(7)).foreach { dt =>
(values.map(Literal.create(_, dt)) :+ Literal.create(null, dt)).foreach { lit =>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All hash inputs here are Literals, so the generated code embeds the TimestampNanosVal as a reference object and the ordinal-read codegen path (CodeGenerator.getValue -> getTimestampNTZNanos/getTimestampLTZNanos) is never exercised, and the value never round-trips through an UnsafeRow as a hash input. checkEvaluationWithUnsafeProjection here only projects the resulting int/long hash, not the nanos input.

Since the motivation is hash-based GROUP BY / shuffle / joins (where the input is a BoundReference reading from a possibly-unsafe row), could we add a BoundReference-over-row case, e.g.:

val row = InternalRow(TimestampNanosVal.fromParts(1234567890L, 999.toShort))
val ref = BoundReference(0, dt, nullable = true)
checkEvaluation(Murmur3Hash(Seq(ref), 42), Murmur3Hash(Seq(ref), 42).eval(row), row)

This drives the row read + unsafe round-trip that the literal-based tests skip.

// checkEvaluation asserts the interpreted, codegen, and unsafe paths all agree.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: with a Literal child, the unsafe projection only stores the scalar hash result, not the TimestampNanosVal input, so the unsafe input path isn't actually covered here. Either reword this comment or add the BoundReference-over-row case suggested above so the comment becomes accurate.

checkEvaluation(Murmur3Hash(Seq(lit), 42), Murmur3Hash(Seq(lit), 42).eval())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected value is Murmur3Hash(Seq(lit), 42).eval(), i.e. computed by the same expression under test, so this only proves the eval paths agree with each other (and, via the second test, that both fields contribute). A bug shared across all paths (e.g. a wrong constant, or a symmetric field swap) wouldn't be caught. The existing tests in this suite pin literals (e.g. checkEvaluation(HiveHash(Seq(time)), -1567775210)). Could we pin at least one golden constant per algorithm for a fixed (epochMicros, nanosWithinMicro) pair?

checkEvaluation(XxHash64(Seq(lit), 42L), XxHash64(Seq(lit), 42L).eval())
checkEvaluation(HiveHash(Seq(lit)), HiveHash(Seq(lit)).eval())
}
}
}

test("nanosecond timestamp hash is consistent with equality") {
val dt = TimestampNTZNanosType(9)
def lit(micros: Long, nanos: Short): Literal =
Literal.create(TimestampNanosVal.fromParts(micros, nanos), dt)

val a = lit(1234567890L, 123)
val aCopy = lit(1234567890L, 123)
val diffNanos = lit(1234567890L, 124) // same micros, different sub-micro nanos
val diffMicros = lit(1234567891L, 123) // different micros, same nanos

Seq[Expression => Any](
e => Murmur3Hash(Seq(e), 42).eval(),
e => XxHash64(Seq(e), 42L).eval(),
e => HiveHash(Seq(e)).eval()).foreach { hash =>
// Equal values hash equally.
assert(hash(a) === hash(aCopy))
// Both fields contribute to the hash (guards against a dropped epochMicros/nanos field).
assert(hash(a) !== hash(diffNanos))
assert(hash(a) !== hash(diffMicros))
}
}

private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val toRow = ExpressionEncoder(inputSchema).createSerializer()
Expand Down