Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};

use arrow::array::ArrayRef;
use arrow::datatypes::DataType::{
Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, Int8, Int16, Int32,
Int64, UInt8, UInt16, UInt32, UInt64,
};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int8Type, Int16Type,
Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
};
use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
Expand Down Expand Up @@ -120,6 +122,13 @@ fn calculate_new_precision_scale<T: DecimalType>(
}
}

fn is_integer_data_type(data_type: &DataType) -> bool {
matches!(
data_type,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
)
}

fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
let out_of_range = |value: String| {
datafusion_common::DataFusionError::Execution(format!(
Expand Down Expand Up @@ -185,6 +194,7 @@ impl RoundFunc {
vec![TypeSignatureClass::Integer],
NativeType::Int32,
);
let integer = Coercion::new_exact(TypeSignatureClass::Integer);
let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
let float64 = Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
Expand All @@ -199,6 +209,11 @@ impl RoundFunc {
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![decimal]),
TypeSignature::Coercible(vec![
integer.clone(),
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![integer]),
TypeSignature::Coercible(vec![
float32.clone(),
decimal_places.clone(),
Expand Down Expand Up @@ -245,6 +260,7 @@ impl ScalarUDFImpl for RoundFunc {
// extra precision to accommodate potential carry-over.
let return_type =
match input_type {
input_type if is_integer_data_type(input_type) => input_type.clone(),
Float32 => Float32,
Decimal32(precision, scale) => calculate_new_precision_scale::<
Decimal32Type,
Expand Down Expand Up @@ -308,6 +324,29 @@ impl ScalarUDFImpl for RoundFunc {
};

match (value_scalar, args.return_type()) {
(value_scalar, return_type)
if is_integer_data_type(return_type) && dp >= 0 =>
{
ColumnarValue::Scalar(value_scalar.clone()).cast_to(return_type, None)
}
(value_scalar, Float64)
if is_integer_data_type(&value_scalar.data_type()) =>
{
let value = ColumnarValue::Scalar(value_scalar.clone())
.cast_to(&Float64, None)?;
match value {
ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
let rounded = round_float(v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(rounded))))
}
ColumnarValue::Scalar(ScalarValue::Float64(None)) => {
Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)))
}
_ => internal_err!(
"Unexpected datatype after casting integer argument to Float64"
),
}
}
(ScalarValue::Float32(Some(v)), _) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
Expand Down Expand Up @@ -467,7 +506,36 @@ fn round_columnar(
&& matches!(decimal_places, ColumnarValue::Scalar(_));
let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_));

macro_rules! round_integer_array_to_float64 {
($ARRAY_TYPE:ty) => {{
let result = calculate_binary_math::<$ARRAY_TYPE, Int32Type, Float64Type, _>(
value_array.as_ref(),
decimal_places,
|v, dp| round_float(v as f64, dp),
)?;
result as _
}};
}

let arr: ArrayRef = match (value_array.data_type(), return_type) {
(input_type, return_type)
if input_type == return_type
&& is_integer_data_type(return_type)
&& match decimal_places {
ColumnarValue::Scalar(ScalarValue::Int32(Some(dp))) => *dp >= 0,
_ => false,
} =>
{
value_array
}
(Int8, Float64) => round_integer_array_to_float64!(Int8Type),
(Int16, Float64) => round_integer_array_to_float64!(Int16Type),
(Int32, Float64) => round_integer_array_to_float64!(Int32Type),
(Int64, Float64) => round_integer_array_to_float64!(Int64Type),
(UInt8, Float64) => round_integer_array_to_float64!(UInt8Type),
(UInt16, Float64) => round_integer_array_to_float64!(UInt16Type),
(UInt32, Float64) => round_integer_array_to_float64!(UInt32Type),
(UInt64, Float64) => round_integer_array_to_float64!(UInt64Type),
(Float64, _) => {
let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
value_array.as_ref(),
Expand Down
50 changes: 29 additions & 21 deletions datafusion/spark/src/function/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,22 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result<Columna
impl_integer_array_round!(array, UInt32Type, scale, enable_ansi_mode)
}
DataType::UInt64 => {
let array = array.as_primitive::<UInt64Type>();
let result: PrimitiveArray<UInt64Type> = array.try_unary(|x| {
let v_i64 = i64::try_from(x).map_err(|_| {
(exec_err!(
"round: UInt64 value {x} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
if scale >= 0 {
Ok(args[0].clone())
} else {
let array = array.as_primitive::<UInt64Type>();
let result: PrimitiveArray<UInt64Type> = array.try_unary(|x| {
let v_i64 = i64::try_from(x).map_err(|_| {
(exec_err!(
"round: UInt64 value {x} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
round_integer(v_i64, scale, enable_ansi_mode)
.map(|v| v as u64)
})?;
round_integer(v_i64, scale, enable_ansi_mode)
.map(|v| v as u64)
})?;
Ok(ColumnarValue::Array(Arc::new(result)))
Ok(ColumnarValue::Array(Arc::new(result)))
}
}

// Float types
Expand Down Expand Up @@ -588,16 +592,20 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result<Columna
Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(result))))
}
ScalarValue::UInt64(Some(v)) => {
let v_i64 = i64::try_from(*v).map_err(|_| {
(exec_err!(
"round: UInt64 value {v} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
let result = round_integer(v_i64, scale, enable_ansi_mode)?;
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
result as u64,
))))
if scale >= 0 {
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(*v))))
} else {
let v_i64 = i64::try_from(*v).map_err(|_| {
(exec_err!(
"round: UInt64 value {v} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
let result = round_integer(v_i64, scale, enable_ansi_mode)?;
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
result as u64,
))))
}
}

// Float scalars
Expand Down
14 changes: 14 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,20 @@ select round(a), round(b), round(c) from small_floats;
0 0 1
1 0 0

# round int64 should preserve exact values above Float64 precision range
query TI
select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'))),
round(arrow_cast(9007199254740993, 'Int64'));
----
Int64 9007199254740993

# round int64 with positive decimal_places should preserve exact values above Float64 precision range
query TI
select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'), 2)),
round(arrow_cast(9007199254740993, 'Int64'), 2);
----
Int64 9007199254740993

# round with too large
# max Int32 is 2147483647
query error round decimal_places 2147483648 is out of supported i32 range
Expand Down
11 changes: 5 additions & 6 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1591,16 +1591,15 @@ WHERE CAST(ROUND(b) as INT) = a
ORDER BY CAST(ROUND(b) as INT);
----
logical_plan
01)Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC NULLS LAST
02)--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a
03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a]
01)Sort: CAST(round(annotated_data_finite2.b) AS Int32) ASC NULLS LAST
02)--Filter: CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a
03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a]
physical_plan
01)SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS LAST]
02)--FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1
01)SortPreservingMergeExec: [round(b@2) ASC NULLS LAST]
02)--FilterExec: round(b@2) = a@1
03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1, maintains_sort_order=true
04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true


statement ok
drop table annotated_data_finite2;

Expand Down
24 changes: 24 additions & 0 deletions datafusion/sqllogictest/test_files/spark/math/round.slt
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ SELECT round(25::bigint, -1::int);
----
30

# round(bigint) should preserve exact values above Float64's exact integer range
query IT
SELECT round(arrow_cast(9007199254740993, 'Int64')), arrow_typeof(round(arrow_cast(9007199254740993, 'Int64')));
----
9007199254740993 Int64

# round(bigint, positive scale) should also preserve exact values above Float64's exact integer range
query IT
SELECT round(arrow_cast(9007199254740993, 'Int64'), 2::int), arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'), 2::int));
----
9007199254740993 Int64

# round(smallint, -1)
query I
SELECT round(25::smallint, -1::int);
Expand Down Expand Up @@ -268,6 +280,18 @@ SELECT round(arrow_cast(25, 'UInt64'), -1::int);
----
30

# round(uint64) should preserve exact values above Float64's exact integer range
query IT
SELECT round(arrow_cast(18446744073709551615, 'UInt64')), arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64')));
----
18446744073709551615 UInt64

# round(uint64, positive scale) should also preserve exact values above Float64's exact integer range
query IT
SELECT round(arrow_cast(18446744073709551615, 'UInt64'), 2::int), arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64'), 2::int));
----
18446744073709551615 UInt64

# round(uint32, positive scale) — no-op for integers
query I
SELECT round(arrow_cast(42, 'UInt32'), 2::int);
Expand Down