Skip to content

Commit 44f3a91

Browse files
committed
Add integration test that uses config values for scalar udf
1 parent 887e0b1 commit 44f3a91

3 files changed

Lines changed: 92 additions & 3 deletions

File tree

datafusion/ffi/src/tests/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ pub struct ForeignLibraryModule {
7878

7979
pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF,
8080

81+
pub create_timezone_udf: extern "C" fn() -> FFI_ScalarUDF,
82+
8183
pub create_table_function:
8284
extern "C" fn(FFI_LogicalExtensionCodec) -> FFI_TableFunction,
8385

@@ -142,6 +144,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef {
142144
create_table: construct_table_provider,
143145
create_scalar_udf: create_ffi_abs_func,
144146
create_nullary_udf: create_ffi_random_func,
147+
create_timezone_udf: udf_udaf_udwf::create_timezone_func,
145148
create_table_function: create_ffi_table_func,
146149
create_sum_udaf: create_ffi_sum_func,
147150
create_stddev_udaf: create_ffi_stddev_func,

datafusion/ffi/src/tests/udf_udaf_udwf.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ use std::sync::Arc;
2020

2121
use arrow_schema::DataType;
2222
use datafusion_catalog::TableFunctionImpl;
23+
use datafusion_common::ScalarValue;
2324
use datafusion_expr::{
2425
AggregateUDF, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
25-
WindowUDF,
26+
Volatility, WindowUDF,
2627
};
2728
use datafusion_functions::math::abs::AbsFunc;
2829
use datafusion_functions::math::random::RandomFunc;
@@ -78,6 +79,47 @@ pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF {
7879
udf.into()
7980
}
8081

82+
#[derive(Debug, PartialEq, Eq, Hash)]
83+
struct TimeZoneUDF {
84+
signature: Signature,
85+
}
86+
87+
impl ScalarUDFImpl for TimeZoneUDF {
88+
fn as_any(&self) -> &dyn Any {
89+
self
90+
}
91+
fn name(&self) -> &str {
92+
"TimeZoneUDF"
93+
}
94+
95+
fn signature(&self) -> &Signature {
96+
&self.signature
97+
}
98+
99+
fn return_type(
100+
&self,
101+
_arg_types: &[DataType],
102+
) -> datafusion_common::Result<DataType> {
103+
Ok(DataType::Utf8)
104+
}
105+
106+
fn invoke_with_args(
107+
&self,
108+
args: ScalarFunctionArgs,
109+
) -> datafusion_common::Result<ColumnarValue> {
110+
let tz = args.config_options.execution.time_zone.clone();
111+
Ok(ColumnarValue::Scalar(ScalarValue::from(tz)))
112+
}
113+
}
114+
115+
pub(crate) extern "C" fn create_timezone_func() -> FFI_ScalarUDF {
116+
let udf: Arc<ScalarUDF> = Arc::new(ScalarUDF::from(TimeZoneUDF {
117+
signature: Signature::uniform(1, vec![DataType::Utf8], Volatility::Stable),
118+
}));
119+
120+
udf.into()
121+
}
122+
81123
pub(crate) extern "C" fn create_ffi_table_func(
82124
codec: FFI_LogicalExtensionCodec,
83125
) -> FFI_TableFunction {

datafusion/ffi/tests/ffi_udf.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@
1919
/// when the feature integration-tests is built
2020
#[cfg(feature = "integration-tests")]
2121
mod tests {
22-
use std::sync::Arc;
23-
22+
use arrow::array::{Array, AsArray};
2423
use arrow::datatypes::DataType;
2524
use datafusion::common::record_batch;
2625
use datafusion::error::{DataFusionError, Result};
2726
use datafusion::logical_expr::{ScalarUDF, ScalarUDFImpl};
2827
use datafusion::prelude::{SessionContext, col};
28+
use datafusion_execution::config::SessionConfig;
29+
use datafusion_expr::lit;
2930
use datafusion_ffi::tests::create_record_batch;
3031
use datafusion_ffi::tests::utils::get_module;
32+
use std::sync::Arc;
3133

3234
/// This test validates that we can load an external module and use a scalar
3335
/// udf defined in it via the foreign function interface. In this case we are
@@ -100,4 +102,46 @@ mod tests {
100102

101103
Ok(())
102104
}
105+
106+
#[tokio::test]
107+
async fn test_config_on_scalar_udf() -> Result<()> {
108+
let module = get_module()?;
109+
110+
let ffi_udf =
111+
module
112+
.create_timezone_udf()
113+
.ok_or(DataFusionError::NotImplemented(
114+
"External module failed to implement create_timezone_udf".to_string(),
115+
))?();
116+
let foreign_udf: Arc<dyn ScalarUDFImpl> = (&ffi_udf).into();
117+
118+
let udf = ScalarUDF::new_from_shared_impl(foreign_udf);
119+
120+
let ctx = SessionContext::default();
121+
122+
let df = ctx
123+
.read_empty()?
124+
.select(vec![udf.call(vec![lit("a")]).alias("a")])?;
125+
126+
let result = df.collect().await?;
127+
assert!(result[0].column(0).as_string::<i32>().is_null(0));
128+
129+
let mut config = SessionConfig::new();
130+
config.options_mut().execution.time_zone = Some("AEST".into());
131+
132+
let ctx = SessionContext::new_with_config(config);
133+
134+
let df = ctx
135+
.read_empty()?
136+
.select(vec![udf.call(vec![lit("a")]).alias("a")])?;
137+
138+
let result = df.collect().await?;
139+
140+
assert!(result.len() == 1);
141+
assert!(!result[0].column(0).as_string::<i32>().is_null(0));
142+
let result = result[0].column(0).as_string::<i32>().value(0);
143+
assert_eq!(result, "AEST");
144+
145+
Ok(())
146+
}
103147
}

0 commit comments

Comments
 (0)