-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix log(0.0::float8) should error, not return -inf #22564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,6 +106,17 @@ fn is_valid_integer_base(base: f64) -> bool { | |
| base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64 | ||
| } | ||
|
|
||
| #[inline] | ||
| fn validate_log_value(value: f64) -> Result<(), ArrowError> { | ||
| if value == 0.0 { | ||
| Err(ArrowError::ComputeError( | ||
| "cannot take logarithm of zero".to_string(), | ||
| )) | ||
| } else { | ||
| Ok(()) | ||
| } | ||
| } | ||
|
|
||
| /// Calculate logarithm for Decimal32 values. | ||
| /// For integer bases >= 2 with zero scale, return an exact integer log when the | ||
| /// value is a perfect power of the base. Otherwise falls back to f64 computation. | ||
|
|
@@ -121,7 +132,10 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> { | |
| return Ok(int_log as f64); | ||
| } | ||
| } | ||
| decimal_to_f64(value, scale).map(|v| v.log(base)) | ||
| decimal_to_f64(value, scale).and_then(|v| { | ||
| validate_log_value(v)?; | ||
| Ok(v.log(base)) | ||
| }) | ||
| } | ||
|
|
||
| /// Calculate logarithm for Decimal64 values. | ||
|
|
@@ -139,7 +153,10 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> { | |
| return Ok(int_log as f64); | ||
| } | ||
| } | ||
| decimal_to_f64(value, scale).map(|v| v.log(base)) | ||
| decimal_to_f64(value, scale).and_then(|v| { | ||
| validate_log_value(v)?; | ||
| Ok(v.log(base)) | ||
| }) | ||
| } | ||
|
|
||
| /// Calculate logarithm for Decimal128 values. | ||
|
|
@@ -157,7 +174,20 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> | |
| return Ok(int_log as f64); | ||
| } | ||
| } | ||
| decimal_to_f64(value, scale).map(|v| v.log(base)) | ||
| decimal_to_f64(value, scale).and_then(|v| { | ||
| validate_log_value(v)?; | ||
| Ok(v.log(base)) | ||
| }) | ||
| } | ||
|
|
||
| /// Compute logarithm for Float16, Float32, and Float64 values | ||
| #[inline] | ||
| fn compute_float_log<T: Float + ToPrimitive>(value: T, base: T) -> Result<T, ArrowError> { | ||
| let value_f64 = value.to_f64().ok_or_else(|| { | ||
| ArrowError::ComputeError("Cannot convert value to f64".to_string()) | ||
| })?; | ||
| validate_log_value(value_f64)?; | ||
| Ok(value.log(base)) | ||
| } | ||
|
|
||
| /// Convert a scaled decimal value to f64. | ||
|
|
@@ -180,7 +210,9 @@ fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError> | |
| ArrowError::ComputeError(format!("Cannot convert {value} to f64")) | ||
| })?; | ||
| let scale_factor = 10f64.powi(scale as i32); | ||
| Ok((value_f64 / scale_factor).log(base)) | ||
| let value = value_f64 / scale_factor; | ||
| validate_log_value(value)?; | ||
| Ok(value.log(base)) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -247,27 +279,24 @@ impl ScalarUDFImpl for LogFunc { | |
| let value = value.to_array(args.number_rows)?; | ||
|
|
||
| let output: ArrayRef = match value.data_type() { | ||
| DataType::Float16 => { | ||
| calculate_binary_math::<Float16Type, Float16Type, Float16Type, _>( | ||
| &value, | ||
| &base, | ||
| |value, base| Ok(value.log(base)), | ||
| )? | ||
| } | ||
| DataType::Float32 => { | ||
| calculate_binary_math::<Float32Type, Float32Type, Float32Type, _>( | ||
| &value, | ||
| &base, | ||
| |value, base| Ok(value.log(base)), | ||
| )? | ||
| } | ||
| DataType::Float64 => { | ||
| calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>( | ||
| &value, | ||
| &base, | ||
| |value, base| Ok(value.log(base)), | ||
| )? | ||
| } | ||
| DataType::Float16 => calculate_binary_math::< | ||
| Float16Type, | ||
| Float16Type, | ||
| Float16Type, | ||
| _, | ||
| >(&value, &base, compute_float_log)?, | ||
| DataType::Float32 => calculate_binary_math::< | ||
| Float32Type, | ||
| Float32Type, | ||
| Float32Type, | ||
| _, | ||
| >(&value, &base, compute_float_log)?, | ||
| DataType::Float64 => calculate_binary_math::< | ||
| Float64Type, | ||
| Float64Type, | ||
| Float64Type, | ||
| _, | ||
| >(&value, &base, compute_float_log)?, | ||
| DataType::Decimal32(_, scale) => { | ||
| calculate_binary_math::<Decimal32Type, Float64Type, Float64Type, _>( | ||
| &value, | ||
|
|
@@ -308,10 +337,9 @@ impl ScalarUDFImpl for LogFunc { | |
| self.doc() | ||
| } | ||
|
|
||
| /// Simplify the `log` function by the relevant rules: | ||
| /// 1. Log(a, 1) ===> 0 | ||
| /// 2. Log(a, Power(a, b)) ===> b | ||
| /// 3. Log(a, a) ===> 1 | ||
| /// Simplify `log` only when the base is a known-valid literal. | ||
| /// This preserves current runtime `NaN` / domain behavior for column and | ||
| /// expression bases whose validity cannot be proven during planning. | ||
| fn simplify( | ||
| &self, | ||
| mut args: Vec<Expr>, | ||
|
|
@@ -358,43 +386,83 @@ impl ScalarUDFImpl for LogFunc { | |
| } else { | ||
| lit(ScalarValue::new_ten(&number_datatype)?) | ||
| }; | ||
| let base_datatype = info.get_data_type(&base)?; | ||
|
|
||
| if is_zero_literal(&number, &number_datatype)? | ||
| || is_zero_literal(&base, &base_datatype)? | ||
| { | ||
| return Ok(ExprSimplifyResult::Original(original_log_args( | ||
| num_args, &base, &number, | ||
| )?)); | ||
| } | ||
|
|
||
| let base_is_valid_literal = is_valid_log_base_literal(&base)?; | ||
|
|
||
| match number { | ||
| Expr::Literal(value, _) | ||
| if value == ScalarValue::new_one(&number_datatype)? => | ||
| if value == ScalarValue::new_one(&number_datatype)? | ||
| && base_is_valid_literal => | ||
| { | ||
| Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( | ||
| &info.get_data_type(&base)?, | ||
| )?))) | ||
| } | ||
| Expr::ScalarFunction(ScalarFunction { func, mut args }) | ||
| if is_pow(&func) && args.len() == 2 && base == args[0] => | ||
| if is_pow(&func) | ||
| && args.len() == 2 | ||
| && base == args[0] | ||
| && base_is_valid_literal => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for tightening the rewrite conditions. I think there is still one edge case here. The For example: select log(10.0, power(10.0, column1))
from (values (-400.0), (2.0)) as t(column1);With this rewrite, the result becomes Without the rewrite, select log(10.0, power(10.0, -400.0));Could we either avoid this rewrite unless the value can be proven non-zero, or preserve the runtime zero-value validation when the value side is still an expression? A regression test covering an exponent that underflows to zero would also help lock in the expected behavior.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I apologize; this was indeed my oversight. I originally intended to temporarily refrain from simplifying the reverse operations of log and power, but I missed it in my tightening logic. I'd like to ask if it would be better for us to temporarily abandon the simplification of these reverse operations? Because I feel that we seem unable to know the true range of values for a column (or an expression). |
||
| { | ||
| let b = args.pop().unwrap(); // length checked above | ||
| Ok(ExprSimplifyResult::Simplified(b)) | ||
| } | ||
| number => { | ||
| if number == base { | ||
| if number == base && base_is_valid_literal { | ||
| Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( | ||
| &number_datatype, | ||
| )?))) | ||
| } else { | ||
| let args = match num_args { | ||
| 1 => vec![number], | ||
| 2 => vec![base, number], | ||
| _ => { | ||
| return internal_err!( | ||
| "Unexpected number of arguments in log::simplify" | ||
| ); | ||
| } | ||
| }; | ||
| Ok(ExprSimplifyResult::Original(args)) | ||
| Ok(ExprSimplifyResult::Original(original_log_args( | ||
| num_args, &base, &number, | ||
| )?)) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn original_log_args(num_args: usize, base: &Expr, number: &Expr) -> Result<Vec<Expr>> { | ||
| match num_args { | ||
| 1 => Ok(vec![number.clone()]), | ||
| 2 => Ok(vec![base.clone(), number.clone()]), | ||
| _ => { | ||
| internal_err!("Unexpected number of arguments in log::simplify") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn is_zero_literal(expr: &Expr, data_type: &DataType) -> Result<bool> { | ||
| match expr { | ||
| Expr::Literal(value, _) => Ok(*value == ScalarValue::new_zero(data_type)?), | ||
| _ => Ok(false), | ||
| } | ||
| } | ||
|
|
||
| #[inline] | ||
| fn is_valid_log_base_literal(expr: &Expr) -> Result<bool> { | ||
| match expr { | ||
| Expr::Literal(value, _) => { | ||
| let scalar = value.cast_to(&DataType::Float64)?; | ||
| Ok( | ||
| matches!(scalar, ScalarValue::Float64(Some(base)) if base > 0.0 && base != 1.0), | ||
| ) | ||
| } | ||
| _ => Ok(false), | ||
| } | ||
| } | ||
|
|
||
| /// Returns true if the function is `PowerFunc` | ||
| fn is_pow(func: &ScalarUDF) -> bool { | ||
| func.inner().is::<PowerFunc>() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.