From 9b3f32fdd2e492f0aeb696ecde5dfaba0e1d974d Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Wed, 20 May 2026 11:16:24 +0800 Subject: [PATCH 1/5] Refactor array remove invocation --- datafusion/functions-nested/src/remove.rs | 209 +++++++++++++++++++--- 1 file changed, 182 insertions(+), 27 deletions(-) diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index d0f838ddad12a..4a6005d41bfd1 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -18,16 +18,17 @@ //! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions. use crate::utils; -use crate::utils::make_scalar_function; use arrow::array::{ Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, - cast::AsArray, make_array, + Scalar, cast::AsArray, make_array, }; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -113,7 +114,21 @@ impl ScalarUDFImpl for ArrayRemove { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_remove_inner)(&args.args) + let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let arr_n = vec![1; num_rows]; + match element_arg { + ColumnarValue::Array(element_array) => { + let result = array_remove_internal(&list_array, element_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_element) => { + let result = + remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -214,7 +229,23 @@ impl ScalarUDFImpl for ArrayRemoveN { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_remove_n_inner)(&args.args) + let [list_arg, element_arg, max_arg] = + take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let max_array = max_arg.to_array(num_rows)?; + let arr_n = as_int64_array(&max_array)?.values().to_vec(); + match element_arg { + ColumnarValue::Array(element_array) => { + let result = array_remove_internal(&list_array, element_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_element) => { + let result = + remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -304,7 +335,21 @@ impl ScalarUDFImpl for ArrayRemoveAll { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_remove_all_inner)(&args.args) + let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + let arr_n = vec![i64::MAX; num_rows]; + match element_arg { + ColumnarValue::Array(element_array) => { + let result = array_remove_internal(&list_array, element_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar_element) => { + let result = + remove_with_scalar_needle(&list_array, scalar_element, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -316,27 +361,6 @@ impl ScalarUDFImpl for ArrayRemoveAll { } } -fn array_remove_inner(args: &[ArrayRef]) -> Result { - let [array, element] = take_function_args("array_remove", args)?; - - let arr_n = vec![1; array.len()]; - array_remove_internal(array, element, &arr_n) -} - -fn array_remove_n_inner(args: &[ArrayRef]) -> Result { - let [array, element, max] = take_function_args("array_remove_n", args)?; - - let arr_n = as_int64_array(max)?.values().to_vec(); - array_remove_internal(array, element, &arr_n) -} - -fn array_remove_all_inner(args: &[ArrayRef]) -> Result { - let [array, element] = take_function_args("array_remove_all", args)?; - - let arr_n = vec![i64::MAX; array.len()]; - array_remove_internal(array, element, &arr_n) -} - fn array_remove_internal( array: &ArrayRef, element_array: &ArrayRef, @@ -357,6 +381,45 @@ fn array_remove_internal( } } +/// Dispatches scalar-needle array removal by list offset type. +/// +/// `needle` must be a length-1 array containing the scalar element to remove. +fn array_remove_dispatch_scalar( + array: &ArrayRef, + needle: &ArrayRef, + arr_n: &[i64], +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove_with_scalar::(list_array, needle, arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove_with_scalar::(list_array, needle, arr_n) + } + array_type => exec_err!("array_remove does not support type '{array_type}'."), + } +} + +/// Removes elements matching a scalar needle from a list array. +/// +/// Uses a bulk `distinct` comparison for non-null, non-nested scalar elements, +/// falling back to the per-row `general_remove` path for null or nested types. +fn remove_with_scalar_needle( + list_array: &ArrayRef, + scalar_element: &ScalarValue, + arr_n: &[i64], +) -> Result { + if !scalar_element.is_null() && !scalar_element.data_type().is_nested() { + let needle = scalar_element.to_array_of_size(1)?; + array_remove_dispatch_scalar(list_array, &needle, arr_n) + } else { + let needle_array = scalar_element.to_array_of_size(list_array.len())?; + array_remove_internal(list_array, &needle_array, arr_n) + } +} + /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences /// of `element_array[i]`. /// @@ -468,6 +531,98 @@ fn general_remove( )?)) } +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences +/// of `needle[0]` (scalar element broadcasted). +/// +/// This is a specialized version of `general_remove` for scalar elements that +/// uses bulk comparison for better performance. +fn general_remove_with_scalar( + list_array: &GenericListArray, + needle: &ArrayRef, + arr_n: &[i64], +) -> Result { + let list_field = match list_array.data_type() { + DataType::List(field) | DataType::LargeList(field) => field, + _ => { + return exec_err!( + "Expected List or LargeList data type, got {:?}", + list_array.data_type() + ); + } + }; + let original_data = list_array.values().to_data(); + let mut offsets = Vec::::with_capacity(list_array.len() + 1); + offsets.push(OffsetSize::zero()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + Capacities::Array(original_data.len()), + ); + let nulls = list_array.nulls().cloned(); + let keep_mask = + arrow_ord::cmp::distinct(list_array.values(), &Scalar::new(Arc::clone(needle)))?; + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) { + offsets.push(offsets[row_index]); + continue; + } + + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + + let n = arr_n[row_index]; + + if n <= 0 { + mutable.extend(0, start, end); + offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); + continue; + } + + let eq_array = keep_mask.slice(start, end - start); + let num_to_remove = eq_array.false_count(); + + if num_to_remove == 0 { + mutable.extend(0, start, end); + offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); + continue; + } + + let max_removals = n.min(num_to_remove as i64); + let mut removed = 0i64; + let mut copied = 0usize; + let mut pending_batch_to_retain: Option = None; + for (i, keep) in eq_array.iter().enumerate() { + if keep == Some(false) && removed < max_removals { + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, start + i); + copied += i - bs; + pending_batch_to_retain = None; + } + removed += 1; + } else if pending_batch_to_retain.is_none() { + pending_batch_to_retain = Some(i); + } + } + + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, end); + copied += end - start - bs; + } + + offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); + } + + let new_values = make_array(mutable.freeze()); + Ok(Arc::new(GenericListArray::::try_new( + Arc::clone(list_field), + OffsetBuffer::new(offsets.into()), + new_values, + nulls, + )?)) +} + #[cfg(test)] mod tests { use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; From 9697c3a8ab5f9876f23b50c3b0c5f01e6e60dfc3 Mon Sep 17 00:00:00 2001 From: linfeng <33561138+lyne7-sc@users.noreply.github.com> Date: Wed, 20 May 2026 15:50:24 +0800 Subject: [PATCH 2/5] enhance array_remove slt --- .../test_files/array/array_remove.slt | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/datafusion/sqllogictest/test_files/array/array_remove.slt b/datafusion/sqllogictest/test_files/array/array_remove.slt index c3ce7073eca83..456ebb6482341 100644 --- a/datafusion/sqllogictest/test_files/array/array_remove.slt +++ b/datafusion/sqllogictest/test_files/array/array_remove.slt @@ -537,4 +537,76 @@ select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +# array_remove scalar arguments over multiple input rows +query ??? +select + array_remove(column1, 2), + array_remove_n(column1, 2, 2), + array_remove_all(column1, 2) +from ( + values + (make_array(1, 2, 2, 3, 2, 1, 4)), + (make_array(42, 2, 55, 63, 2)) +) as t(column1); +---- +[1, 2, 3, 2, 1, 4] [1, 3, 2, 1, 4] [1, 3, 1, 4] +[42, 55, 63, 2] [42, 55, 63] [42, 55, 63] + +# array_remove with elements containing NULLs — scalar path preserves NULLs +query ??? +select + array_remove(column1, 2), + array_remove_n(column1, 2, 2), + array_remove_all(column1, 2) +from ( + values + (make_array(1, 2, NULL, 3, 2, NULL, 4)), + (make_array(42, 2, NULL, 63, 2)) +) as t(column1); +---- +[1, NULL, 3, 2, NULL, 4] [1, NULL, 3, NULL, 4] [1, NULL, 3, NULL, 4] +[42, NULL, 63, 2] [42, NULL, 63] [42, NULL, 63] + +# array_remove_n with n exceeding match count +query ? +select array_remove_n(make_array(1, 2, 2, 3), 2, 100); +---- +[1, 3] + +# array_remove_n with n=0 and n=-1 (no removal) +query ?? +select + array_remove_n(make_array(1, 2, 2, 3), 2, 0), + array_remove_n(make_array(1, 2, 2, 3), 2, -1); +---- +[1, 2, 2, 3] [1, 2, 2, 3] + +# array_remove on empty arrays +query ?? +select + array_remove(arrow_cast(make_array(), 'List(Int64)'), 1), + array_remove_all(arrow_cast(make_array(), 'List(Int64)'), 1); +---- +[] [] + +# array_remove needle not found — array unchanged +query ? +select array_remove_all(make_array(1, 2, 3, 4, 5), 99); +---- +[1, 2, 3, 4, 5] + +# array_remove all elements match +query ? +select array_remove_all(make_array(7, 7, 7, 7), 7); +---- +[] + +# LargeList scalar path edge cases +query ?? +select + array_remove_all(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1), + array_remove_n(arrow_cast(make_array(1, 1, 1), 'LargeList(Int64)'), 1, 2); +---- +[] [1] + include ./cleanup.slt.part From b6023a094cd11457de305cdeb9329c5f86886298 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Thu, 21 May 2026 22:41:11 +0800 Subject: [PATCH 3/5] apply suggestions --- datafusion/functions-nested/src/remove.rs | 65 ++++++++++++++--------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 4a6005d41bfd1..3856a7c233089 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -398,7 +398,7 @@ fn array_remove_dispatch_scalar( let list_array = array.as_list::(); general_remove_with_scalar::(list_array, needle, arr_n) } - array_type => exec_err!("array_remove does not support type '{array_type}'."), + array_type => exec_err!("array_remove/array_remove_n/array_remove_all does not support type '{array_type}'."), } } @@ -550,7 +550,13 @@ fn general_remove_with_scalar( ); } }; - let original_data = list_array.values().to_data(); + + let list_offsets = list_array.offsets(); + let first_offset = list_offsets[0].to_usize().unwrap(); + let last_offset = list_offsets[list_offsets.len() - 1].to_usize().unwrap(); + let values_range_len = last_offset - first_offset; + let values_slice = list_array.values().slice(first_offset, values_range_len); + let original_data = values_slice.to_data(); let mut offsets = Vec::::with_capacity(list_array.len() + 1); offsets.push(OffsetSize::zero()); @@ -562,15 +568,19 @@ fn general_remove_with_scalar( let nulls = list_array.nulls().cloned(); let keep_mask = arrow_ord::cmp::distinct(list_array.values(), &Scalar::new(Arc::clone(needle)))?; + let remove_bits = match keep_mask.nulls() { + Some(validity) => !(&(keep_mask.values() & validity.inner())), + None => !keep_mask.values(), + }; - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + for (row_index, offset_window) in list_offsets.windows(2).enumerate() { if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) { offsets.push(offsets[row_index]); continue; } - let start = offset_window[0].to_usize().unwrap(); - let end = offset_window[1].to_usize().unwrap(); + let start = offset_window[0].to_usize().unwrap() - first_offset; + let end = offset_window[1].to_usize().unwrap() - first_offset; let n = arr_n[row_index]; @@ -580,35 +590,40 @@ fn general_remove_with_scalar( continue; } - let eq_array = keep_mask.slice(start, end - start); - let num_to_remove = eq_array.false_count(); + let row_len = end - start; + let row_remove_bits = remove_bits.slice(first_offset + start, row_len); + let num_to_remove = row_remove_bits.count_set_bits(); if num_to_remove == 0 { mutable.extend(0, start, end); - offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); + offsets.push(offsets[row_index] + OffsetSize::usize_as(row_len)); continue; } - let max_removals = n.min(num_to_remove as i64); - let mut removed = 0i64; + let max_removals = n.min(num_to_remove as i64) as usize; + + // Iterate only over the positions that need removal using set_indices, + // which is more efficient than scanning every bit. + let mut removed = 0usize; let mut copied = 0usize; - let mut pending_batch_to_retain: Option = None; - for (i, keep) in eq_array.iter().enumerate() { - if keep == Some(false) && removed < max_removals { - if let Some(bs) = pending_batch_to_retain { - mutable.extend(0, start + bs, start + i); - copied += i - bs; - pending_batch_to_retain = None; - } - removed += 1; - } else if pending_batch_to_retain.is_none() { - pending_batch_to_retain = Some(i); + let mut prev_end = start; // end of last copied range (absolute index into values_slice) + for remove_pos in row_remove_bits.set_indices() { + let abs_pos = start + remove_pos; + // Copy the range before this removal position + if abs_pos > prev_end { + mutable.extend(0, prev_end, abs_pos); + copied += abs_pos - prev_end; + } + prev_end = abs_pos + 1; + removed += 1; + if removed == max_removals { + break; } } - - if let Some(bs) = pending_batch_to_retain { - mutable.extend(0, start + bs, end); - copied += end - start - bs; + // Copy the remaining tail after the last removal + if prev_end < end { + mutable.extend(0, prev_end, end); + copied += end - prev_end; } offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); From 2e7cd4046f3e664ffd8d5c7a51337e7a902a16b2 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Thu, 21 May 2026 22:41:43 +0800 Subject: [PATCH 4/5] apply suggestions --- datafusion/functions-nested/src/remove.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 3856a7c233089..ad294b94d23d1 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -398,7 +398,9 @@ fn array_remove_dispatch_scalar( let list_array = array.as_list::(); general_remove_with_scalar::(list_array, needle, arr_n) } - array_type => exec_err!("array_remove/array_remove_n/array_remove_all does not support type '{array_type}'."), + array_type => exec_err!( + "array_remove/array_remove_n/array_remove_all does not support type '{array_type}'." + ), } } From cdb6021e9dc953e6f7d167ec5cd38d7cea80b1e3 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Fri, 22 May 2026 23:17:23 +0800 Subject: [PATCH 5/5] redundant comparisons --- datafusion/functions-nested/src/remove.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index ad294b94d23d1..388c60c87d6e7 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -569,7 +569,7 @@ fn general_remove_with_scalar( ); let nulls = list_array.nulls().cloned(); let keep_mask = - arrow_ord::cmp::distinct(list_array.values(), &Scalar::new(Arc::clone(needle)))?; + arrow_ord::cmp::distinct(&values_slice, &Scalar::new(Arc::clone(needle)))?; let remove_bits = match keep_mask.nulls() { Some(validity) => !(&(keep_mask.values() & validity.inner())), None => !keep_mask.values(), @@ -593,7 +593,7 @@ fn general_remove_with_scalar( } let row_len = end - start; - let row_remove_bits = remove_bits.slice(first_offset + start, row_len); + let row_remove_bits = remove_bits.slice(start, row_len); let num_to_remove = row_remove_bits.count_set_bits(); if num_to_remove == 0 { @@ -604,8 +604,9 @@ fn general_remove_with_scalar( let max_removals = n.min(num_to_remove as i64) as usize; - // Iterate only over the positions that need removal using set_indices, - // which is more efficient than scanning every bit. + // Iterate only over the removal positions via set_indices. This is + // efficient when the number of removals is small relative to the row + // length (common case), since it skips over retained elements. let mut removed = 0usize; let mut copied = 0usize; let mut prev_end = start; // end of last copied range (absolute index into values_slice)