Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion datafusion/catalog/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use datafusion_common::{Result, internal_err};
use datafusion_expr::Expr;
use datafusion_expr::statistics::StatisticsRequest;

use datafusion_expr::dml::InsertOp;
use datafusion_expr::dml::{InsertOp, MergeIntoClause};
use datafusion_expr::{
CreateExternalTable, LogicalPlan, TableProviderFilterPushDown, TableType,
};
Expand Down Expand Up @@ -379,6 +379,23 @@ pub trait TableProvider: Any + Debug + Sync + Send {
async fn truncate(&self, _state: &dyn Session) -> Result<Arc<dyn ExecutionPlan>> {
not_impl_err!("TRUNCATE not supported for {} table", self.table_type())
}

/// Merge rows from a source into this table.
///
/// The `source` is an [`ExecutionPlan`] representing the USING clause.
/// The `on` condition is the join predicate from the ON clause.
/// The `clauses` describe the WHEN MATCHED / WHEN NOT MATCHED actions.
///
/// Returns an [`ExecutionPlan`] producing a single row with `count` (UInt64).
async fn merge_into(
&self,
_state: &dyn Session,
_source: Arc<dyn ExecutionPlan>,
_on: Expr,
_clauses: Vec<MergeIntoClause>,
) -> Result<Arc<dyn ExecutionPlan>> {
not_impl_err!("MERGE INTO not supported for {} table", self.table_type())
}
}

impl dyn TableProvider {
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,26 @@ impl DefaultPhysicalPlanner {
);
}
}
LogicalPlan::Dml(DmlStatement {
table_name,
target,
op: WriteOp::MergeInto(merge_op),
..
}) => {
let provider = source_as_provider(target)?;
let input_exec = children.one()?;
provider
.merge_into(
session_state,
input_exec,
merge_op.on.clone(),
merge_op.clauses.clone(),
)
.await
.map_err(|e| {
e.context(format!("MERGE INTO operation on table '{table_name}'"))
})?
}
LogicalPlan::Window(Window { window_expr, .. }) => {
assert_or_internal_err!(
!window_expr.is_empty(),
Expand Down
151 changes: 150 additions & 1 deletion datafusion/expr/src/logical_plan/dml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{DFSchemaRef, TableReference};
use datafusion_common::{DFSchemaRef, Result, TableReference, internal_err};

use crate::{Expr, LogicalPlan, TableSource};

Expand Down Expand Up @@ -307,6 +307,106 @@ pub struct MergeIntoOp {
pub clauses: Vec<MergeIntoClause>,
}

impl MergeIntoOp {
/// Count of top-level [`Expr`]s owned by this operation (no allocation).
///
/// Matches the length of [`Self::exprs`] and the `exprs` vec consumed by
/// [`Self::with_new_exprs`].
fn expr_count(&self) -> usize {
1 + self
.clauses
.iter()
.map(|c| {
c.predicate.is_some() as usize
+ match &c.action {
MergeIntoAction::Update(a) => a.len(),
MergeIntoAction::Insert { values, .. } => values.len(),
MergeIntoAction::Delete => 0,
}
})
.sum::<usize>()
}

/// Top-level [`Expr`]s in stable order: `on`, then per-clause predicate
/// (if any) and action value expressions.
pub fn exprs(&self) -> Vec<&Expr> {
let mut out = Vec::with_capacity(self.expr_count());
out.push(&self.on);
for clause in &self.clauses {
if let Some(predicate) = &clause.predicate {
out.push(predicate);
}
match &clause.action {
MergeIntoAction::Update(assignments) => {
out.extend(assignments.iter().map(|(_, value)| value));
}
MergeIntoAction::Insert { values, .. } => {
out.extend(values.iter());
}
MergeIntoAction::Delete => {}
}
}
out
}

/// Rebuild this `MergeIntoOp` from a flat vector of new expressions, in
/// the same order produced by [`Self::exprs`]. The clause kinds, action
/// kinds, column lists, and presence/absence of each predicate are
/// preserved from `self`.
pub fn with_new_exprs(&self, exprs: Vec<Expr>) -> Result<Self> {
let expected = self.expr_count();
if exprs.len() != expected {
return internal_err!(
"MergeIntoOp::with_new_exprs expected {expected} expressions, got {}",
exprs.len()
);
}
let mut iter = exprs.into_iter();
let on = iter.next().expect("non-empty by length check");
let clauses = self
.clauses
.iter()
.map(|clause| {
let predicate = clause
.predicate
.is_some()
.then(|| iter.next().expect("non-empty by length check"));
let action = match &clause.action {
MergeIntoAction::Update(assignments) => {
let assignments = assignments
.iter()
.map(|(name, _)| {
(
name.clone(),
iter.next().expect("non-empty by length check"),
)
})
.collect();
MergeIntoAction::Update(assignments)
}
MergeIntoAction::Insert { columns, values } => {
let values = values
.iter()
.map(|_| iter.next().expect("non-empty by length check"))
.collect();
MergeIntoAction::Insert {
columns: columns.clone(),
values,
}
}
MergeIntoAction::Delete => MergeIntoAction::Delete,
};
MergeIntoClause {
kind: clause.kind,
predicate,
action,
}
})
.collect();
Ok(Self { on, clauses })
}
}

/// A single WHEN clause within a MERGE INTO statement.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct MergeIntoClause {
Expand Down Expand Up @@ -445,4 +545,53 @@ mod tests {
MergeIntoClauseKind::NotMatchedBySource
);
}

#[test]
fn merge_into_op_exprs_round_trip() {
let op = MergeIntoOp {
on: col("id").eq(col("source_id")),
clauses: vec![
MergeIntoClause {
kind: MergeIntoClauseKind::Matched,
predicate: Some(col("qty").gt(lit(0_i64))),
action: MergeIntoAction::Update(vec![
("qty".to_string(), col("source_qty")),
("price".to_string(), col("source_price")),
]),
},
MergeIntoClause {
kind: MergeIntoClauseKind::NotMatched,
predicate: None,
action: MergeIntoAction::Insert {
columns: vec!["id".to_string(), "qty".to_string()],
values: vec![col("source_id"), col("source_qty")],
},
},
MergeIntoClause {
kind: MergeIntoClauseKind::NotMatchedBySource,
predicate: Some(col("active").eq(lit(true))),
action: MergeIntoAction::Delete,
},
],
};
let exprs = op.exprs();
assert_eq!(exprs.len(), 7);

let owned: Vec<Expr> = exprs.into_iter().cloned().collect();
let rebuilt = op.with_new_exprs(owned).unwrap();
assert_eq!(op, rebuilt);
}

#[test]
fn merge_into_op_with_new_exprs_length_mismatch() {
let op = MergeIntoOp {
on: col("id").eq(col("source_id")),
clauses: vec![],
};
let err = op.with_new_exprs(vec![]).unwrap_err();
assert!(
err.to_string().contains("expected 1 expressions, got 0"),
"unexpected error: {err}"
);
}
}
14 changes: 11 additions & 3 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::expr_rewriter::{
};
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::logical_plan::{DmlStatement, Statement};
use crate::logical_plan::{DmlStatement, Statement, WriteOp};
use crate::utils::{
enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs,
grouping_set_expr_count, grouping_set_to_exprlist, merge_schema, split_conjunction,
Expand Down Expand Up @@ -810,12 +810,20 @@ impl LogicalPlan {
op,
..
}) => {
self.assert_no_expressions(expr)?;
let input = self.only_input(inputs)?;
let op = match op {
WriteOp::MergeInto(merge_op) => {
WriteOp::MergeInto(Box::new(merge_op.with_new_exprs(expr)?))
}
other => {
self.assert_no_expressions(expr)?;
other.clone()
}
};
Ok(LogicalPlan::Dml(DmlStatement::new(
table_name.clone(),
Arc::clone(target),
op.clone(),
op,
Arc::new(input),
)))
}
Expand Down
27 changes: 26 additions & 1 deletion datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::{
DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, Limit,
LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, Sort,
Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode,
Values, Window, builder::unnest_with_options, dml::CopyTo,
Values, Window, WriteOp, builder::unnest_with_options, dml::CopyTo,
};
use datafusion_common::tree_node::TreeNodeRefContainer;

Expand Down Expand Up @@ -480,6 +480,10 @@ impl LogicalPlan {
}
_ => Ok(TreeNodeRecursion::Continue),
},
LogicalPlan::Dml(DmlStatement {
op: WriteOp::MergeInto(merge_op),
..
}) => merge_op.exprs().apply_ref_elements(f),
// plans without expressions
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
Expand Down Expand Up @@ -719,6 +723,27 @@ impl LogicalPlan {
)
})?
}
LogicalPlan::Dml(DmlStatement {
table_name,
target,
op: WriteOp::MergeInto(merge_op),
input,
output_schema,
}) => {
let owned_exprs: Vec<Expr> =
merge_op.exprs().into_iter().cloned().collect();
owned_exprs.map_elements(f)?.transform_data(|new_exprs| {
Ok(Transformed::no(LogicalPlan::Dml(DmlStatement {
table_name,
target,
op: WriteOp::MergeInto(Box::new(
merge_op.with_new_exprs(new_exprs)?,
)),
input,
output_schema,
})))
})?
}
// plans without expressions
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
Expand Down
Loading