Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,35 @@ public class RelCopyOnWriteVisitor<E extends Exception>

private final ExpressionCopyOnWriteVisitor<E> expressionCopyOnWriteVisitor;

/** Creates a visitor using a default expression visitor bound to this relation visitor. */
public RelCopyOnWriteVisitor() {
this.expressionCopyOnWriteVisitor = new ExpressionCopyOnWriteVisitor<>(this);
}

/**
* Creates a visitor using the given expression visitor.
*
* @param expressionCopyOnWriteVisitor the expression visitor to delegate to
*/
public RelCopyOnWriteVisitor(ExpressionCopyOnWriteVisitor<E> expressionCopyOnWriteVisitor) {
this.expressionCopyOnWriteVisitor = expressionCopyOnWriteVisitor;
}

/**
* Creates a visitor whose expression visitor is built from this instance by the given factory.
*
* @param fn factory producing the expression visitor from this relation visitor
*/
public RelCopyOnWriteVisitor(
Function<RelCopyOnWriteVisitor<E>, ExpressionCopyOnWriteVisitor<E>> fn) {
this.expressionCopyOnWriteVisitor = fn.apply(this);
}

/**
* Returns the expression visitor used to rewrite expressions within relations.
*
* @return the expression copy-on-write visitor
*/
protected ExpressionCopyOnWriteVisitor<E> getExpressionCopyOnWriteVisitor() {
return expressionCopyOnWriteVisitor;
}
Expand All @@ -69,12 +85,28 @@ public Optional<Rel> visit(Aggregate aggregate, EmptyVisitationContext context)
.build());
}

/**
* Rewrites an aggregate grouping, returning a new grouping if any expression changed.
*
* @param grouping the grouping to rewrite
* @param context the visitation context
* @return the rewritten grouping, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<Aggregate.Grouping> visitGrouping(
Aggregate.Grouping grouping, EmptyVisitationContext context) throws E {
return visitExprList(grouping.getExpressions(), context)
.map(exprs -> Aggregate.Grouping.builder().from(grouping).expressions(exprs).build());
}

/**
* Rewrites an aggregate measure, returning a new measure if anything changed.
*
* @param measure the measure to rewrite
* @param context the visitation context
* @return the rewritten measure, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<Aggregate.Measure> visitMeasure(
Aggregate.Measure measure, EmptyVisitationContext context) throws E {
Optional<Expression> preMeasureFilter =
Expand All @@ -93,6 +125,14 @@ protected Optional<Aggregate.Measure> visitMeasure(
.build());
}

/**
* Rewrites an aggregate function invocation, returning a new one if anything changed.
*
* @param afi the aggregate function invocation to rewrite
* @param context the visitation context
* @return the rewritten invocation, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<AggregateFunctionInvocation> visitAggregateFunction(
AggregateFunctionInvocation afi, EmptyVisitationContext context) throws E {
Optional<List<FunctionArg>> arguments = visitFunctionArguments(afi.arguments(), context);
Expand Down Expand Up @@ -232,6 +272,14 @@ public Optional<Rel> visit(ExtensionDdl ddl, EmptyVisitationContext context) thr
throw new UnsupportedOperationException();
}

/**
* Rewrites a named-update transform expression, returning a new one if it changed.
*
* @param transform the transform expression to rewrite
* @param context the visitation context
* @return the rewritten transform expression, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<NamedUpdate.TransformExpression> visitTransformExpression(
NamedUpdate.TransformExpression transform, EmptyVisitationContext context) throws E {
return transform
Expand Down Expand Up @@ -533,6 +581,14 @@ public Optional<Rel> visit(
.build());
}

/**
* Rewrites a window relation function invocation, returning a new one if anything changed.
*
* @param windowRelFunctionInvocation the window relation function invocation to rewrite
* @param context the visitation context
* @return the rewritten invocation, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<ConsistentPartitionWindow.WindowRelFunctionInvocation> visitWindowRelFunction(
ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation,
EmptyVisitationContext context)
Expand All @@ -553,11 +609,27 @@ protected Optional<ConsistentPartitionWindow.WindowRelFunctionInvocation> visitW

// utilities

/**
* Rewrites a list of expressions, returning a new list if any expression changed.
*
* @param exprs the expressions to rewrite
* @param context the visitation context
* @return the rewritten list, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<List<Expression>> visitExprList(
List<Expression> exprs, EmptyVisitationContext context) throws E {
return transformList(exprs, context, (t, c) -> t.accept(getExpressionCopyOnWriteVisitor(), c));
}

/**
* Rewrites a field reference, returning a new one if its input expression changed.
*
* @param fieldReference the field reference to rewrite
* @param context the visitation context
* @return the rewritten field reference, or empty if unchanged
* @throws E if the visit fails
*/
public Optional<FieldReference> visitFieldReference(
FieldReference fieldReference, EmptyVisitationContext context) throws E {
Optional<Expression> inputExpression =
Expand All @@ -569,6 +641,14 @@ public Optional<FieldReference> visitFieldReference(
return Optional.of(FieldReference.builder().inputExpression(inputExpression).build());
}

/**
* Rewrites a list of function arguments, returning a new list if any argument changed.
*
* @param funcArgs the function arguments to rewrite
* @param context the visitation context
* @return the rewritten list, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<List<FunctionArg>> visitFunctionArguments(
List<FunctionArg> funcArgs, EmptyVisitationContext context) throws E {
return CopyOnWriteUtils.<FunctionArg, EmptyVisitationContext, E>transformList(
Expand All @@ -585,6 +665,14 @@ protected Optional<List<FunctionArg>> visitFunctionArguments(
});
}

/**
* Rewrites a sort field, returning a new one if its expression changed.
*
* @param sortField the sort field to rewrite
* @param context the visitation context
* @return the rewritten sort field, or empty if unchanged
* @throws E if the visit fails
*/
protected Optional<Expression.SortField> visitSortField(
Expression.SortField sortField, EmptyVisitationContext context) throws E {
return sortField
Expand Down
Loading