Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn set_nulls<T: ArrowNumericType + Send>(
///
/// The output is `true` for rows where the filter is `Some(true)`, and `false`
/// for rows where the filter is `Some(false)` or `None`.
pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
pub fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer {
let Some(filter_nulls) = filter.nulls() else {
return filter.values().clone();
};
Expand Down
97 changes: 94 additions & 3 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt,
GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filter_to_validity;
use datafusion_functions_aggregate_common::utils::get_sort_options;
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
Expand Down Expand Up @@ -552,10 +553,12 @@ impl<S: ValueState> FirstLastGroupsAccumulator<S> {
LexicographicalComparator::try_new(&sort_columns)?
};

for (idx_in_val, group_idx) in group_indices.iter().enumerate() {
let group_idx = *group_idx;
let filter_validity = opt_filter.map(filter_to_validity);

let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val));
for (idx_in_val, &group_idx) in group_indices.iter().enumerate() {
let passed_filter = filter_validity
.as_ref()
.is_none_or(|validity| validity.value(idx_in_val));
let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val));

if !passed_filter || !is_set {
Expand Down Expand Up @@ -1416,13 +1419,53 @@ mod tests {

use arrow::{
array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray},
buffer::NullBuffer,
compute::SortOptions,
datatypes::Schema,
};
use datafusion_physical_expr::{PhysicalSortExpr, expressions::col};

use super::*;

fn new_int64_first_last_group_acc(
pick_first_in_group: bool,
) -> Result<FirstLastGroupsAccumulator<PrimitiveValueState<Int64Type>>> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("c", DataType::Int64, true),
]));

let sort_keys = [PhysicalSortExpr {
expr: col("c", &schema).unwrap(),
options: SortOptions::default(),
}];

FirstLastGroupsAccumulator::try_new(
PrimitiveValueState::<Int64Type>::new(DataType::Int64),
sort_keys.into(),
true,
&[DataType::Int64],
pick_first_in_group,
)
}

fn nullable_bool_filter(values: Vec<bool>, validity: Vec<bool>) -> BooleanArray {
BooleanArray::new(
BooleanBuffer::from(values),
Some(NullBuffer::from(validity)),
)
}

fn assert_group_acc_int64_result(
group_acc: &mut FirstLastGroupsAccumulator<PrimitiveValueState<Int64Type>>,
expected: Int64Array,
) -> Result<()> {
let result = group_acc.evaluate(EmitTo::All)?;
let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(result, &expected);
Ok(())
}

#[test]
fn test_first_last_value_value() -> Result<()> {
let mut first_accumulator =
Expand Down Expand Up @@ -1621,6 +1664,54 @@ mod tests {
Ok(())
}

#[test]
fn test_first_group_acc_rejects_null_filter_with_true_value_bit() -> Result<()> {
let mut group_acc = new_int64_first_last_group_acc(true)?;

let values_and_orderings: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![10, 20])),
Arc::new(Int64Array::from(vec![1, 2])),
];
let filter = nullable_bool_filter(vec![true, false], vec![false, true]);

group_acc.update_batch(&values_and_orderings, &[0, 0], Some(&filter), 1)?;

assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![None]))
}

#[test]
fn test_last_group_acc_rejects_null_filter_with_true_value_bit() -> Result<()> {
let mut group_acc = new_int64_first_last_group_acc(false)?;

let values_and_orderings: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![10, 20, 30])),
Arc::new(Int64Array::from(vec![1, 2, 3])),
];
let filter =
nullable_bool_filter(vec![true, true, false], vec![false, true, true]);

group_acc.update_batch(&values_and_orderings, &[0, 0, 0], Some(&filter), 1)?;

assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![Some(20)]))
}

#[test]
fn test_first_group_acc_merge_rejects_null_filter_with_true_value_bit() -> Result<()>
{
let mut group_acc = new_int64_first_last_group_acc(true)?;

let states: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![10, 20])),
Arc::new(Int64Array::from(vec![1, 2])),
Arc::new(BooleanArray::from(vec![true, true])),
];
let filter = nullable_bool_filter(vec![true, true], vec![false, true]);

group_acc.merge_batch(&states, &[0, 0], Some(&filter), 1)?;

assert_group_acc_int64_result(&mut group_acc, Int64Array::from(vec![Some(20)]))
}

#[test]
fn test_group_acc_size_of_ordering() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down
Loading