diff --git a/datafusion/spark/src/function/string/format_string.rs b/datafusion/spark/src/function/string/format_string.rs index 51e4ebfa7b465..30bc2eacf7a76 100644 --- a/datafusion/spark/src/function/string/format_string.rs +++ b/datafusion/spark/src/function/string/format_string.rs @@ -891,23 +891,82 @@ fn unsigned_to_char(value: u64) -> Result { codepoint_to_char(codepoint) } -/// Convert a non-null integer scalar to a [`char`] for the `%c` conversion. -fn integer_scalar_to_char(scalar: &ScalarValue) -> Result { - match scalar { - ScalarValue::Int8(Some(value)) => signed_to_char(*value as i64), - ScalarValue::Int16(Some(value)) => signed_to_char(*value as i64), - ScalarValue::Int32(Some(value)) => signed_to_char(*value as i64), - ScalarValue::Int64(Some(value)) => signed_to_char(*value), - ScalarValue::UInt8(Some(value)) => unsigned_to_char(*value as u64), - ScalarValue::UInt16(Some(value)) => unsigned_to_char(*value as u64), - ScalarValue::UInt32(Some(value)) => unsigned_to_char(*value as u64), - ScalarValue::UInt64(Some(value)) => unsigned_to_char(*value), - _ => datafusion_common::internal_err!( - "integer_scalar_to_char expects a non-null integer scalar, got {scalar:?}" - ), +/// Normalizes integer scalar payloads while preserving Spark formatting semantics: +/// signed values format as decimal for `%d` / `%s` / `%c`, but use their original +/// bit width for `%x` / `%o` via `unsigned_bits`. +#[derive(Debug, Clone, Copy)] +enum IntegerValue { + Signed { decimal: i64, unsigned_bits: u64 }, + Unsigned(u64), +} + +impl IntegerValue { + fn unsigned_bits(self) -> u64 { + match self { + Self::Signed { unsigned_bits, .. } => unsigned_bits, + Self::Unsigned(value) => value, + } + } + + fn to_char(self) -> Result { + match self { + Self::Signed { decimal, .. } => signed_to_char(decimal), + Self::Unsigned(value) => unsigned_to_char(value), + } + } + + fn format_decimal( + self, + spec: &ConversionSpecifier, + writer: &mut String, + ) -> Result<()> { + match self { + Self::Signed { decimal, .. } => spec.format_signed(writer, decimal), + Self::Unsigned(value) => spec.format_unsigned(writer, value), + } + } + + fn decimal_string(self) -> String { + match self { + Self::Signed { decimal, .. } => decimal.to_string(), + Self::Unsigned(value) => value.to_string(), + } } } +macro_rules! signed_integer_value { + ($source:ty, $unsigned:ty) => { + impl From<$source> for IntegerValue { + fn from(value: $source) -> Self { + Self::Signed { + decimal: value as i64, + unsigned_bits: (value as $unsigned) as u64, + } + } + } + }; +} + +signed_integer_value!(i8, u8); +signed_integer_value!(i16, u16); +signed_integer_value!(i32, u32); +signed_integer_value!(i64, u64); + +macro_rules! unsigned_integer_value { + ($source:ty) => { + impl From<$source> for IntegerValue { + fn from(value: $source) -> Self { + Self::Unsigned(value as u64) + } + } + }; +} + +unsigned_integer_value!(u8); +unsigned_integer_value!(u16); +unsigned_integer_value!(u32); +unsigned_integer_value!(u64); + impl ConversionSpecifier { /// Validates that the grouping separator flag is not used with scientific /// notation conversions, matching Java/Spark behavior which throws @@ -940,189 +999,14 @@ impl ConversionSpecifier { _ => self.format_boolean(string, value), }, - ScalarValue::Int8(Some(_)) - | ScalarValue::Int16(Some(_)) - | ScalarValue::Int32(Some(_)) - | ScalarValue::Int64(Some(_)) - | ScalarValue::UInt8(Some(_)) - | ScalarValue::UInt16(Some(_)) - | ScalarValue::UInt32(Some(_)) - | ScalarValue::UInt64(Some(_)) - if matches!( - self.conversion_type, - ConversionType::CharLower | ConversionType::CharUpper - ) => - { - self.format_char(string, integer_scalar_to_char(value)?) - } - ScalarValue::Int8(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value as i64) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, (*value as u8) as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int8", - self.conversion_type - ) - } - }, - ScalarValue::Int16(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value as i64) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, (*value as u16) as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int16", - self.conversion_type - ) - } - }, - ScalarValue::Int32(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value as i64) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, (*value as u32) as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int32", - self.conversion_type - ) - } - }, - ScalarValue::Int64(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int64", - self.conversion_type - ) - } - }, - ScalarValue::UInt8(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt8", - self.conversion_type - ) - } - }, - ScalarValue::UInt16(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt16", - self.conversion_type - ) - } - }, - ScalarValue::UInt32(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt32", - self.conversion_type - ) - } - }, - ScalarValue::UInt64(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt64", - self.conversion_type - ) - } - }, + ScalarValue::Int8(value) => self.format_integer(string, value, "Int8"), + ScalarValue::Int16(value) => self.format_integer(string, value, "Int16"), + ScalarValue::Int32(value) => self.format_integer(string, value, "Int32"), + ScalarValue::Int64(value) => self.format_integer(string, value, "Int64"), + ScalarValue::UInt8(value) => self.format_integer(string, value, "UInt8"), + ScalarValue::UInt16(value) => self.format_integer(string, value, "UInt16"), + ScalarValue::UInt32(value) => self.format_integer(string, value, "UInt32"), + ScalarValue::UInt64(value) => self.format_integer(string, value, "UInt64"), ScalarValue::Float16(value) => match (self.conversion_type, value) { ( ConversionType::DecFloatLower @@ -1484,6 +1368,48 @@ impl ConversionSpecifier { } } + fn format_integer( + &self, + writer: &mut String, + value: &Option, + type_name: &str, + ) -> Result<()> + where + T: Copy + Into, + { + let Some(value) = value.map(Into::into) else { + return if self.conversion_type.supports_integer() { + self.format_string(writer, "null") + } else { + self.invalid_integer_conversion(type_name) + }; + }; + + match self.conversion_type { + ConversionType::DecInt => value.format_decimal(self, writer), + ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt => { + self.format_unsigned(writer, value.unsigned_bits()) + } + ConversionType::CharLower | ConversionType::CharUpper => { + self.format_char(writer, value.to_char()?) + } + ConversionType::StringLower | ConversionType::StringUpper => { + self.format_string(writer, &value.decimal_string()) + } + _ => self.invalid_integer_conversion(type_name), + } + } + + fn invalid_integer_conversion(&self, type_name: &str) -> Result { + exec_err!( + "Invalid conversion type: {:?} for {}", + self.conversion_type, + type_name + ) + } + fn format_hex_float(&self, writer: &mut String, value: f64) -> Result<()> { // Handle special cases first let (sign, raw_exponent, mantissa) = value.to_parts(); @@ -2588,6 +2514,74 @@ mod tests { ); } + #[test] + fn test_integer_formatting_across_widths() -> Result<()> { + let cases = [ + ( + ScalarValue::Int8(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ff|377|-1", + ), + ( + ScalarValue::Int16(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ffff|177777|-1", + ), + ( + ScalarValue::Int32(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ffffffff|37777777777|-1", + ), + ( + ScalarValue::Int64(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ffffffffffffffff|1777777777777777777777|-1", + ), + ( + ScalarValue::UInt8(Some(255)), + "%d|%x|%o|%s|%c", + 5, + "255|ff|377|255|ΓΏ", + ), + ( + ScalarValue::UInt16(Some(65535)), + "%d|%x|%o|%s", + 4, + "65535|ffff|177777|65535", + ), + ( + ScalarValue::UInt32(Some(u32::MAX)), + "%d|%x|%o|%s", + 4, + "4294967295|ffffffff|37777777777|4294967295", + ), + ( + ScalarValue::UInt64(Some(u64::MAX)), + "%d|%x|%o|%s", + 4, + "18446744073709551615|ffffffffffffffff|1777777777777777777777|18446744073709551615", + ), + ( + ScalarValue::Int32(None), + "%d|%x|%o|%s|%c", + 5, + "null|null|null|null|null", + ), + ]; + + for (value, fmt, arg_count, expected) in cases { + let data_types = vec![value.data_type(); arg_count]; + let formatter = Formatter::parse(fmt, &data_types)?; + let args = vec![value; arg_count]; + assert_eq!(formatter.format(&args)?, expected, "{fmt}"); + } + Ok(()) + } + #[test] fn test_insert_thousands_separator() { assert_eq!(insert_thousands_separator("1234567.89"), "1,234,567.89");