diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index a9a53a3cb989f..a38d0fafcb945 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -19,14 +19,13 @@ use arrow::array::{ Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, - NullBufferBuilder, OffsetSizeTrait, new_null_array, + NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array, }; -use arrow::datatypes::{DataType, Field}; - use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; -use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -34,7 +33,6 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use crate::utils::compare_element_to_list; -use crate::utils::make_scalar_function; use std::sync::Arc; @@ -125,7 +123,28 @@ impl ScalarUDFImpl for ArrayReplace { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_inner)(&args.args) + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match (from_arg, to_arg) { + (ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => { + let result = array_replace_with_scalar_args( + &list_array, + scalar_from, + scalar_to, + 1i64, + )?; + Ok(ColumnarValue::Array(result)) + } + (from_arg, to_arg) => { + let from_array = from_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let arr_n = vec![1i64; num_rows]; + let result = + array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -200,7 +219,41 @@ impl ScalarUDFImpl for ArrayReplaceN { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_n_inner)(&args.args) + let [list_arg, from_arg, to_arg, max_arg] = + take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match (from_arg, to_arg, max_arg) { + ( + ColumnarValue::Scalar(scalar_from), + ColumnarValue::Scalar(scalar_to), + ColumnarValue::Scalar(scalar_max), + ) => { + let a = scalar_max.to_array_of_size(1)?; + let n = as_int64_array(&a)?.value(0); + let result = array_replace_with_scalar_args( + &list_array, + scalar_from, + scalar_to, + n, + )?; + Ok(ColumnarValue::Array(result)) + } + (from_arg, to_arg, max_arg) => { + let from_array = from_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let arr_n = match max_arg { + ColumnarValue::Scalar(s) => { + let a = s.to_array_of_size(1)?; + as_int64_array(&a)?.values().to_vec() + } + ColumnarValue::Array(a) => as_int64_array(a)?.values().to_vec(), + }; + let result = + array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -273,7 +326,28 @@ impl ScalarUDFImpl for ArrayReplaceAll { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(array_replace_all_inner)(&args.args) + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match (from_arg, to_arg) { + (ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => { + let result = array_replace_with_scalar_args( + &list_array, + scalar_from, + scalar_to, + i64::MAX, + )?; + Ok(ColumnarValue::Array(result)) + } + (from_arg, to_arg) => { + let from_array = from_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let arr_n = vec![i64::MAX; num_rows]; + let result = + array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -343,7 +417,11 @@ fn general_replace( let original_idx = O::usize_as(0); let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; + let n = if arr_n.len() == 1 { + arr_n[0] + } else { + arr_n[row_index] + }; let mut counter = 0; // All elements are false, no need to replace, just copy original data @@ -412,63 +490,155 @@ fn general_replace( )?)) } -fn array_replace_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace", args)?; +/// Replaces up to `max_replacements` occurrences of `needle` with the single +/// element in `to_array` for each row in `list_array`. +/// +/// This is a specialized fast path for the all-scalar case that uses a single +/// bulk `not_distinct` comparison over only the visible values range, then +/// iterates match positions via `set_indices` instead of scanning every bit. +fn general_replace_with_scalar( + list_array: &GenericListArray, + needle: &Scalar, + to_array: &ArrayRef, + max_replacements: i64, +) -> Result { + let first_offset = list_array.offsets()[0].to_usize().unwrap(); + let last_offset = list_array.offsets()[list_array.len()].to_usize().unwrap(); + let visible_values = list_array + .values() + .slice(first_offset, last_offset - first_offset); - // replace at most one occurrence for each element - let arr_n = vec![1; array.len()]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + let original_data = visible_values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut offsets = OffsetBufferBuilder::::new(list_array.len()); + + // Single bulk comparison over the visible values only. + let match_bitmap = arrow_ord::cmp::not_distinct(&visible_values, needle)?; + let match_bits = match_bitmap.values(); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + // Offsets relative to visible_values (subtract first_offset). + let start = offset_window[0].to_usize().unwrap() - first_offset; + let end = offset_window[1].to_usize().unwrap() - first_offset; + let row_len = end - start; + + if list_array.is_null(row_index) { + offsets.push_length(0); + continue; } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + + if max_replacements <= 0 { + mutable.extend(0, start, end); + offsets.push_length(row_len); + continue; } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => exec_err!("array_replace does not support type '{array_type}'."), + + // Slice the match bits to this row and iterate only over true positions. + let row_bits = match_bits.slice(start, row_len); + let mut match_positions = row_bits + .set_indices() + .take(max_replacements as usize) + .peekable(); + if match_positions.peek().is_none() { + mutable.extend(0, start, end); + offsets.push_length(row_len); + continue; + } + + // Iterate only over the positions that match using set_indices, + // which is more efficient than scanning every bit because the number + // of matches is typically much smaller than the total array size. + let mut prev_end = 0usize; + for match_pos in match_positions { + // Retain elements before this match. + if match_pos > prev_end { + mutable.extend(0, start + prev_end, start + match_pos); + } + // Emit the replacement element. + mutable.extend(1, 0, 1); + prev_end = match_pos + 1; + } + + // Copy remaining elements after the last replacement. + if prev_end < row_len { + mutable.extend(0, start + prev_end, end); + } + + offsets.push_length(row_len); } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field(list_array.value_type(), true)), + offsets.finish(), + arrow::array::make_array(data), + list_array.nulls().cloned(), + )?)) } -fn array_replace_n_inner(args: &[ArrayRef]) -> Result { - let [array, from, to, max] = take_function_args("array_replace_n", args)?; +/// Fast path for `array_replace` when all arguments are scalars. +/// +/// Uses a single bulk `not_distinct` comparison instead of per-row comparisons. +fn array_replace_with_scalar_args( + list_array: &ArrayRef, + scalar_from: &ScalarValue, + scalar_to: &ScalarValue, + max_replacements: i64, +) -> Result { + // `not_distinct` doesn't support nested types, fall back to the generic array path. + if scalar_from.data_type().is_nested() { + let num_rows = list_array.len(); + let from_array = scalar_from.to_array_of_size(num_rows)?; + let to_array = scalar_to.to_array_of_size(num_rows)?; + return array_replace_internal( + list_array, + &from_array, + &to_array, + &vec![max_replacements; num_rows], + ); + } - // replace the specified number of occurrences - let arr_n = as_int64_array(max)?.values().to_vec(); - match array.data_type() { + let needle = Scalar::new(scalar_from.to_array_of_size(1)?); + let to_array = scalar_to.to_array_of_size(1)?; + match list_array.data_type() { DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, &to_array, max_replacements) } DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) - } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_n does not support type '{array_type}'.") + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, &to_array, max_replacements) } + DataType::Null => Ok(new_null_array(list_array.data_type(), 1)), + array_type => exec_err!("array_replace does not support type '{array_type}'."), } } -fn array_replace_all_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace_all", args)?; - - // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; array.len()]; +fn array_replace_internal( + array: &ArrayRef, + from: &ArrayRef, + to: &ArrayRef, + arr_n: &[i64], +) -> Result { match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_all does not support type '{array_type}'.") - } + array_type => exec_err!("array_replace does not support type '{array_type}'."), } } diff --git a/datafusion/sqllogictest/test_files/array/array_replace.slt b/datafusion/sqllogictest/test_files/array/array_replace.slt index 390ed4b946520..1ebcdd4142853 100644 --- a/datafusion/sqllogictest/test_files/array/array_replace.slt +++ b/datafusion/sqllogictest/test_files/array/array_replace.slt @@ -212,6 +212,33 @@ from large_nested_arrays_with_repeating_elements; [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] +# array_replace scalar arguments over multiple input rows +query ??? +select + array_replace(column1, 2, 9), + array_replace_n(column1, 2, 9, 2), + array_replace_all(column1, 2, 9) +from ( + values + (make_array(1, 2, 2, 3)), + (make_array(2, 4, 2)) +) as t(column1); +---- +[1, 9, 2, 3] [1, 9, 9, 3] [1, 9, 9, 3] +[9, 4, 2] [9, 4, 9] [9, 4, 9] + +# array_replace_n scalar max exceeding matches over multiple input rows +query ? +select array_replace_n(column1, 2, 9, 10) +from ( + values + (make_array(1, 2, 2, 3)), + (make_array(2, 4, 2)) +) as t(column1); +---- +[1, 9, 9, 3] +[9, 4, 9] + ## array_replace_n (aliases: `list_replace_n`) # array_replace_n scalar function #1 @@ -226,22 +253,35 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] -query ???? +query ?????? select array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2), array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2), array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3), - array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0); + array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, -1), + array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'LargeList(Int64)'), 1, 0, 10); ---- -[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] -query ??? +query ?????? select array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3, 2), array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0, 2), - array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3); + array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, 0), + array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, -1), + array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'FixedSizeList(4, Int64)'), 1, 0, 10); ---- -[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5] + +# array_replace_n scalar max exceeding matches for empty arrays +query ?? +select + array_replace_n(arrow_cast(make_array(), 'List(Int64)'), 2, 9, 10), + array_replace_n(arrow_cast(make_array(), 'LargeList(Int64)'), 2, 9, 10); +---- +[] [] # array_replace_n scalar function #2 (element is list) query ?? @@ -323,6 +363,13 @@ select array_replace_n(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)') ---- [1, 2, 3, 4, 5] +query ?? +select + array_replace_n(make_array(1, 2, 2), 2, 9, NULL), + array_replace_n(arrow_cast(make_array(1, 2, 2), 'LargeList(Int64)'), 2, 9, NULL); +---- +[1, 2, 2] [1, 2, 2] + # array_replace_n scalar function with columns #1 query ? select @@ -657,6 +704,14 @@ select column1, column2, column3, column4, array_replace_n(column1, column2, col NULL 3 2 1 NULL [3, 1, 3] 3 NULL 1 [NULL, 1, 3] +query ??? +select + array_replace(make_array(3, NULL, NULL), NULL, 5), + array_replace_n(make_array(3, NULL, NULL), NULL, 5, 10), + array_replace_all(make_array(3, NULL, NULL), NULL, 5); +---- +[3, 5, NULL] [3, 5, 5] [3, 5, 5] + statement ok