Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 215 additions & 45 deletions datafusion/functions-nested/src/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,20 @@

use arrow::array::{
Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
NullBufferBuilder, OffsetSizeTrait, new_null_array,
NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array,
};
use arrow::datatypes::{DataType, Field};

use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, Field};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, utils::take_function_args};
use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;

use crate::utils::compare_element_to_list;
use crate::utils::make_scalar_function;

use std::sync::Arc;

Expand Down Expand Up @@ -125,7 +123,28 @@ impl ScalarUDFImpl for ArrayReplace {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match (from_arg, to_arg) {
(ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => {
let result = array_replace_with_scalar_args(
&list_array,
scalar_from,
scalar_to,
1i64,
)?;
Ok(ColumnarValue::Array(result))
}
(from_arg, to_arg) => {
let from_array = from_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let arr_n = vec![1i64; num_rows];
let result =
array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -200,7 +219,41 @@ impl ScalarUDFImpl for ArrayReplaceN {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_n_inner)(&args.args)
let [list_arg, from_arg, to_arg, max_arg] =
take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match (from_arg, to_arg, max_arg) {
(
ColumnarValue::Scalar(scalar_from),
ColumnarValue::Scalar(scalar_to),
ColumnarValue::Scalar(scalar_max),
) => {
let a = scalar_max.to_array_of_size(1)?;
let n = as_int64_array(&a)?.value(0);
let result = array_replace_with_scalar_args(
&list_array,
scalar_from,
scalar_to,
n,
)?;
Ok(ColumnarValue::Array(result))
}
(from_arg, to_arg, max_arg) => {
let from_array = from_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let arr_n = match max_arg {
ColumnarValue::Scalar(s) => {
let a = s.to_array_of_size(1)?;
as_int64_array(&a)?.values().to_vec()
}
ColumnarValue::Array(a) => as_int64_array(a)?.values().to_vec(),
};
let result =
array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -273,7 +326,28 @@ impl ScalarUDFImpl for ArrayReplaceAll {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_replace_all_inner)(&args.args)
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match (from_arg, to_arg) {
(ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => {
let result = array_replace_with_scalar_args(
&list_array,
scalar_from,
scalar_to,
i64::MAX,
)?;
Ok(ColumnarValue::Array(result))
}
(from_arg, to_arg) => {
let from_array = from_arg.to_array(num_rows)?;
let to_array = to_arg.to_array(num_rows)?;
let arr_n = vec![i64::MAX; num_rows];
let result =
array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -343,7 +417,11 @@ fn general_replace<O: OffsetSizeTrait>(

let original_idx = O::usize_as(0);
let replace_idx = O::usize_as(1);
let n = arr_n[row_index];
let n = if arr_n.len() == 1 {
arr_n[0]
} else {
arr_n[row_index]
};
let mut counter = 0;

// All elements are false, no need to replace, just copy original data
Expand Down Expand Up @@ -412,63 +490,155 @@ fn general_replace<O: OffsetSizeTrait>(
)?))
}

fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace", args)?;
/// Replaces up to `max_replacements` occurrences of `needle` with the single
/// element in `to_array` for each row in `list_array`.
///
/// This is a specialized fast path for the all-scalar case that uses a single
/// bulk `not_distinct` comparison over only the visible values range, then
/// iterates match positions via `set_indices` instead of scanning every bit.
fn general_replace_with_scalar<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
needle: &Scalar<ArrayRef>,
to_array: &ArrayRef,
max_replacements: i64,
) -> Result<ArrayRef> {
let first_offset = list_array.offsets()[0].to_usize().unwrap();
let last_offset = list_array.offsets()[list_array.len()].to_usize().unwrap();
let visible_values = list_array
.values()
.slice(first_offset, last_offset - first_offset);

// replace at most one occurrence for each element
let arr_n = vec![1; array.len()];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
let original_data = visible_values.to_data();
let to_data = to_array.to_data();
let capacity = Capacities::Array(original_data.len());

let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &to_data],
false,
capacity,
);

let mut offsets = OffsetBufferBuilder::<O>::new(list_array.len());

// Single bulk comparison over the visible values only.
let match_bitmap = arrow_ord::cmp::not_distinct(&visible_values, needle)?;
let match_bits = match_bitmap.values();

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
// Offsets relative to visible_values (subtract first_offset).
let start = offset_window[0].to_usize().unwrap() - first_offset;
let end = offset_window[1].to_usize().unwrap() - first_offset;
let row_len = end - start;

if list_array.is_null(row_index) {
offsets.push_length(0);
continue;
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)

if max_replacements <= 0 {
mutable.extend(0, start, end);
offsets.push_length(row_len);
continue;
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type}'."),

// Slice the match bits to this row and iterate only over true positions.
let row_bits = match_bits.slice(start, row_len);
let mut match_positions = row_bits
.set_indices()
.take(max_replacements as usize)
.peekable();
if match_positions.peek().is_none() {
mutable.extend(0, start, end);
offsets.push_length(row_len);
continue;
}

// Iterate only over the positions that match using set_indices,
// which is more efficient than scanning every bit because the number
// of matches is typically much smaller than the total array size.
let mut prev_end = 0usize;
for match_pos in match_positions {
// Retain elements before this match.
if match_pos > prev_end {
mutable.extend(0, start + prev_end, start + match_pos);
}
// Emit the replacement element.
mutable.extend(1, 0, 1);
prev_end = match_pos + 1;
}

// Copy remaining elements after the last replacement.
if prev_end < row_len {
mutable.extend(0, start + prev_end, end);
}

offsets.push_length(row_len);
}

let data = mutable.freeze();

Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(list_array.value_type(), true)),
offsets.finish(),
arrow::array::make_array(data),
list_array.nulls().cloned(),
)?))
}

fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to, max] = take_function_args("array_replace_n", args)?;
/// Fast path for `array_replace` when all arguments are scalars.
///
/// Uses a single bulk `not_distinct` comparison instead of per-row comparisons.
fn array_replace_with_scalar_args(
list_array: &ArrayRef,
scalar_from: &ScalarValue,
scalar_to: &ScalarValue,
max_replacements: i64,
) -> Result<ArrayRef> {
// `not_distinct` doesn't support nested types, fall back to the generic array path.
if scalar_from.data_type().is_nested() {
let num_rows = list_array.len();
let from_array = scalar_from.to_array_of_size(num_rows)?;
let to_array = scalar_to.to_array_of_size(num_rows)?;
return array_replace_internal(
list_array,
&from_array,
&to_array,
&vec![max_replacements; num_rows],
);
}

// replace the specified number of occurrences
let arr_n = as_int64_array(max)?.values().to_vec();
match array.data_type() {
let needle = Scalar::new(scalar_from.to_array_of_size(1)?);
let to_array = scalar_to.to_array_of_size(1)?;
match list_array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
let list = list_array.as_list::<i32>();
general_replace_with_scalar::<i32>(list, &needle, &to_array, max_replacements)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_n does not support type '{array_type}'.")
let list = list_array.as_list::<i64>();
general_replace_with_scalar::<i64>(list, &needle, &to_array, max_replacements)
}
DataType::Null => Ok(new_null_array(list_array.data_type(), 1)),
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}

fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, from, to] = take_function_args("array_replace_all", args)?;

// replace all occurrences (up to "i64::MAX")
let arr_n = vec![i64::MAX; array.len()];
fn array_replace_internal(
array: &ArrayRef,
from: &ArrayRef,
to: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, from, to, &arr_n)
general_replace::<i32>(list_array, from, to, arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, from, to, &arr_n)
general_replace::<i64>(list_array, from, to, arr_n)
}
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
array_type => {
exec_err!("array_replace_all does not support type '{array_type}'.")
}
array_type => exec_err!("array_replace does not support type '{array_type}'."),
}
}
Loading
Loading