diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 6e1271ef19aa9..6c28f3f3a24b6 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -606,24 +606,43 @@ fn test_simplify_with_cycle_count( #[test] fn test_simplify_log() { - // Log(c3, 1) ===> 0 + // Log(10, 1) ===> 0 + { + let expr = log(lit(10), lit(1)); + test_simplify(expr, lit(0)); + } + // Log(10, 10) ===> 1 + { + let expr = log(lit(10), lit(10)); + test_simplify(expr, lit(1)); + } + // Log(c3, 1) ===> Log(c3, 1) { let expr = log(col("c3_non_null"), lit(1)); - test_simplify(expr, lit(0i64)); + test_simplify(expr.clone(), expr); + } + // Log(10, Power(10, c4)) ===> Log(10, Power(10, c4)) + { + let expr = log(lit(10), power(lit(10), col("c4_non_null"))); + let expected = log(lit(10), power(lit(10), col("c4_non_null"))); + test_simplify(expr, expected); } - // Log(c3, c3) ===> 1 + // Log(c3, c3) ===> Log(c3, c3) { let expr = log(col("c3_non_null"), col("c3_non_null")); - let expected = lit(1i64); + let expected = log(col("c3_non_null"), col("c3_non_null")); test_simplify(expr, expected); } - // Log(c3, Power(c3, c4)) ===> c4 + // Log(c3, Power(c3, c4)) ===> Log(c3, Power(c3, c4)) { let expr = log( col("c3_non_null"), power(col("c3_non_null"), col("c4_non_null")), ); - let expected = col("c4_non_null"); + let expected = log( + col("c3_non_null"), + power(col("c3_non_null"), col("c4_non_null")), + ); test_simplify(expr, expected); } // Log(c3, c4) ===> Log(c3, c4) @@ -648,27 +667,6 @@ fn test_simplify_power() { let expected = col("c3_non_null"); test_simplify(expr, expected) } - // Power(c3, Log(c3, c4)) ===> cast(c4 AS Int64) - // The simplifier rewrites `power(b, log(b, x))` to `x`, but the - // rewritten expression must keep the same type as the original - // `power` call. `power`'s declared return type follows its base - // argument (c3 = Int64), so the UInt32 c4 has to be cast to Int64 - // to preserve the output schema the optimizer already committed to. - { - let expr = power( - col("c3_non_null"), - log(col("c3_non_null"), col("c4_non_null")), - ); - let expected = - Expr::Cast(Cast::new(Box::new(col("c4_non_null")), DataType::Int64)); - test_simplify(expr, expected) - } - // Power(c3, c4) ===> Power(c3, c4) - { - let expr = power(col("c3_non_null"), col("c4_non_null")); - let expected = power(col("c3_non_null"), col("c4_non_null")); - test_simplify(expr, expected) - } } #[test] diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 2ca2ed1b572be..65479bf159fc3 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -17,8 +17,6 @@ //! Math function: `log()`. -use super::power::PowerFunc; - use crate::utils::calculate_binary_math; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{ @@ -28,15 +26,12 @@ use arrow::datatypes::{ use arrow::error::ArrowError; use arrow_buffer::i256; use datafusion_common::types::NativeType; -use datafusion_common::{ - Result, ScalarValue, exec_err, internal_err, plan_datafusion_err, plan_err, -}; -use datafusion_expr::expr::ScalarFunction; +use datafusion_common::{Result, ScalarValue, exec_err, plan_datafusion_err, plan_err}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, - TypeSignature, TypeSignatureClass, lit, + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, TypeSignature, + TypeSignatureClass, lit, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -106,10 +101,37 @@ fn is_valid_integer_base(base: f64) -> bool { base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64 } +#[inline] +fn validate_log_value(value: f64) -> Result<(), ArrowError> { + if value == 0.0 { + Err(ArrowError::ComputeError( + "cannot take logarithm of zero".to_string(), + )) + } else { + Ok(()) + } +} + +#[inline] +fn validate_log_base(base: f64) -> Result<(), ArrowError> { + if base < 0.0 { + Err(ArrowError::ComputeError( + "cannot take logarithm of a negative number".to_string(), + )) + } else if base == 1.0 { + Err(ArrowError::ComputeError( + "division by zero in based logarithm".to_string(), + )) + } else { + Ok(()) + } +} + /// Calculate logarithm for Decimal32 values. /// For integer bases >= 2 with zero scale, return an exact integer log when the /// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { + validate_log_base(base)?; if scale == 0 && is_valid_integer_base(base) && let Ok(unscaled) = u32::try_from(value) @@ -121,13 +143,17 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + decimal_to_f64(value, scale).and_then(|v| { + validate_log_value(v)?; + Ok(v.log(base)) + }) } /// Calculate logarithm for Decimal64 values. /// For integer bases >= 2 with zero scale, return an exact integer log when the /// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { + validate_log_base(base)?; if scale == 0 && is_valid_integer_base(base) && let Ok(unscaled) = u64::try_from(value) @@ -139,13 +165,17 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + decimal_to_f64(value, scale).and_then(|v| { + validate_log_value(v)?; + Ok(v.log(base)) + }) } /// Calculate logarithm for Decimal128 values. /// For integer bases >= 2 with zero scale, return an exact integer log when the /// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { + validate_log_base(base)?; if scale == 0 && is_valid_integer_base(base) && let Ok(unscaled) = u128::try_from(value) @@ -157,7 +187,31 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + decimal_to_f64(value, scale).and_then(|v| { + validate_log_value(v)?; + Ok(v.log(base)) + }) +} + +/// Compute logarithm for Float16, Float32, and Float64 values +#[inline] +fn compute_float_log(value: T, base: T) -> Result { + if base < T::zero() { + return Err(ArrowError::ComputeError( + "cannot take logarithm of a negative number".to_string(), + )); + } + if base == T::one() { + return Err(ArrowError::ComputeError( + "division by zero in based logarithm".to_string(), + )); + } + if value.is_zero() { + return Err(ArrowError::ComputeError( + "cannot take logarithm of zero".to_string(), + )); + } + Ok(value.log(base)) } /// Convert a scaled decimal value to f64. @@ -176,11 +230,14 @@ fn log_decimal256(value: i256, scale: i8, base: f64) -> Result Some(v) => log_decimal128(v, scale, base), None => { // For very large Decimal256 values, use f64 computation + validate_log_base(base)?; let value_f64 = value.to_f64().ok_or_else(|| { ArrowError::ComputeError(format!("Cannot convert {value} to f64")) })?; let scale_factor = 10f64.powi(scale as i32); - Ok((value_f64 / scale_factor).log(base)) + let value = value_f64 / scale_factor; + validate_log_value(value)?; + Ok(value.log(base)) } } } @@ -247,27 +304,24 @@ impl ScalarUDFImpl for LogFunc { let value = value.to_array(args.number_rows)?; let output: ArrayRef = match value.data_type() { - DataType::Float16 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float32 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float64 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } + DataType::Float16 => calculate_binary_math::< + Float16Type, + Float16Type, + Float16Type, + _, + >(&value, &base, compute_float_log)?, + DataType::Float32 => calculate_binary_math::< + Float32Type, + Float32Type, + Float32Type, + _, + >(&value, &base, compute_float_log)?, + DataType::Float64 => calculate_binary_math::< + Float64Type, + Float64Type, + Float64Type, + _, + >(&value, &base, compute_float_log)?, DataType::Decimal32(_, scale) => { calculate_binary_math::( &value, @@ -308,15 +362,14 @@ impl ScalarUDFImpl for LogFunc { self.doc() } - /// Simplify the `log` function by the relevant rules: - /// 1. Log(a, 1) ===> 0 - /// 2. Log(a, Power(a, b)) ===> b - /// 3. Log(a, a) ===> 1 + /// Simplify `log` only when the base and value-side behavior remain safe to + /// evaluate at planning time. fn simplify( &self, mut args: Vec, info: &SimplifyContext, ) -> Result { + let original_args = args.clone(); let mut arg_types = args .iter() .map(|arg| info.get_data_type(arg)) @@ -358,46 +411,59 @@ impl ScalarUDFImpl for LogFunc { } else { lit(ScalarValue::new_ten(&number_datatype)?) }; + let base_datatype = info.get_data_type(&base)?; + + if is_zero_literal(&number, &number_datatype)? + || is_zero_literal(&base, &base_datatype)? + { + return Ok(ExprSimplifyResult::Original(original_args)); + } + + let base_is_valid_literal = is_valid_log_base_literal(&base)?; match number { Expr::Literal(value, _) - if value == ScalarValue::new_one(&number_datatype)? => + if value == ScalarValue::new_one(&number_datatype)? + && base_is_valid_literal => { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) } - Expr::ScalarFunction(ScalarFunction { func, mut args }) - if is_pow(&func) && args.len() == 2 && base == args[0] => - { - let b = args.pop().unwrap(); // length checked above - Ok(ExprSimplifyResult::Simplified(b)) - } number => { - if number == base { + if number == base && base_is_valid_literal { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( &number_datatype, )?))) } else { - let args = match num_args { - 1 => vec![number], - 2 => vec![base, number], - _ => { - return internal_err!( - "Unexpected number of arguments in log::simplify" - ); - } - }; - Ok(ExprSimplifyResult::Original(args)) + Ok(ExprSimplifyResult::Original(original_args)) } } } } } -/// Returns true if the function is `PowerFunc` -fn is_pow(func: &ScalarUDF) -> bool { - func.inner().is::() +#[inline] +fn is_zero_literal(expr: &Expr, data_type: &DataType) -> Result { + match expr { + Expr::Literal(value, _) => { + Ok(*value == ScalarValue::new_zero(&value.data_type())?) + } + _ => Ok(false), + } +} + +#[inline] +fn is_valid_log_base_literal(expr: &Expr) -> Result { + match expr { + Expr::Literal(value, _) => { + let scalar = value.cast_to(&DataType::Float64)?; + Ok( + matches!(scalar, ScalarValue::Float64(Some(base)) if base > 0.0 && base != 1.0), + ) + } + _ => Ok(false), + } } #[cfg(test)] @@ -1130,7 +1196,7 @@ mod tests { #[test] fn test_log_decimal128_invalid_base() { - // Invalid base (-2.0) should return NaN, matching f64::log behavior + // Invalid base (-2.0) should error, matching DuckDB behavior let arg_fields = vec![ Field::new("b", DataType::Float64, false).into(), Field::new("x", DataType::Decimal128(38, 0), false).into(), @@ -1145,21 +1211,11 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new() - .invoke_with_args(args) - .expect("should not error on invalid base"); - - match result { - ColumnarValue::Array(arr) => { - let floats = as_float64_array(&arr) - .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 1); - assert!(floats.value(0).is_nan()); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } + let result = LogFunc::new().invoke_with_args(args).unwrap_err(); + assert_eq!( + result.to_string().lines().next().unwrap(), + "Arrow error: Compute error: cannot take logarithm of a negative number" + ); } #[test] diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index e261bada87eda..fbc64698717bf 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -965,3 +965,59 @@ SELECT gcd(column1, 0) FROM (VALUES (7), (-3), (0)); 7 3 0 + +# Verify error handling for log with zero values +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(0, 0); + +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(0); + +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(2, 0); + +# Safe literal-base rewrites must preserve current results. +query BBBB +SELECT log(10, 1) = 0, log(10, 10) = 1, log(10) = 1, log(10, power(10, 2)) = 2; +---- +true true true true + +query B rowsort +SELECT log(10, power(10, column1)) = column1 +FROM (VALUES (2), (3)) AS t(column1); +---- +true +true + +# log(literal, power(literal, expr)) must preserve runtime validation for +# rows where power underflows to zero before log runs. +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(10.0, power(10.0, column1)) +FROM (VALUES (-400.0), (2.0)) AS t(column1); + +# Base 1 must error at runtime rather than being simplified away. +query error DataFusion error: Arrow error: Compute error: division by zero in based logarithm +SELECT log(1, 1); + +query error DataFusion error: Arrow error: Compute error: division by zero in based logarithm +SELECT log(1, power(1, 2)); + +query error DataFusion error: Arrow error: Compute error: division by zero in based logarithm +SELECT log(1, power(1, column1)) +FROM (VALUES (2), (3)) AS t(column1); + +# Negative bases must also error at runtime rather than being simplified away. +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of a negative number +SELECT log(-2, 1); + +# log(col, power(col, b)) must NOT be simplified away when col may be zero at runtime. +# Before the fix the planner rewrote log(a, power(a, b)) => b for any expression a, +# so the row where column1=0 silently returned 2.0 instead of raising an error. +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(column1, power(column1, 2)) +FROM (VALUES (0.0), (2.0)) AS t(column1); + +# log(col, col) must also preserve runtime validation for zero rows. +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(column1, column1) +FROM (VALUES (0.0), (2.0)) AS t(column1); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 9dbf8f16d85ab..0a4697ad3ecbd 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -644,10 +644,15 @@ select log(2, 2.0/3) a, log(10, 2.0/3) b; # log scalar ops with zero edgecases # please see https://github.com/apache/datafusion/pull/5245#issuecomment-1426828382 -query RR rowsort -select log(0) a, log(1, 64) b; ----- --Infinity Infinity +query error cannot take logarithm of zero +select log(0) a; + +# This behavior is consistent with duckdb +query error division by zero in based logarithm +select log(1, 64) a; + +query error cannot take logarithm of zero +select log(0, power(0, 2)) a; # log with columns #1 query RRR rowsort