From ef08a38b5791da6af9a957e7466e0d648a9770fa Mon Sep 17 00:00:00 2001 From: Puneet Dixit Date: Thu, 21 May 2026 13:32:56 +0530 Subject: [PATCH] fix: error on negative ln inputs --- datafusion/functions/src/math/ln.rs | 153 ++++++++++++++++++ datafusion/functions/src/math/mod.rs | 10 +- datafusion/sqllogictest/test_files/scalar.slt | 12 +- 3 files changed, 163 insertions(+), 12 deletions(-) create mode 100644 datafusion/functions/src/math/ln.rs diff --git a/datafusion/functions/src/math/ln.rs b/datafusion/functions/src/math/ln.rs new file mode 100644 index 0000000000000..270413704aafa --- /dev/null +++ b/datafusion/functions/src/math/ln.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `ln()`. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use arrow::error::ArrowError; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +use super::{bounds, get_ln_doc, ln_order}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct LnFunc { + signature: Signature, +} + +impl Default for LnFunc { + fn default() -> Self { + LnFunc::new() + } +} + +impl LnFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LnFunc { + fn name(&self) -> &str { + "ln" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + ln_order(input) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + bounds::unbounded_bounds(inputs) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new( + args[0] + .as_primitive::() + .try_unary::<_, Float64Type, _>(ln_checked_f64)?, + ) as ArrayRef, + DataType::Float32 => Arc::new( + args[0] + .as_primitive::() + .try_unary::<_, Float32Type, _>(ln_checked_f32)?, + ) as ArrayRef, + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ln_doc()) + } +} + +fn ln_checked_f64(value: f64) -> std::result::Result { + if value < 0.0 { + Err(ArrowError::ComputeError( + "cannot take logarithm of a negative number".to_string(), + )) + } else { + Ok(value.ln()) + } +} + +fn ln_checked_f32(value: f32) -> std::result::Result { + if value < 0.0 { + Err(ArrowError::ComputeError( + "cannot take logarithm of a negative number".to_string(), + )) + } else { + Ok(value.ln()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ln_checked_f64_negative_errors() { + let err = ln_checked_f64(-1.0).unwrap_err(); + assert_eq!( + err.to_string(), + "Compute error: cannot take logarithm of a negative number" + ); + } + + #[test] + fn test_ln_checked_f64_zero_and_positive() { + let zero = ln_checked_f64(0.0).unwrap(); + assert!(zero.is_infinite() && zero.is_sign_negative()); + + let one = ln_checked_f64(1.0).unwrap(); + assert!(one.abs() < f64::EPSILON); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 610e773d68fd0..243d198fe60e0 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -31,6 +31,7 @@ pub mod floor; pub mod gcd; pub mod iszero; pub mod lcm; +mod ln; pub mod log; pub mod monotonicity; pub mod nans; @@ -148,14 +149,7 @@ make_udf_function!(gcd::GcdFunc, gcd); make_udf_function!(nans::IsNanFunc, isnan); make_udf_function!(iszero::IsZeroFunc, iszero); make_udf_function!(lcm::LcmFunc, lcm); -make_math_unary_udf!( - LnFunc, - ln, - ln, - super::ln_order, - super::bounds::unbounded_bounds, - super::get_ln_doc -); +make_udf_function!(ln::LnFunc, ln); make_math_unary_udf!( Log2Func, log2, diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 89ae30e3c047b..310584c07c70a 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -581,14 +581,18 @@ select ln(0); ---- -Infinity +# ln scalar negative input +query error Arrow error: Compute error: cannot take logarithm of a negative number +select ln((-1.0)::float8); + # ln with columns (round is needed to normalize the outputs of different operating systems) query RRR rowsort -select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from signed_integers; +select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from unsigned_integers; ---- -0.69315 NaN 4.81218 +0 4.60517 6.34036 +0.69315 6.90776 4.81218 +1.09861 9.21034 6.88551 1.38629 NULL NULL -NaN 4.60517 NaN -NaN 9.21034 NaN ## log