diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index d524afe43a5a3..646d95d6bb359 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -43,7 +43,7 @@ pub fn set_nulls( /// /// The output is `true` for rows where the filter is `Some(true)`, and `false` /// for rows where the filter is `Some(false)` or `None`. -pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer { +pub fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer { let Some(filter_nulls) = filter.nulls() else { return filter.values().clone(); }; diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 1935f29c4cfe8..ebb2163cefe45 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -45,6 +45,7 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, }; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filter_to_validity; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -552,10 +553,12 @@ impl FirstLastGroupsAccumulator { LexicographicalComparator::try_new(&sort_columns)? }; - for (idx_in_val, group_idx) in group_indices.iter().enumerate() { - let group_idx = *group_idx; + let filter_validity = opt_filter.map(filter_to_validity); - let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val)); + for (idx_in_val, &group_idx) in group_indices.iter().enumerate() { + let passed_filter = filter_validity + .as_ref() + .is_none_or(|validity| validity.value(idx_in_val)); let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val)); if !passed_filter || !is_set { @@ -1416,6 +1419,7 @@ mod tests { use arrow::{ array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray}, + buffer::NullBuffer, compute::SortOptions, datatypes::Schema, }; @@ -1423,6 +1427,45 @@ mod tests { use super::*; + fn new_int64_first_last_group_acc( + pick_first_in_group: bool, + ) -> Result>> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + ])); + + let sort_keys = [PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + + FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(DataType::Int64), + sort_keys.into(), + true, + &[DataType::Int64], + pick_first_in_group, + ) + } + + fn nullable_bool_filter(values: Vec, validity: Vec) -> BooleanArray { + BooleanArray::new( + BooleanBuffer::from(values), + Some(NullBuffer::from(validity)), + ) + } + + fn assert_group_acc_int64_result( + group_acc: &mut FirstLastGroupsAccumulator>, + expected: Int64Array, + ) -> Result<()> { + let result = group_acc.evaluate(EmitTo::All)?; + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result, &expected); + Ok(()) + } + #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = @@ -1621,6 +1664,54 @@ mod tests { Ok(()) } + #[test] + fn test_first_group_acc_rejects_null_filter_with_true_value_bit() -> Result<()> { + let mut group_acc = new_int64_first_last_group_acc(true)?; + + let values_and_orderings: Vec = vec![ + Arc::new(Int64Array::from(vec![10, 20])), + Arc::new(Int64Array::from(vec![1, 2])), + ]; + let filter = nullable_bool_filter(vec![true, false], vec![false, true]); + + group_acc.update_batch(&values_and_orderings, &[0, 0], Some(&filter), 1)?; + + assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![None])) + } + + #[test] + fn test_last_group_acc_rejects_null_filter_with_true_value_bit() -> Result<()> { + let mut group_acc = new_int64_first_last_group_acc(false)?; + + let values_and_orderings: Vec = vec![ + Arc::new(Int64Array::from(vec![10, 20, 30])), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ]; + let filter = + nullable_bool_filter(vec![true, true, false], vec![false, true, true]); + + group_acc.update_batch(&values_and_orderings, &[0, 0, 0], Some(&filter), 1)?; + + assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![Some(20)])) + } + + #[test] + fn test_first_group_acc_merge_rejects_null_filter_with_true_value_bit() -> Result<()> + { + let mut group_acc = new_int64_first_last_group_acc(true)?; + + let states: Vec = vec![ + Arc::new(Int64Array::from(vec![10, 20])), + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(BooleanArray::from(vec![true, true])), + ]; + let filter = nullable_bool_filter(vec![true, true], vec![false, true]); + + group_acc.merge_batch(&states, &[0, 0], Some(&filter), 1)?; + + assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![Some(20)])) + } + #[test] fn test_group_acc_size_of_ordering() -> Result<()> { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 25b69d16dd035..9672c83b26da3 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -6197,6 +6197,50 @@ GROUP BY g ---- 0 0 +# Grouped first_value/last_value must apply aggregate FILTER with Some(true) +# semantics: a row passes only when the predicate is TRUE. Rows where the +# predicate evaluates to NULL or FALSE must be excluded. +# +# Rows per group (predicate is b < 1): +# g=1: (a=10, b=NULL -> NULL), (a=20, b=2 -> FALSE) => no rows pass +# g=2: (a=30, b=0 -> TRUE), (a=40, b=NULL -> NULL), +# (a=50, b=-5 -> TRUE) => a=30 and a=50 pass +# g=3: (a=60, b=NULL -> NULL) => no rows pass +statement ok +CREATE TABLE first_last_filter_null_tests(g INT, a INT, b INT) AS VALUES +(1, 10, CAST(NULL AS INT)), +(1, 20, 2), +(2, 30, 0), +(2, 40, CAST(NULL AS INT)), +(2, 50, -5), +(3, 60, CAST(NULL AS INT)); + +# Groups 1 and 3 have no rows passing the filter -> NULL. +# Group 2 has a=30 and a=50 passing -> first_value ORDER BY a = 30. +query II +SELECT g, first_value(a ORDER BY a) FILTER (WHERE b < 1) AS fv +FROM first_last_filter_null_tests +GROUP BY g +ORDER BY g; +---- +1 NULL +2 30 +3 NULL + +# Same groups via last_value: group 2 picks the largest passing a = 50. +query II +SELECT g, last_value(a ORDER BY a) FILTER (WHERE b < 1) AS lv +FROM first_last_filter_null_tests +GROUP BY g +ORDER BY g; +---- +1 NULL +2 50 +3 NULL + +statement ok +DROP TABLE first_last_filter_null_tests; + # query_with_and_without_filter query III rowsort SELECT