Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.DateTrunc;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
Expand All @@ -58,6 +59,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand All @@ -69,6 +71,7 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -198,8 +201,7 @@ public Void visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, PartitionInc
Set<Set<Slot>> shuttledEqualSlotSet = context.getShuttledEqualSlotSet();
for (Set<Slot> equalSlotSet : shuttledEqualSlotSet) {
if (equalSlotSet.contains(consumerSlot)) {
Expression shuttledSlot = ExpressionUtils.shuttleExpressionWithLineage(
producerSlot, producerPlan);
Expression shuttledSlot = context.shuttleExpressionWithLineage(producerSlot, producerPlan);
if (shuttledSlot instanceof Slot) {
equalSlotSet.add((Slot) shuttledSlot);
}
Expand Down Expand Up @@ -239,7 +241,7 @@ public Void visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join,
continue;
}
Pair<Set<Slot>, Set<Slot>> partitionEqualSlotPair =
calEqualSet((SlotReference) partitionSlotToCheck, join);
calEqualSet((SlotReference) partitionSlotToCheck, join, context);
if (!partitionEqualSlotPair.value().isEmpty()) {
context.getShuttledEqualSlotSet().add(partitionEqualSlotPair.value());
}
Expand Down Expand Up @@ -526,31 +528,24 @@ private Set<SlotReference> getPartitionColumnsToCheck(PartitionIncrementCheckCon
*/
private static boolean checkPartition(Collection<? extends Expression> expressionsToCheck, Plan plan,
PartitionIncrementCheckContext context) {
Set<Entry<NamedExpression, RelatedTableColumnInfo>> partitionAndExprEntrySet
= new HashSet<>(context.getPartitionAndRefExpressionMap().entrySet());
List<Entry<NamedExpression, RelatedTableColumnInfo>> partitionAndExprEntryList
= new ArrayList<>(context.getPartitionAndRefExpressionMap().entrySet());
List<Expression> partitionExpressions = new ArrayList<>(partitionAndExprEntryList.size());
for (Entry<NamedExpression, RelatedTableColumnInfo> entry : partitionAndExprEntryList) {
partitionExpressions.add(entry.getValue().getPartitionExpression().orElse(entry.getKey()));
}
List<? extends Expression> partitionExpressionActualList =
context.shuttleAndNormalizeExpressionWithLineage(partitionExpressions, context.getOriginalPlan());
List<? extends Expression> expressionsShuttledToCheck =
context.shuttleAndNormalizeExpressionWithLineage(expressionsToCheck, context.getOriginalPlan());
boolean checked = false;
for (Map.Entry<NamedExpression, RelatedTableColumnInfo> partitionExpressionEntry
: partitionAndExprEntrySet) {
NamedExpression partitionNamedExpression = partitionExpressionEntry.getKey();
for (int i = 0; i < partitionAndExprEntryList.size(); i++) {
Map.Entry<NamedExpression, RelatedTableColumnInfo> partitionExpressionEntry =
partitionAndExprEntryList.get(i);
RelatedTableColumnInfo partitionTableColumnInfo = partitionExpressionEntry.getValue();
Optional<Expression> partitionExpressionOpt = partitionTableColumnInfo.getPartitionExpression();
Expression partitionExpressionActual = partitionExpressionOpt
.map(expr -> ExpressionUtils.shuttleExpressionWithLineage(expr,
context.getOriginalPlan()))
.orElseGet(() -> ExpressionUtils.shuttleExpressionWithLineage(partitionNamedExpression,
context.getOriginalPlan()));
// merge date_trunc
partitionExpressionActual = new ExpressionNormalization().rewrite(partitionExpressionActual,
new ExpressionRewriteContext(context.getCascadesContext()));
Expression partitionExpressionActual = partitionExpressionActualList.get(i);
OUTER_CHECK:
for (Expression projectSlotToCheck : expressionsToCheck) {
Expression expressionShuttledToCheck =
ExpressionUtils.shuttleExpressionWithLineage(projectSlotToCheck,
context.getOriginalPlan());
// merge date_trunc
expressionShuttledToCheck = new ExpressionNormalization().rewrite(expressionShuttledToCheck,
new ExpressionRewriteContext(context.getCascadesContext()));

for (Expression expressionShuttledToCheck : expressionsShuttledToCheck) {
Set<SlotReference> expressionToCheckSlots =
expressionShuttledToCheck.collectToSet(SlotReference.class::isInstance);
Set<SlotReference> partitionColumnSlots =
Expand Down Expand Up @@ -683,8 +678,19 @@ public static final class PartitionIncrementCheckContext {
private final Set<Set<Slot>> shuttledEqualSlotSet = new HashSet<>();
private final Map<CTEId, Plan> producerCteIdToPlanMap;
private final Plan originalPlan;
// Cache lineage-visible named expressions per plan identity to avoid repeated full plan walks.
private final Map<Plan, List<NamedExpression>> planLineageExpressionIndexes = new IdentityHashMap<>();
// Cache normalized expressions within this check context; normalization uses the same CascadesContext.
private final Map<Expression, Expression> normalizedExpressionMap = new IdentityHashMap<>();
// Reuse the normalization rewriter during one partition lineage check.
private final ExpressionNormalization expressionNormalization = new ExpressionNormalization();
// Reuse the rewrite context because all normalization in this checker shares the same CascadesContext.
private final ExpressionRewriteContext expressionRewriteContext;
private boolean failFast = false;

/**
* Construct partition increment check context.
*/
public PartitionIncrementCheckContext(NamedExpression mvPartitionColumn,
Expression mvPartitionExpression, Map<CTEId, Plan> producerCteIdToPlanMap,
Plan originalPlan,
Expand All @@ -694,6 +700,7 @@ public PartitionIncrementCheckContext(NamedExpression mvPartitionColumn,
this.cascadesContext = cascadesContext;
this.producerCteIdToPlanMap = producerCteIdToPlanMap;
this.originalPlan = originalPlan;
this.expressionRewriteContext = new ExpressionRewriteContext(cascadesContext);
}

public Set<String> getFailReasons() {
Expand Down Expand Up @@ -743,6 +750,64 @@ public Plan getOriginalPlan() {
return originalPlan;
}

private Expression shuttleExpressionWithLineage(Expression expression, Plan plan) {
return shuttleExpressionWithLineage(ImmutableList.of(expression), plan).get(0);
}

private List<? extends Expression> shuttleExpressionWithLineage(List<? extends Expression> expressions,
Plan plan) {
if (expressions.isEmpty()) {
return ImmutableList.of();
}
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
new ExpressionLineageReplacer.ExpressionReplaceContext(expressions);
for (NamedExpression namedExpression : getLineageExpressionIndex(plan)) {
if (!replaceContext.getUsedExprIdSet().contains(namedExpression.getExprId())) {
continue;
}
namedExpression.accept(ExpressionLineageReplacer.NamedExpressionCollector.INSTANCE, replaceContext);
}
List<? extends Expression> replacedExpressions = replaceContext.getReplacedExpressions();
if (replacedExpressions == null || expressions.size() != replacedExpressions.size()) {
return ExpressionUtils.shuttleExpressionWithLineage(expressions, plan);
}
return replacedExpressions;
}

private List<? extends Expression> shuttleAndNormalizeExpressionWithLineage(
Collection<? extends Expression> expressions, Plan plan) {
if (expressions.isEmpty()) {
return ImmutableList.of();
}
List<? extends Expression> shuttledExpressions =
shuttleExpressionWithLineage(ImmutableList.copyOf(expressions), plan);
List<Expression> normalizedExpressions = new ArrayList<>(shuttledExpressions.size());
for (Expression expression : shuttledExpressions) {
normalizedExpressions.add(normalizeExpression(expression));
}
return normalizedExpressions;
}

private Expression normalizeExpression(Expression expression) {
Expression normalizedExpression = normalizedExpressionMap.get(expression);
if (normalizedExpression == null) {
normalizedExpression = expressionNormalization.rewrite(expression, expressionRewriteContext);
normalizedExpressionMap.put(expression, normalizedExpression);
}
return normalizedExpression;
}

private List<NamedExpression> getLineageExpressionIndex(Plan plan) {
List<NamedExpression> lineageExpressionIndex = planLineageExpressionIndexes.get(plan);
if (lineageExpressionIndex == null) {
List<NamedExpression> collectedIndex = new ArrayList<>();
plan.accept(LineageExpressionCollector.INSTANCE, collectedIndex);
lineageExpressionIndex = ImmutableList.copyOf(collectedIndex);
planLineageExpressionIndexes.put(plan, lineageExpressionIndex);
}
return lineageExpressionIndex;
}

/**
* collect invalid table set to check self join
*/
Expand Down Expand Up @@ -772,6 +837,25 @@ public Void visitLogicalCatalogRelation(LogicalCatalogRelation relation,
}
}

private static final class LineageExpressionCollector extends DefaultPlanVisitor<Void, List<NamedExpression>> {
private static final LineageExpressionCollector INSTANCE = new LineageExpressionCollector();

@Override
public Void visitGroupPlan(GroupPlan groupPlan, List<NamedExpression> lineageExpressionIndex) {
return null;
}

@Override
public Void visit(Plan plan, List<NamedExpression> lineageExpressionIndex) {
for (Expression expression : plan.getExpressions()) {
if (expression instanceof NamedExpression) {
lineageExpressionIndex.add((NamedExpression) expression);
}
}
return super.visit(plan, lineageExpressionIndex);
}
}

/**
* Add partitionEqualSlot to partitionAndRefExpressionToCheck if partitionExpression use the partitionSlot
*/
Expand Down Expand Up @@ -816,7 +900,8 @@ public Expression visitNamedExpression(NamedExpression namedExpression, Void con
* the value equal set contain the slot itself
*/
private static Pair<Set<Slot>, Set<Slot>> calEqualSet(Slot slot,
LogicalJoin<? extends Plan, ? extends Plan> join) {
LogicalJoin<? extends Plan, ? extends Plan> join,
PartitionIncrementCheckContext context) {
Set<Slot> partitionEqualSlotSet = new HashSet<>();
JoinType joinType = join.getJoinType();
if (joinType.isInnerJoin() || joinType.isSemiJoin()) {
Expand All @@ -829,7 +914,7 @@ private static Pair<Set<Slot>, Set<Slot>> calEqualSet(Slot slot,
}
List<Expression> extendedPartitionEqualSlotSet = new ArrayList<>(partitionEqualSlotSet);
extendedPartitionEqualSlotSet.add(slot);
List<? extends Expression> shuttledEqualExpressions = ExpressionUtils.shuttleExpressionWithLineage(
List<? extends Expression> shuttledEqualExpressions = context.shuttleExpressionWithLineage(
extendedPartitionEqualSlotSet, join);
for (Expression shuttledEqualExpression : shuttledEqualExpressions) {
Set<Slot> objects = shuttledEqualExpression.collectToSet(expr -> expr instanceof SlotReference);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,66 @@ public void test40() {
});
}

// CTE + union all + wide aggregate should keep partition lineage inside each plan boundary.
@Test
public void testCteUnionAllWideAggregatePartitionLineage() {
PlanChecker.from(connectContext)
.checkExplain("with union_src as (\n"
+ " select\n"
+ " L_SHIPDATE as part_date,\n"
+ " L_ORDERKEY as order_key,\n"
+ " L_QUANTITY as metric1,\n"
+ " L_EXTENDEDPRICE as metric2,\n"
+ " L_DISCOUNT as metric3,\n"
+ " L_TAX as metric4,\n"
+ " L_RETURNFLAG as flag\n"
+ " from lineitem\n"
+ " union all\n"
+ " select\n"
+ " O_ORDERDATE as part_date,\n"
+ " O_ORDERKEY as order_key,\n"
+ " O_TOTALPRICE as metric1,\n"
+ " O_TOTALPRICE as metric2,\n"
+ " O_TOTALPRICE as metric3,\n"
+ " O_TOTALPRICE as metric4,\n"
+ " O_ORDERSTATUS as flag\n"
+ " from orders\n"
+ "), wide_project as (\n"
+ " select\n"
+ " date_trunc(part_date, 'day') as part_day,\n"
+ " part_date,\n"
+ " order_key,\n"
+ " metric1,\n"
+ " metric2,\n"
+ " metric3,\n"
+ " metric4,\n"
+ " flag\n"
+ " from union_src\n"
+ ")\n"
+ "select\n"
+ " part_day,\n"
+ " flag,\n"
+ " count(*) as cnt,\n"
+ " sum(metric1) as sum_metric1,\n"
+ " sum(metric2) as sum_metric2,\n"
+ " sum(metric3) as sum_metric3,\n"
+ " sum(metric4) as sum_metric4,\n"
+ " max(order_key) as max_key,\n"
+ " min(order_key) as min_key\n"
+ "from wide_project\n"
+ "group by part_day, flag",
nereidsPlanner -> {
Plan rewrittenPlan = nereidsPlanner.getRewrittenPlan();
RelatedTableInfo relatedTableInfo =
MaterializedViewUtils.getRelatedTableInfos("part_day", null,
rewrittenPlan, nereidsPlanner.getCascadesContext());
successWith(relatedTableInfo, ImmutableSet.of(
ImmutableList.of("lineitem", "l_shipdate", "true", "true"),
ImmutableList.of("orders", "o_orderdate", "true", "true")),
"day");
});
}


// test with union but not union all
@Test
Expand Down
Loading
Loading