From 58464e935933a453101589d1e0d513c4d9374331 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Fri, 20 Mar 2026 14:39:50 +0100 Subject: [PATCH 1/4] fix(isthmus): std_dev, variance function mappings Signed-off-by: Niels Pardon --- .../io/substrait/dsl/SubstraitBuilder.java | 184 ++++++++++++++++++ .../substrait/isthmus/AggregateFunctions.java | 64 ++++-- .../isthmus/SubstraitRelNodeConverter.java | 4 +- .../isthmus/SubstraitRelVisitor.java | 122 ++++++++++++ .../AggregateFunctionConverter.java | 54 +++-- .../isthmus/expression/FunctionConverter.java | 95 ++++++++- .../isthmus/expression/FunctionMappings.java | 14 +- .../isthmus/AggregationFunctionsTest.java | 50 ++++- .../AggregateFunctionConverterTest.java | 82 ++++++++ 9 files changed, 639 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 9b26c6a2d..254a2edf9 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -1361,6 +1361,190 @@ public Aggregate.Measure sum0(Expression expr) { R.I64); } + /** + * Creates a population standard deviation aggregate measure for a specific field. + * + *

Computes the standard deviation using the population formula (n denominator), which + * considers all values in the dataset as the entire population. This is equivalent to SQL's + * STDDEV_POP function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing population standard deviation with + * distribution=POPULATION option + */ + public Aggregate.Measure stddevPopulation(Rel input, int field) { + return stddevPopulation(fieldReference(input, field)); + } + + /** + * Creates a population standard deviation aggregate measure for an expression. + * + *

Computes the standard deviation using the population formula (n denominator), which + * considers all values in the dataset as the entire population. This is equivalent to SQL's + * STDDEV_POP function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing population standard deviation + */ + public Aggregate.Measure stddevPopulation(Expression expr) { + return statisticalAggregate(expr, "std_dev", "POPULATION"); + } + + /** + * Creates a sample standard deviation aggregate measure for a specific field. + * + *

Computes the standard deviation using the sample formula (n-1 denominator), which applies + * Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV + * function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing sample standard deviation with distribution=SAMPLE + * option + */ + public Aggregate.Measure stddevSample(Rel input, int field) { + return stddevSample(fieldReference(input, field)); + } + + /** + * Creates a sample standard deviation aggregate measure for an expression. + * + *

Computes the standard deviation using the sample formula (n-1 denominator), which applies + * Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV + * function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing sample standard deviation + */ + public Aggregate.Measure stddevSample(Expression expr) { + return statisticalAggregate(expr, "std_dev", "SAMPLE"); + } + + /** + * Creates a population variance aggregate measure for a specific field. + * + *

Computes the variance using the population formula (n denominator), which considers all + * values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing population variance with distribution=POPULATION option + */ + public Aggregate.Measure variancePopulation(Rel input, int field) { + return variancePopulation(fieldReference(input, field)); + } + + /** + * Creates a population variance aggregate measure for an expression. + * + *

Computes the variance using the population formula (n denominator), which considers all + * values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing population variance + */ + public Aggregate.Measure variancePopulation(Expression expr) { + return statisticalAggregate(expr, "variance", "POPULATION"); + } + + /** + * Creates a sample variance aggregate measure for a specific field. + * + *

Computes the variance using the sample formula (n-1 denominator), which applies Bessel's + * correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing sample variance with distribution=SAMPLE option + */ + public Aggregate.Measure varianceSample(Rel input, int field) { + return varianceSample(fieldReference(input, field)); + } + + /** + * Creates a sample variance aggregate measure for an expression. + * + *

Computes the variance using the sample formula (n-1 denominator), which applies Bessel's + * correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing sample variance + */ + public Aggregate.Measure varianceSample(Expression expr) { + return statisticalAggregate(expr, "variance", "SAMPLE"); + } + + /** + * Helper method to create statistical aggregate measures (std_dev, variance) with distribution + * option. + * + * @param expr the expression to aggregate + * @param functionName the Substrait function name ("std_dev" or "variance") + * @param distribution the distribution type ("SAMPLE" or "POPULATION") + * @return an aggregate measure with the specified distribution option + */ + private Aggregate.Measure statisticalAggregate( + Expression expr, String functionName, String distribution) { + String typeString = ToTypeString.apply(expr.getType()); + SimpleExtension.AggregateFunctionVariant declaration = + extensions.getAggregateFunction( + SimpleExtension.FunctionAnchor.of( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, + String.format("%s:%s", functionName, typeString))); + FunctionOption distributionOption = + FunctionOption.builder().name("distribution").addValues(distribution).build(); + return measure( + AggregateFunctionInvocation.builder() + .arguments(Arrays.asList(expr)) + .outputType(TypeCreator.asNullable(expr.getType())) + .declaration(declaration) + .addOptions(distributionOption) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()); + } + private Aggregate.Measure singleArgumentArithmeticAggregate( Expression expr, String functionName, Type outputType) { String typeString = ToTypeString.apply(expr.getType()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 0d5d5bf0e..ffa55b652 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -29,6 +29,30 @@ public class AggregateFunctions { /** Substrait-specific AVG aggregate function (nullable return type). */ public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG); + /** + * Standard deviation (population) aggregate function. Maps to Substrait's std_dev function with + * distribution=POPULATION option. + */ + public static SqlAggFunction STDDEV_POP = new SubstraitAvgAggFunction(SqlKind.STDDEV_POP); + + /** + * Standard deviation (sample) aggregate function. Maps to Substrait's std_dev function with + * distribution=SAMPLE option. + */ + public static SqlAggFunction STDDEV_SAMP = new SubstraitAvgAggFunction(SqlKind.STDDEV_SAMP); + + /** + * Variance (population) aggregate function. Maps to Substrait's variance function with + * distribution=POPULATION option. + */ + public static SqlAggFunction VAR_POP = new SubstraitAvgAggFunction(SqlKind.VAR_POP); + + /** + * Variance (sample) aggregate function. Maps to Substrait's variance function with + * distribution=SAMPLE option. + */ + public static SqlAggFunction VAR_SAMP = new SubstraitAvgAggFunction(SqlKind.VAR_SAMP); + /** Substrait-specific SUM aggregate function (nullable return type). */ public static SqlAggFunction SUM = new SubstraitSumAggFunction(); @@ -42,18 +66,34 @@ public class AggregateFunctions { * @return optional containing Substrait equivalent if conversion applies */ public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) { - if (aggFunction instanceof SqlMinMaxAggFunction) { - SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction; - return Optional.of( - fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX); - } else if (aggFunction instanceof SqlAvgAggFunction) { - return Optional.of(AggregateFunctions.AVG); - } else if (aggFunction instanceof SqlSumAggFunction) { - return Optional.of(AggregateFunctions.SUM); - } else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) { - return Optional.of(AggregateFunctions.SUM0); - } else { - return Optional.empty(); + // First check by SqlKind to handle all statistical functions + SqlKind kind = aggFunction.getKind(); + switch (kind) { + case MIN: + return Optional.of(AggregateFunctions.MIN); + case MAX: + return Optional.of(AggregateFunctions.MAX); + case AVG: + return Optional.of(AggregateFunctions.AVG); + case STDDEV_POP: + return Optional.of(AggregateFunctions.STDDEV_POP); + case STDDEV_SAMP: + return Optional.of(AggregateFunctions.STDDEV_SAMP); + case VAR_POP: + return Optional.of(AggregateFunctions.VAR_POP); + case VAR_SAMP: + return Optional.of(AggregateFunctions.VAR_SAMP); + case SUM: + case SUM0: + // Check instance type for SUM variants + if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) { + return Optional.of(AggregateFunctions.SUM0); + } else if (aggFunction instanceof SqlSumAggFunction) { + return Optional.of(AggregateFunctions.SUM); + } + return Optional.empty(); + default: + return Optional.empty(); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 52849f96f..ab1807066 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -398,7 +398,9 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { .collect(java.util.stream.Collectors.toList()); Optional operator = aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc( - measure.getFunction().declaration().key(), measure.getFunction().outputType()); + measure.getFunction().declaration().key(), + measure.getFunction().outputType(), + measure.getFunction().options()); if (!operator.isPresent()) { throw new IllegalArgumentException( String.format( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index bc4d0a9f2..979f75195 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -36,6 +36,7 @@ import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; import io.substrait.type.Type; +import io.substrait.type.TypeCreator; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -52,11 +53,14 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; @@ -84,6 +88,9 @@ public class SubstraitRelVisitor extends RelNodeVisitor { private Map fieldAccessDepthMap; + /** Rex builder for creating Rex expressions during conversion. */ + protected RexBuilder rexBuilder; + /** * Creates a new SubstraitRelVisitor with the specified type factory and extensions. * @@ -106,6 +113,7 @@ public SubstraitRelVisitor(ConverterProvider converterProvider) { this.typeConverter = converterProvider.getTypeConverter(); this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter(); this.rexExpressionConverter = converterProvider.getRexExpressionConverter(this); + this.rexBuilder = new RexBuilder(converterProvider.getTypeFactory()); } /** @@ -411,7 +419,35 @@ Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { return Aggregate.Grouping.builder().addAllExpressions(references).build(); } + /** + * Converts a Calcite {@link AggregateCall} to a Substrait {@link Aggregate.Measure}. + * + *

This method handles the conversion of aggregate function calls from Calcite's representation + * to Substrait's format. For statistical aggregate functions (STDDEV_POP, STDDEV_SAMP, VAR_POP, + * VAR_SAMP), it automatically transforms the input relation by inserting a projection that casts + * the aggregate function's argument fields to DOUBLE (FP64) type, ensuring type compatibility + * with Substrait's statistical function requirements. Fields not referenced by the aggregate + * function are passed through unchanged. + * + *

The method also processes optional filter arguments (FILTER clauses) by converting them to + * Substrait's preMeasureFilter representation. + * + * @param input the input relational node providing data to the aggregate operation + * @param inputType the Substrait struct type representing the schema of the input relation + * @param call the Calcite aggregate call to convert, containing the aggregate function, + * arguments, and optional filter + * @return a Substrait {@link Aggregate.Measure} representing the aggregate function invocation + * with its configuration + * @throws UnsupportedOperationException if the aggregate function cannot be converted to a + * Substrait representation (no matching function binding found) + */ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) { + SqlKind kind = call.getAggregation().getKind(); + if (java.util.Set.of(SqlKind.STDDEV_POP, SqlKind.STDDEV_SAMP, SqlKind.VAR_POP, SqlKind.VAR_SAMP) + .contains(kind)) { + input = transformInputForStdDevVariance(input, call); + } + Optional invocation = aggregateFunctionConverter.convert( input, inputType, call, t -> t.accept(rexExpressionConverter)); @@ -427,6 +463,92 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal return builder.build(); } + /** + * Transforms the input relation by conditionally casting aggregate argument fields to DOUBLE + * (FP64) type for statistical aggregate functions like STDDEV_SAMP, STDDEV_POP, VAR_SAMP, and + * VAR_POP. + * + *

This transformation is necessary because statistical aggregate functions require numeric + * inputs to be in a consistent floating-point format for accurate computation. + * + *

Optimization: If all fields referenced by the aggregate call's argument list are + * already of a single floating-point type (FP32 or FP64), the input relation is returned + * unchanged without creating a projection. This avoids unnecessary casting when the data is + * already in an acceptable floating-point format. + * + *

Transformation logic: When casting is required, the method creates a LogicalProject + * that processes all fields in the input relation: + * + *

    + *
  • Fields referenced by the aggregate call: Fields whose indices are in {@code + * call.getArgList()} are conditionally cast: + *
      + *
    • Fields already matching FP64 (ignoring nullability) are passed through unchanged + *
    • All other fields are cast to DOUBLE (FP64) using {@code makeCast} + *
    + *
  • Fields not referenced by the aggregate call: All other fields are passed through + * unchanged using {@code makeInputRef}, regardless of their type + *
+ * + *

Implementation details: + * + *

    + *
  • The returned LogicalProject preserves the original field names from the input relation + *
  • Empty hints are used ({@code Collections.emptyList()}) + *
  • Empty variable substitutions are used ({@code Collections.emptySet()}) + *
  • Type checking uses {@code TypeCreator.NULLABLE.FP64.equalsIgnoringNullability()} on + * Substrait types converted from Calcite field types + *
+ * + * @param input the input relational node to transform + * @param call the aggregate call containing the argument list that identifies which fields need + * potential casting + * @return the original input if optimization applies, or a LogicalProject that conditionally + * casts aggregate argument fields to DOUBLE (FP64) while preserving all field names + */ + protected RelNode transformInputForStdDevVariance(RelNode input, AggregateCall call) { + List distinctTypes = + input.getRowType().getFieldList().stream() + .filter(f -> call.getArgList().contains(f.getIndex())) + .map(f -> f.getType()) + .distinct() + .collect(Collectors.toList()); + + // we do not need to cast if all referenced fields are already FP32 or FP64 + if (distinctTypes.size() == 1 + && (TypeCreator.NULLABLE.FP32.equalsIgnoringNullability( + typeConverter.toSubstrait(distinctTypes.get(0))) + || TypeCreator.NULLABLE.FP64.equalsIgnoringNullability( + typeConverter.toSubstrait(distinctTypes.get(0))))) { + return input; + } + + List castProjects = + input.getRowType().getFieldList().stream() + .map( + f -> + (call.getArgList().contains(f.getIndex()) + && !TypeCreator.NULLABLE.FP64.equalsIgnoringNullability( + typeConverter.toSubstrait(f.getType()))) + ? rexBuilder.makeCast( + typeConverter.toCalcite( + rexBuilder.getTypeFactory(), + Type.withNullability(f.getType().isNullable()).FP64), + rexBuilder.makeInputRef(input, f.getIndex())) + // passthrough all fields that do not need to be casted + : rexBuilder.makeInputRef(input, f.getIndex())) + .collect(Collectors.toList()); + RelDataType projectedRowType = + rexBuilder + .getTypeFactory() + .createStructType( + castProjects.stream().map(RexNode::getType).collect(Collectors.toList()), + input.getRowType().getFieldNames()); + + return LogicalProject.create( + input, Collections.emptyList(), castProjects, projectedRowType, Collections.emptySet()); + } + /** * Converts a Calcite {@link org.apache.calcite.rel.core.Match}. * diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 6f8918179..811753319 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -3,8 +3,9 @@ import com.google.common.collect.ImmutableList; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; +import io.substrait.expression.FunctionOption; +import io.substrait.expression.ImmutableAggregateFunctionInvocation; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.AggregateFunctions; import io.substrait.isthmus.SubstraitRelVisitor; @@ -22,6 +23,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; /** @@ -76,11 +78,24 @@ public AggregateFunctionConverter( /** * Builds a Substrait aggregate invocation from the matched call and arguments. * - * @param call wrapped aggregate call - * @param function matched Substrait function variant - * @param arguments converted arguments + *

This method constructs an {@link AggregateFunctionInvocation} with appropriate configuration + * including sort fields, invocation type (DISTINCT or ALL), and function-specific options. + * + *

Statistical Functions: For standard deviation and variance functions (STDDEV_POP, + * STDDEV_SAMP, VAR_POP, VAR_SAMP), this method automatically adds a "distribution" function + * option to distinguish between population and sample variants: + * + *

    + *
  • STDDEV_SAMP, VAR_SAMP → distribution=SAMPLE (uses n-1 denominator) + *
  • STDDEV_POP, VAR_POP → distribution=POPULATION (uses n denominator) + *
+ * + * @param call wrapped aggregate call containing the Calcite aggregate information + * @param function matched Substrait function variant from the extension catalog + * @param arguments converted function arguments * @param outputType result type of the invocation - * @return aggregate function invocation + * @return aggregate function invocation with all necessary configuration including distribution + * options for statistical functions */ @Override protected AggregateFunctionInvocation generateBinding( @@ -100,13 +115,28 @@ protected AggregateFunctionInvocation generateBinding( agg.isDistinct() ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; - return ExpressionCreator.aggregateFunction( - function, - outputType, - Expression.AggregationPhase.INITIAL_TO_RESULT, - sorts, - invocation, - arguments); + + ImmutableAggregateFunctionInvocation.Builder builder = + AggregateFunctionInvocation.builder() + .declaration(function) + .outputType(outputType) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .sort(sorts) + .invocation(invocation) + .addAllArguments(arguments); + + // Add distribution option for statistical functions based on SqlKind. + // For STDDEV_SAMP/VAR_SAMP, use "SAMPLE" distribution (n-1 denominator). + // For STDDEV_POP/VAR_POP, use "POPULATION" distribution (n denominator). + SqlKind kind = agg.getAggregation().getKind(); + if (kind == SqlKind.STDDEV_SAMP || kind == SqlKind.VAR_SAMP) { + builder.addOptions(FunctionOption.builder().name("distribution").addValues("SAMPLE").build()); + } else if (kind == SqlKind.STDDEV_POP || kind == SqlKind.VAR_POP) { + builder.addOptions( + FunctionOption.builder().name("distribution").addValues("POPULATION").build()); + } + + return builder.build(); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 34655ccb5..4cf4ad0ad 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -41,6 +41,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -181,6 +182,26 @@ public FunctionConverter( * @return matching {@link SqlOperator}, or empty if none */ public Optional getSqlOperatorFromSubstraitFunc(String key, Type outputType) { + return getSqlOperatorFromSubstraitFunc(key, outputType, java.util.Collections.emptyList()); + } + + /** + * Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction). + * + *

Given a Substrait function key (e.g., "std_dev:fp64"), output type, and function options + * (e.g., distribution: SAMPLE), this method finds the corresponding Calcite {@link SqlOperator}. + * When multiple operators match, the output type and function options are used to disambiguate. + * + *

For example, both STDDEV_POP and STDDEV_SAMP map to "std_dev:fp64", but differ in the + * "distribution" option (POPULATION vs SAMPLE). + * + * @param key the Substrait function key (function name with type signature) + * @param outputType the expected output type + * @param options the function options (e.g., distribution, rounding) + * @return the matching {@link SqlOperator}, or empty if no match found + */ + public Optional getSqlOperatorFromSubstraitFunc( + String key, Type outputType, List options) { Map resolver = getTypeBasedResolver(); Collection operators = substraitFuncKeyToSqlOperatorMap.get(key); if (operators.isEmpty()) { @@ -192,15 +213,35 @@ public Optional getSqlOperatorFromSubstraitFunc(String key, Type ou return Optional.of(operators.iterator().next()); } - // at least 2 operators. Use output type to resolve SqlOperator. + // First, filter by output type to ensure type compatibility String outputTypeStr = outputType.accept(ToTypeString.INSTANCE); - List resolvedOperators = + List typeFilteredOperators = operators.stream() .filter( operator -> resolver.containsKey(operator) && resolver.get(operator).types().contains(outputTypeStr)) .collect(Collectors.toList()); + + // If type filtering resolved to a single operator, return it + if (typeFilteredOperators.size() == 1) { + return Optional.of(typeFilteredOperators.get(0)); + } + + // Determine which operators to use for further filtering + List resolvedOperators; + if (typeFilteredOperators.isEmpty() && !options.isEmpty()) { + // If type filtering failed but we have options, try option-based filtering on all operators + // This handles cases where type resolver doesn't have entries for certain functions + resolvedOperators = filterByFunctionOptions(List.copyOf(operators), options); + } else if (typeFilteredOperators.size() > 1 && !options.isEmpty()) { + // If multiple operators remain after type filtering, apply option-based filtering + resolvedOperators = filterByFunctionOptions(typeFilteredOperators, options); + } else { + // Use type-filtered results (may be empty, single, or multiple) + resolvedOperators = typeFilteredOperators; + } + // only one SqlOperator is possible if (resolvedOperators.size() == 1) { return Optional.of(resolvedOperators.get(0)); @@ -213,6 +254,56 @@ public Optional getSqlOperatorFromSubstraitFunc(String key, Type ou return Optional.empty(); } + /** + * Filters SqlOperators based on function options. + * + *

For statistical functions like STDDEV and VAR, the "distribution" option determines whether + * to use the population or sample variant: + * + *

    + *
  • distribution=POPULATION → STDDEV_POP, VAR_POP + *
  • distribution=SAMPLE → STDDEV_SAMP, VAR_SAMP + *
+ * + * @param operators the list of candidate SqlOperators + * @param options the function options from Substrait + * @return filtered list of SqlOperators matching the options + */ + private List filterByFunctionOptions( + List operators, List options) { + if (options == null || options.isEmpty()) { + return operators; + } + + // Extract distribution option if present + Optional distribution = + options.stream() + .filter(opt -> "distribution".equals(opt.getName())) + .flatMap(opt -> opt.values().stream()) + .findFirst(); + + if (!distribution.isPresent()) { + return operators; + } + + String distributionValue = distribution.get(); + return operators.stream() + .filter( + operator -> { + SqlKind kind = operator.getKind(); + // Match distribution option to SqlKind + if ("POPULATION".equals(distributionValue)) { + return kind == SqlKind.STDDEV_POP || kind == SqlKind.VAR_POP; + } else if ("SAMPLE".equals(distributionValue)) { + return kind == SqlKind.STDDEV_SAMP || kind == SqlKind.VAR_SAMP; + } + throw new IllegalArgumentException( + String.format( + "Unknown distribution value '%s' for operator %s", distributionValue, kind)); + }) + .collect(Collectors.toList()); + } + /** * Returns the resolver used to disambiguate Calcite operators by output type. * diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 85f44a312..2daf75e92 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -155,7 +155,19 @@ public class FunctionMappings { s(AggregateFunctions.SUM0, "sum0"), s(SqlStdOperatorTable.COUNT, "count"), s(SqlStdOperatorTable.APPROX_COUNT_DISTINCT, "approx_count_distinct"), - s(AggregateFunctions.AVG, "avg")) + s(AggregateFunctions.AVG, "avg"), + /* + * Substrait std_dev and variance functions use a 'distribution' option + * (SAMPLE or POPULATION) to distinguish between population and sample variants. + * The function invocation logic must set this option based on the SqlKind. + * + * Note: Standard Calcite operators (SqlStdOperatorTable.STDDEV_SAMP, etc.) are + * automatically converted to these Substrait variants via toSubstraitAggVariant(). + */ + s(AggregateFunctions.STDDEV_POP, "std_dev"), + s(AggregateFunctions.STDDEV_SAMP, "std_dev"), + s(AggregateFunctions.VAR_POP, "variance"), + s(AggregateFunctions.VAR_SAMP, "variance")) .build(); /** Window function signatures (including supported aggregates) mapped to Substrait names. */ diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java index e3082f55b..04459ae7e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java @@ -45,6 +45,14 @@ private Aggregate.Measure functionPicker(Rel input, int field, String fname) { return sb.sum0(input, field); case "avg": return sb.avg(input, field); + case "stddev_pop": + return sb.stddevPopulation(input, field); + case "stddev_samp": + return sb.stddevSample(input, field); + case "var_pop": + return sb.variancePopulation(input, field); + case "var_samp": + return sb.varianceSample(input, field); default: throw new UnsupportedOperationException( String.format("no function is associated with %s", fname)); @@ -54,14 +62,41 @@ private Aggregate.Measure functionPicker(Rel input, int field, String fname) { // Create one function call per numeric type column private List functions(Rel input, String fname) { // first column is for grouping, skip it + // Statistical functions (stddev_*, var_*) only support floating-point types in Substrait. + // This filtering ensures we only test with fp32 and fp64 types for these functions, + // avoiding type mismatch errors during round-trip conversion. + boolean isStatisticalFunction = fname.startsWith("stddev_") || fname.startsWith("var_"); return IntStream.range(1, tableTypes.size()) .boxed() + .filter( + index -> { + if (!isStatisticalFunction) { + return true; // All numeric types for non-statistical functions + } + // Only floating-point types for statistical functions + Type type = tableTypes.get(index); + return type.equals(R.FP32) + || type.equals(R.FP64) + || type.equals(N.FP32) + || type.equals(N.FP64); + }) .map(index -> functionPicker(input, index, fname)) .collect(Collectors.toList()); } @ParameterizedTest - @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) + @ValueSource( + strings = { + "max", + "min", + "sum", + "sum0", + "avg", + "stddev_pop", + "stddev_samp", + "var_pop", + "var_samp" + }) void emptyGrouping(String aggFunction) { Aggregate rel = sb.aggregate( @@ -70,7 +105,18 @@ void emptyGrouping(String aggFunction) { } @ParameterizedTest - @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) + @ValueSource( + strings = { + "max", + "min", + "sum", + "sum0", + "avg", + "stddev_pop", + "stddev_samp", + "var_pop", + "var_samp" + }) void withGrouping(String aggFunction) { Aggregate rel = sb.aggregate( diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java index d39cc2084..142eb78bf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java @@ -9,6 +9,8 @@ import io.substrait.isthmus.expression.FunctionConverter.FunctionFinder; import java.util.List; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlAvgAggFunction; import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.jupiter.api.Test; @@ -34,4 +36,84 @@ void testFunctionFinderMatch() { assertEquals("sum0", functionFinder.getSubstraitName()); assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator()); } + + @Test + void testStddevPopFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.STDDEV_POP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("std_dev", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.STDDEV_POP, functionFinder.getOperator()); + } + + @Test + void testStddevSampFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.STDDEV_SAMP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("std_dev", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.STDDEV_SAMP, functionFinder.getOperator()); + } + + @Test + void testVarPopFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.VAR_POP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("variance", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.VAR_POP, functionFinder.getOperator()); + } + + @Test + void testVarSampFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.VAR_SAMP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("variance", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.VAR_SAMP, functionFinder.getOperator()); + } } From a134a2d101dad7051abf73baaf08f41371bdc598 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Mon, 15 Jun 2026 19:41:13 +0200 Subject: [PATCH 2/4] fix(isthmus): map std_dev/variance via enum-arg signatures Use the non-deprecated std_dev/variance signatures that carry the SAMPLE/POPULATION distinction as a leading "distribution" enum argument (std_dev:req_fp64 etc.) instead of the now-deprecated function option. During Calcite -> Substrait conversion the distribution enum operand is synthesized so the generic function matcher resolves the enum-arg variant and builds the EnumArg; the reverse direction disambiguates the Calcite operator from that argument. A shared StatisticalDistribution enum is added in :core so the DSL builder and isthmus share one source of truth. Resolves #803 Signed-off-by: Niels Pardon --- .../io/substrait/dsl/SubstraitBuilder.java | 51 +++++---- .../expression/StatisticalDistribution.java | 17 +++ .../substrait/isthmus/AggregateFunctions.java | 8 +- .../isthmus/PreCalciteAggregateValidator.java | 17 ++- .../isthmus/SubstraitRelNodeConverter.java | 7 +- .../AggregateFunctionConverter.java | 107 +++++++++++------- .../isthmus/expression/EnumConverter.java | 19 ++++ .../isthmus/expression/FunctionConverter.java | 86 +++++++------- .../isthmus/expression/FunctionMappings.java | 6 +- .../isthmus/StatisticalFunctionTest.java | 84 ++++++++++++++ 10 files changed, 286 insertions(+), 116 deletions(-) create mode 100644 core/src/main/java/io/substrait/expression/StatisticalDistribution.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 254a2edf9..3ca5093dd 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -1,6 +1,7 @@ package io.substrait.dsl; import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; import io.substrait.expression.Expression.Cast; import io.substrait.expression.Expression.FailureBehavior; @@ -13,6 +14,7 @@ import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; +import io.substrait.expression.StatisticalDistribution; import io.substrait.expression.WindowBound; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; @@ -1371,7 +1373,7 @@ public Aggregate.Measure sum0(Expression expr) { * @param input the input relation containing the field * @param field the zero-based index of the field to aggregate * @return an aggregate measure computing population standard deviation with - * distribution=POPULATION option + * distribution=POPULATION enum argument */ public Aggregate.Measure stddevPopulation(Rel input, int field) { return stddevPopulation(fieldReference(input, field)); @@ -1388,7 +1390,7 @@ public Aggregate.Measure stddevPopulation(Rel input, int field) { * *
    *
  • Function: Substrait's "std_dev" from the arithmetic extension - *
  • Option: distribution=POPULATION + *
  • Argument: distribution=POPULATION (enum argument) *
  • Output type: nullable version of the input expression type *
  • Aggregation phase: INITIAL_TO_RESULT *
  • Invocation: ALL (processes all rows) @@ -1398,7 +1400,7 @@ public Aggregate.Measure stddevPopulation(Rel input, int field) { * @return an aggregate measure computing population standard deviation */ public Aggregate.Measure stddevPopulation(Expression expr) { - return statisticalAggregate(expr, "std_dev", "POPULATION"); + return statisticalAggregate(expr, "std_dev", StatisticalDistribution.POPULATION); } /** @@ -1410,8 +1412,8 @@ public Aggregate.Measure stddevPopulation(Expression expr) { * * @param input the input relation containing the field * @param field the zero-based index of the field to aggregate - * @return an aggregate measure computing sample standard deviation with distribution=SAMPLE - * option + * @return an aggregate measure computing sample standard deviation with distribution=SAMPLE enum + * argument */ public Aggregate.Measure stddevSample(Rel input, int field) { return stddevSample(fieldReference(input, field)); @@ -1428,7 +1430,7 @@ public Aggregate.Measure stddevSample(Rel input, int field) { * *
      *
    • Function: Substrait's "std_dev" from the arithmetic extension - *
    • Option: distribution=SAMPLE + *
    • Argument: distribution=SAMPLE (enum argument) *
    • Output type: nullable version of the input expression type *
    • Aggregation phase: INITIAL_TO_RESULT *
    • Invocation: ALL (processes all rows) @@ -1438,7 +1440,7 @@ public Aggregate.Measure stddevSample(Rel input, int field) { * @return an aggregate measure computing sample standard deviation */ public Aggregate.Measure stddevSample(Expression expr) { - return statisticalAggregate(expr, "std_dev", "SAMPLE"); + return statisticalAggregate(expr, "std_dev", StatisticalDistribution.SAMPLE); } /** @@ -1449,7 +1451,8 @@ public Aggregate.Measure stddevSample(Expression expr) { * * @param input the input relation containing the field * @param field the zero-based index of the field to aggregate - * @return an aggregate measure computing population variance with distribution=POPULATION option + * @return an aggregate measure computing population variance with distribution=POPULATION enum + * argument */ public Aggregate.Measure variancePopulation(Rel input, int field) { return variancePopulation(fieldReference(input, field)); @@ -1465,7 +1468,7 @@ public Aggregate.Measure variancePopulation(Rel input, int field) { * *
        *
      • Function: Substrait's "variance" from the arithmetic extension - *
      • Option: distribution=POPULATION + *
      • Argument: distribution=POPULATION (enum argument) *
      • Output type: nullable version of the input expression type *
      • Aggregation phase: INITIAL_TO_RESULT *
      • Invocation: ALL (processes all rows) @@ -1475,7 +1478,7 @@ public Aggregate.Measure variancePopulation(Rel input, int field) { * @return an aggregate measure computing population variance */ public Aggregate.Measure variancePopulation(Expression expr) { - return statisticalAggregate(expr, "variance", "POPULATION"); + return statisticalAggregate(expr, "variance", StatisticalDistribution.POPULATION); } /** @@ -1486,7 +1489,7 @@ public Aggregate.Measure variancePopulation(Expression expr) { * * @param input the input relation containing the field * @param field the zero-based index of the field to aggregate - * @return an aggregate measure computing sample variance with distribution=SAMPLE option + * @return an aggregate measure computing sample variance with distribution=SAMPLE enum argument */ public Aggregate.Measure varianceSample(Rel input, int field) { return varianceSample(fieldReference(input, field)); @@ -1502,7 +1505,7 @@ public Aggregate.Measure varianceSample(Rel input, int field) { * *
          *
        • Function: Substrait's "variance" from the arithmetic extension - *
        • Option: distribution=SAMPLE + *
        • Argument: distribution=SAMPLE (enum argument) *
        • Output type: nullable version of the input expression type *
        • Aggregation phase: INITIAL_TO_RESULT *
        • Invocation: ALL (processes all rows) @@ -1512,34 +1515,36 @@ public Aggregate.Measure varianceSample(Rel input, int field) { * @return an aggregate measure computing sample variance */ public Aggregate.Measure varianceSample(Expression expr) { - return statisticalAggregate(expr, "variance", "SAMPLE"); + return statisticalAggregate(expr, "variance", StatisticalDistribution.SAMPLE); } /** - * Helper method to create statistical aggregate measures (std_dev, variance) with distribution - * option. + * Helper method to create statistical aggregate measures (std_dev, variance) with a {@code + * distribution} enum argument. + * + *

          Uses the non-deprecated function signatures that carry the population/sample distinction as + * a leading {@code distribution} {@link EnumArg} (e.g. {@code std_dev:req_fp64}). * * @param expr the expression to aggregate * @param functionName the Substrait function name ("std_dev" or "variance") - * @param distribution the distribution type ("SAMPLE" or "POPULATION") - * @return an aggregate measure with the specified distribution option + * @param distribution the distribution type (SAMPLE or POPULATION) + * @return an aggregate measure with the specified distribution argument */ private Aggregate.Measure statisticalAggregate( - Expression expr, String functionName, String distribution) { + Expression expr, String functionName, StatisticalDistribution distribution) { String typeString = ToTypeString.apply(expr.getType()); SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, - String.format("%s:%s", functionName, typeString))); - FunctionOption distributionOption = - FunctionOption.builder().name("distribution").addValues(distribution).build(); + String.format("%s:req_%s", functionName, typeString))); + EnumArg distributionArg = + EnumArg.of((SimpleExtension.EnumArgument) declaration.args().get(0), distribution.name()); return measure( AggregateFunctionInvocation.builder() - .arguments(Arrays.asList(expr)) + .arguments(Arrays.asList(distributionArg, expr)) .outputType(TypeCreator.asNullable(expr.getType())) .declaration(declaration) - .addOptions(distributionOption) .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) .invocation(Expression.AggregationInvocation.ALL) .build()); diff --git a/core/src/main/java/io/substrait/expression/StatisticalDistribution.java b/core/src/main/java/io/substrait/expression/StatisticalDistribution.java new file mode 100644 index 000000000..62d223024 --- /dev/null +++ b/core/src/main/java/io/substrait/expression/StatisticalDistribution.java @@ -0,0 +1,17 @@ +package io.substrait.expression; + +/** + * The {@code distribution} enum argument of the Substrait {@code std_dev} and {@code variance} + * aggregate functions. + * + *

          Distinguishes between the sample (n-1 denominator, Bessel's correction) and population (n + * denominator) variants. The enum constant names match the Substrait extension's enum option names + * ({@code SAMPLE} / {@code POPULATION}), so {@link #name()} yields the value used to build an + * {@link EnumArg}. + */ +public enum StatisticalDistribution { + /** Sample distribution (uses the n-1 denominator, Bessel's correction). */ + SAMPLE, + /** Population distribution (uses the n denominator). */ + POPULATION +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index ffa55b652..4d7ea1bb9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -31,25 +31,25 @@ public class AggregateFunctions { /** * Standard deviation (population) aggregate function. Maps to Substrait's std_dev function with - * distribution=POPULATION option. + * distribution=POPULATION enum argument. */ public static SqlAggFunction STDDEV_POP = new SubstraitAvgAggFunction(SqlKind.STDDEV_POP); /** * Standard deviation (sample) aggregate function. Maps to Substrait's std_dev function with - * distribution=SAMPLE option. + * distribution=SAMPLE enum argument. */ public static SqlAggFunction STDDEV_SAMP = new SubstraitAvgAggFunction(SqlKind.STDDEV_SAMP); /** * Variance (population) aggregate function. Maps to Substrait's variance function with - * distribution=POPULATION option. + * distribution=POPULATION enum argument. */ public static SqlAggFunction VAR_POP = new SubstraitAvgAggFunction(SqlKind.VAR_POP); /** * Variance (sample) aggregate function. Maps to Substrait's variance function with - * distribution=SAMPLE option. + * distribution=SAMPLE enum argument. */ public static SqlAggFunction VAR_SAMP = new SubstraitAvgAggFunction(SqlKind.VAR_SAMP); diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java index 705e02e3c..f8695db64 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java @@ -46,8 +46,11 @@ public static boolean isValidCalciteAggregate(Aggregate aggregate) { */ private static boolean isValidCalciteMeasure(Aggregate.Measure measure) { return - // all function arguments to measures must be field references - measure.getFunction().arguments().stream().allMatch(farg -> isSimpleFieldReference(farg)) + // all value (Expression) function arguments to measures must be field references; non-value + // arguments such as the std_dev/variance "distribution" enum argument are exempt + measure.getFunction().arguments().stream() + .filter(farg -> farg instanceof Expression) + .allMatch(farg -> isSimpleFieldReference(farg)) && // all sort fields must be field references measure.getFunction().sort().stream().allMatch(sf -> isSimpleFieldReference(sf.expr())) @@ -157,9 +160,9 @@ public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) { private Aggregate.Measure updateMeasure(Aggregate.Measure measure) { AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction(); - List newFunctionArgs = + List newFunctionArgs = oldAggregateFunctionInvocation.arguments().stream() - .map(this::projectOutNonFieldReference) + .map(this::projectOutNonFieldReferenceArg) .collect(Collectors.toList()); List newSortFields = @@ -194,11 +197,13 @@ private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) { return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build(); } - private Expression projectOutNonFieldReference(FunctionArg farg) { + private FunctionArg projectOutNonFieldReferenceArg(FunctionArg farg) { if ((farg instanceof Expression)) { return projectOutNonFieldReference((Expression) farg); } else { - throw new IllegalArgumentException("cannot handle non-expression argument for aggregate"); + // Non-value arguments (e.g. the std_dev/variance "distribution" enum argument) are not + // field references and are passed through unchanged. + return farg; } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index ab1807066..2d42b9f35 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -384,8 +384,11 @@ public RelNode visit(Aggregate aggregate, Context context) throws RuntimeExcepti private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { List eArgs = measure.getFunction().arguments(); + // Only value (Expression) arguments map to Calcite aggregate operands. Enum arguments such as + // the std_dev/variance "distribution" are used to disambiguate the operator, not as operands. List arguments = - IntStream.range(0, measure.getFunction().arguments().size()) + IntStream.range(0, eArgs.size()) + .filter(i -> eArgs.get(i) instanceof Expression) .mapToObj( i -> eArgs @@ -400,7 +403,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc( measure.getFunction().declaration().key(), measure.getFunction().outputType(), - measure.getFunction().options()); + measure.getFunction().arguments()); if (!operator.isPresent()) { throw new IllegalArgumentException( String.format( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 811753319..0b9b7e63d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -3,14 +3,15 @@ import com.google.common.collect.ImmutableList; import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; -import io.substrait.expression.FunctionOption; -import io.substrait.expression.ImmutableAggregateFunctionInvocation; +import io.substrait.expression.StatisticalDistribution; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.AggregateFunctions; import io.substrait.isthmus.SubstraitRelVisitor; import io.substrait.isthmus.TypeConverter; import io.substrait.type.Type; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -79,23 +80,20 @@ public AggregateFunctionConverter( * Builds a Substrait aggregate invocation from the matched call and arguments. * *

          This method constructs an {@link AggregateFunctionInvocation} with appropriate configuration - * including sort fields, invocation type (DISTINCT or ALL), and function-specific options. + * including sort fields and invocation type (DISTINCT or ALL). * *

          Statistical Functions: For standard deviation and variance functions (STDDEV_POP, - * STDDEV_SAMP, VAR_POP, VAR_SAMP), this method automatically adds a "distribution" function - * option to distinguish between population and sample variants: - * - *

            - *
          • STDDEV_SAMP, VAR_SAMP → distribution=SAMPLE (uses n-1 denominator) - *
          • STDDEV_POP, VAR_POP → distribution=POPULATION (uses n denominator) - *
          + * STDDEV_SAMP, VAR_POP, VAR_SAMP), the population/sample distinction is carried by a leading + * {@code distribution} {@link io.substrait.expression.EnumArg} argument. That argument is + * synthesized as an operand in {@link #convert} so that the generic matcher resolves the enum-arg + * function variant ({@code std_dev:req_*} / {@code variance:req_*}) and constructs the {@link + * io.substrait.expression.EnumArg}; no special handling is required here. * * @param call wrapped aggregate call containing the Calcite aggregate information * @param function matched Substrait function variant from the extension catalog * @param arguments converted function arguments * @param outputType result type of the invocation - * @return aggregate function invocation with all necessary configuration including distribution - * options for statistical functions + * @return aggregate function invocation */ @Override protected AggregateFunctionInvocation generateBinding( @@ -116,27 +114,13 @@ protected AggregateFunctionInvocation generateBinding( ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; - ImmutableAggregateFunctionInvocation.Builder builder = - AggregateFunctionInvocation.builder() - .declaration(function) - .outputType(outputType) - .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) - .sort(sorts) - .invocation(invocation) - .addAllArguments(arguments); - - // Add distribution option for statistical functions based on SqlKind. - // For STDDEV_SAMP/VAR_SAMP, use "SAMPLE" distribution (n-1 denominator). - // For STDDEV_POP/VAR_POP, use "POPULATION" distribution (n denominator). - SqlKind kind = agg.getAggregation().getKind(); - if (kind == SqlKind.STDDEV_SAMP || kind == SqlKind.VAR_SAMP) { - builder.addOptions(FunctionOption.builder().name("distribution").addValues("SAMPLE").build()); - } else if (kind == SqlKind.STDDEV_POP || kind == SqlKind.VAR_POP) { - builder.addOptions( - FunctionOption.builder().name("distribution").addValues("POPULATION").build()); - } - - return builder.build(); + return ExpressionCreator.aggregateFunction( + function, + outputType, + Expression.AggregationPhase.INITIAL_TO_RESULT, + sorts, + invocation, + arguments); } /** @@ -158,14 +142,50 @@ public Optional convert( if (m == null) { return Optional.empty(); } - if (!m.allowedArgCount(call.getArgList().size())) { + + // For statistical aggregates (std_dev/variance) the SAMPLE/POPULATION distinction is carried + // by a leading "distribution" enum argument. Synthesize it as an operand so the generic matcher + // resolves the enum-arg function variant and builds the EnumArg. + List leadingArgs = leadingEnumArgs(call); + if (!m.allowedArgCount(call.getArgList().size() + leadingArgs.size())) { return Optional.empty(); } - WrappedAggregateCall wrapped = new WrappedAggregateCall(call, input, rexBuilder, inputType); + WrappedAggregateCall wrapped = + new WrappedAggregateCall(call, leadingArgs, input, rexBuilder, inputType); return m.attemptMatch(wrapped, topLevelConverter); } + /** + * Computes the synthetic leading operands to prepend to a Calcite aggregate call before matching. + * + *

          For standard deviation and variance functions, Substrait carries the population/sample + * distinction as a leading {@code distribution} enum argument, whereas Calcite encodes it in the + * operator's {@link SqlKind}. This returns the matching {@link StatisticalDistribution} flag so + * the generic matcher selects the {@code std_dev:req_*} / {@code variance:req_*} variant and + * constructs the corresponding {@link io.substrait.expression.EnumArg}. + * + * @param call the Calcite aggregate call + * @return the leading enum operands (a single distribution flag for statistical functions, empty + * otherwise) + */ + private List leadingEnumArgs(AggregateCall call) { + List leadingArgs = new ArrayList<>(); + switch (call.getAggregation().getKind()) { + case STDDEV_SAMP: + case VAR_SAMP: + leadingArgs.add(rexBuilder.makeFlag(StatisticalDistribution.SAMPLE)); + break; + case STDDEV_POP: + case VAR_POP: + leadingArgs.add(rexBuilder.makeFlag(StatisticalDistribution.POPULATION)); + break; + default: + break; + } + return leadingArgs; + } + /** * Resolves the appropriate function finder, applying Substrait-specific variants when needed. * @@ -190,6 +210,7 @@ protected FunctionFinder getFunctionFinder(AggregateCall call) { /** Lightweight wrapper around {@link AggregateCall} providing operands and type access. */ static class WrappedAggregateCall implements FunctionConverter.GenericCall { private final AggregateCall call; + private final List leadingArgs; private final RelNode input; private final RexBuilder rexBuilder; private final Type.Struct inputType; @@ -198,26 +219,36 @@ static class WrappedAggregateCall implements FunctionConverter.GenericCall { * Creates a new wrapped aggregate call. * * @param call underlying Calcite aggregate call + * @param leadingArgs synthetic operands (e.g. a {@code distribution} enum flag) prepended ahead + * of the field arguments during matching * @param input input relational node * @param rexBuilder Rex builder for operand construction * @param inputType Substrait input struct type */ private WrappedAggregateCall( - AggregateCall call, RelNode input, RexBuilder rexBuilder, Type.Struct inputType) { + AggregateCall call, + List leadingArgs, + RelNode input, + RexBuilder rexBuilder, + Type.Struct inputType) { this.call = call; + this.leadingArgs = leadingArgs; this.input = input; this.rexBuilder = rexBuilder; this.inputType = inputType; } /** - * Returns operands as input references over the argument list. + * Returns operands as the synthetic leading operands followed by input references over the + * argument list. * * @return stream of RexNode operands */ @Override public Stream getOperands() { - return call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r)); + return Stream.concat( + leadingArgs.stream(), + call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r))); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index 2f95ca276..2fef9f683 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -1,6 +1,7 @@ package io.substrait.isthmus.expression; import io.substrait.expression.EnumArg; +import io.substrait.expression.StatisticalDistribution; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.extension.SimpleExtension.Argument; @@ -78,6 +79,20 @@ public class EnumConverter { calciteEnumMap.put( argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 1), ExtractIndexing.class); + + // std_dev and variance carry the SAMPLE/POPULATION distinction as a leading enum argument. + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "std_dev:req_fp32", 0), + StatisticalDistribution.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "std_dev:req_fp64", 0), + StatisticalDistribution.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "variance:req_fp32", 0), + StatisticalDistribution.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "variance:req_fp64", 0), + StatisticalDistribution.class); } private static Optional> constructValue( @@ -90,6 +105,10 @@ private static Optional> constructValue( return option.get().map(SqlTrimFunction.Flag::valueOf); } + if (cls.isAssignableFrom(StatisticalDistribution.class)) { + return option.get().map(StatisticalDistribution::valueOf); + } + // ExtractIndexing does not need to be converted here. Calcite // doesn't have the concept of the indexing. It's date // functions are all indexed from 1 diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 4cf4ad0ad..21f48b74f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -10,6 +10,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; +import io.substrait.expression.StatisticalDistribution; import io.substrait.extension.SimpleExtension; import io.substrait.extension.SimpleExtension.Argument; import io.substrait.function.ParameterizedType; @@ -188,20 +189,21 @@ public Optional getSqlOperatorFromSubstraitFunc(String key, Type ou /** * Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction). * - *

          Given a Substrait function key (e.g., "std_dev:fp64"), output type, and function options - * (e.g., distribution: SAMPLE), this method finds the corresponding Calcite {@link SqlOperator}. - * When multiple operators match, the output type and function options are used to disambiguate. + *

          Given a Substrait function key (e.g., "std_dev:req_fp64"), output type, and function + * arguments (which may include a {@code distribution} {@link io.substrait.expression.EnumArg}), + * this method finds the corresponding Calcite {@link SqlOperator}. When multiple operators match, + * the output type and the {@code distribution} enum argument are used to disambiguate. * - *

          For example, both STDDEV_POP and STDDEV_SAMP map to "std_dev:fp64", but differ in the - * "distribution" option (POPULATION vs SAMPLE). + *

          For example, both STDDEV_POP and STDDEV_SAMP map to "std_dev:req_fp64", but differ in the + * {@code distribution} enum argument (POPULATION vs SAMPLE). * * @param key the Substrait function key (function name with type signature) * @param outputType the expected output type - * @param options the function options (e.g., distribution, rounding) + * @param arguments the function arguments (used to read the {@code distribution} enum argument) * @return the matching {@link SqlOperator}, or empty if no match found */ public Optional getSqlOperatorFromSubstraitFunc( - String key, Type outputType, List options) { + String key, Type outputType, List arguments) { Map resolver = getTypeBasedResolver(); Collection operators = substraitFuncKeyToSqlOperatorMap.get(key); if (operators.isEmpty()) { @@ -229,14 +231,16 @@ public Optional getSqlOperatorFromSubstraitFunc( } // Determine which operators to use for further filtering + Optional distribution = distributionArgument(arguments); List resolvedOperators; - if (typeFilteredOperators.isEmpty() && !options.isEmpty()) { - // If type filtering failed but we have options, try option-based filtering on all operators - // This handles cases where type resolver doesn't have entries for certain functions - resolvedOperators = filterByFunctionOptions(List.copyOf(operators), options); - } else if (typeFilteredOperators.size() > 1 && !options.isEmpty()) { - // If multiple operators remain after type filtering, apply option-based filtering - resolvedOperators = filterByFunctionOptions(typeFilteredOperators, options); + if (typeFilteredOperators.isEmpty() && distribution.isPresent()) { + // If type filtering failed but we have a distribution argument, try distribution-based + // filtering on all operators. This handles functions (e.g. std_dev/variance) that the type + // resolver has no entries for. + resolvedOperators = filterByDistribution(List.copyOf(operators), distribution.get()); + } else if (typeFilteredOperators.size() > 1 && distribution.isPresent()) { + // If multiple operators remain after type filtering, apply distribution-based filtering + resolvedOperators = filterByDistribution(typeFilteredOperators, distribution.get()); } else { // Use type-filtered results (may be empty, single, or multiple) resolvedOperators = typeFilteredOperators; @@ -255,10 +259,28 @@ public Optional getSqlOperatorFromSubstraitFunc( } /** - * Filters SqlOperators based on function options. + * Extracts the value of the {@code distribution} enum argument, if present. * - *

          For statistical functions like STDDEV and VAR, the "distribution" option determines whether - * to use the population or sample variant: + * @param arguments the Substrait function arguments + * @return the distribution value (e.g. {@code SAMPLE} / {@code POPULATION}) if a {@code + * distribution} {@link io.substrait.expression.EnumArg} is present + */ + private static Optional distributionArgument(List arguments) { + if (arguments == null) { + return Optional.empty(); + } + return arguments.stream() + .filter(arg -> arg instanceof io.substrait.expression.EnumArg) + .map(arg -> (io.substrait.expression.EnumArg) arg) + .flatMap(arg -> arg.value().stream()) + .findFirst(); + } + + /** + * Filters SqlOperators based on the {@code distribution} enum argument. + * + *

          For statistical functions like STDDEV and VAR, the {@code distribution} argument determines + * whether to use the population or sample variant: * *

            *
          • distribution=POPULATION → STDDEV_POP, VAR_POP @@ -266,35 +288,19 @@ public Optional getSqlOperatorFromSubstraitFunc( *
          * * @param operators the list of candidate SqlOperators - * @param options the function options from Substrait - * @return filtered list of SqlOperators matching the options + * @param distributionValue the distribution value from the Substrait enum argument + * @return filtered list of SqlOperators matching the distribution */ - private List filterByFunctionOptions( - List operators, List options) { - if (options == null || options.isEmpty()) { - return operators; - } - - // Extract distribution option if present - Optional distribution = - options.stream() - .filter(opt -> "distribution".equals(opt.getName())) - .flatMap(opt -> opt.values().stream()) - .findFirst(); - - if (!distribution.isPresent()) { - return operators; - } - - String distributionValue = distribution.get(); + private List filterByDistribution( + List operators, String distributionValue) { return operators.stream() .filter( operator -> { SqlKind kind = operator.getKind(); - // Match distribution option to SqlKind - if ("POPULATION".equals(distributionValue)) { + // Match distribution value to SqlKind + if (StatisticalDistribution.POPULATION.name().equals(distributionValue)) { return kind == SqlKind.STDDEV_POP || kind == SqlKind.VAR_POP; - } else if ("SAMPLE".equals(distributionValue)) { + } else if (StatisticalDistribution.SAMPLE.name().equals(distributionValue)) { return kind == SqlKind.STDDEV_SAMP || kind == SqlKind.VAR_SAMP; } throw new IllegalArgumentException( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 2daf75e92..4c7ca3023 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -157,9 +157,9 @@ public class FunctionMappings { s(SqlStdOperatorTable.APPROX_COUNT_DISTINCT, "approx_count_distinct"), s(AggregateFunctions.AVG, "avg"), /* - * Substrait std_dev and variance functions use a 'distribution' option - * (SAMPLE or POPULATION) to distinguish between population and sample variants. - * The function invocation logic must set this option based on the SqlKind. + * Substrait std_dev and variance functions use a leading 'distribution' enum + * argument (SAMPLE or POPULATION) to distinguish between population and sample + * variants. AggregateFunctionConverter synthesizes that argument based on the SqlKind. * * Note: Standard Calcite operators (SqlStdOperatorTable.STDDEV_SAMP, etc.) are * automatically converted to these Substrait variants via toSubstraitAggVariant(). diff --git a/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java new file mode 100644 index 000000000..190401744 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java @@ -0,0 +1,84 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.EnumArg; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; +import io.substrait.plan.Plan; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Rel; +import java.util.List; +import java.util.Optional; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelRoot; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +/** + * Verifies that the SQL statistical aggregates (STDDEV_POP/SAMP, VAR_POP/SAMP) map to the Substrait + * {@code std_dev} / {@code variance} functions using the non-deprecated enum-argument signatures + * ({@code std_dev:req_fp64} etc.), carrying the population/sample distinction as a {@code + * distribution} {@link EnumArg} rather than a function option. + */ +class StatisticalFunctionTest extends PlanTestBase { + + static final String CREATES = + "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)"; + + @ParameterizedTest + @CsvSource({"STDDEV_POP", "STDDEV_SAMP", "VAR_POP", "VAR_SAMP"}) + void roundTrip(String fn) throws Exception { + assertFullRoundTrip(String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fn, fn), CREATES); + } + + @ParameterizedTest + @CsvSource({ + "STDDEV_POP, std_dev, POPULATION", + "STDDEV_SAMP, std_dev, SAMPLE", + "VAR_POP, variance, POPULATION", + "VAR_SAMP, variance, SAMPLE", + }) + void usesEnumArgSignature(String sqlFn, String substraitFn, String distribution) + throws Exception { + Prepare.CatalogReader catalog = + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES); + RelRoot calcite = + SubstraitSqlToCalcite.convertQuery( + String.format("SELECT %s(fp64) FROM numbers", sqlFn), + catalog, + converterProvider.getSqlOperatorTable()); + Plan.Root root = SubstraitRelVisitor.convert(calcite, converterProvider); + + AggregateFunctionInvocation function = firstMeasure(root.getInput()).getFunction(); + + // The non-deprecated enum-arg variant is used (note the "req" enum argument in the key). + assertEquals(substraitFn + ":req_fp64", function.declaration().key()); + + // The distribution is carried as a leading EnumArg, not as a function option. + List args = function.arguments(); + EnumArg distributionArg = assertInstanceOf(EnumArg.class, args.get(0)); + assertEquals(Optional.of(distribution), distributionArg.value()); + assertTrue(function.options().isEmpty(), "expected no function options"); + } + + /** Recursively finds the first {@link Aggregate} measure in the relation tree. */ + private static Aggregate.Measure firstMeasure(Rel rel) { + if (rel instanceof Aggregate) { + Aggregate aggregate = (Aggregate) rel; + if (!aggregate.getMeasures().isEmpty()) { + return aggregate.getMeasures().get(0); + } + } + for (Rel input : rel.getInputs()) { + Aggregate.Measure measure = firstMeasure(input); + if (measure != null) { + return measure; + } + } + return null; + } +} From 4dab843ab3847c1f93d351faf2fd87d3e6883be8 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Mon, 15 Jun 2026 19:56:40 +0200 Subject: [PATCH 3/4] refactor(isthmus): simplify distribution disambiguation branching Collapse the three-way branch in getSqlOperatorFromSubstraitFunc into a single conditional: when a distribution enum argument is present, narrow the candidate operators by it (falling back to all operators when output type filtering yielded none). Behavior is unchanged. Signed-off-by: Niels Pardon --- .../isthmus/expression/FunctionConverter.java | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 21f48b74f..886f69e93 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -230,20 +230,15 @@ public Optional getSqlOperatorFromSubstraitFunc( return Optional.of(typeFilteredOperators.get(0)); } - // Determine which operators to use for further filtering + // If still ambiguous and a distribution enum argument is present, disambiguate by it. + // Both the population and sample operators share one key (e.g. variance:req_fp32), since the + // SAMPLE/POPULATION value lives in the argument, not the signature. Optional distribution = distributionArgument(arguments); - List resolvedOperators; - if (typeFilteredOperators.isEmpty() && distribution.isPresent()) { - // If type filtering failed but we have a distribution argument, try distribution-based - // filtering on all operators. This handles functions (e.g. std_dev/variance) that the type - // resolver has no entries for. - resolvedOperators = filterByDistribution(List.copyOf(operators), distribution.get()); - } else if (typeFilteredOperators.size() > 1 && distribution.isPresent()) { - // If multiple operators remain after type filtering, apply distribution-based filtering - resolvedOperators = filterByDistribution(typeFilteredOperators, distribution.get()); - } else { - // Use type-filtered results (may be empty, single, or multiple) - resolvedOperators = typeFilteredOperators; + List resolvedOperators = typeFilteredOperators; + if (distribution.isPresent()) { + List candidates = + typeFilteredOperators.isEmpty() ? List.copyOf(operators) : typeFilteredOperators; + resolvedOperators = filterByDistribution(candidates, distribution.get()); } // only one SqlOperator is possible From 7d5b7a6edf8d68855374b0154807ef12ea86161f Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Mon, 15 Jun 2026 20:56:15 +0200 Subject: [PATCH 4/4] fix(isthmus): support non-floating-point std_dev/variance arguments Substrait's std_dev/variance only define fp32/fp64 signatures, so a statistical aggregate over an integer (or other non-fp) column must be cast. Previously the cast projection built in fromAggCall was discarded (only used to type the measure operand), producing an inconsistent plan where the aggregate argument was typed fp64 over an un-cast integer input. visit(Aggregate) now rewrites such aggregates at the Calcite level: it appends a cast(arg AS fp64) column to the input (leaving the original column for other aggregates over it), re-points the statistical aggregate at the appended column with its return type re-derived over fp64, and casts the results back to the type Calcite inferred via a projection on top. The rewrite is idempotent (fp32/fp64 arguments are untouched), so the recursive re-conversion terminates and the plan is stable after Calcite's project normalization. Also fold in minor cleanups to FunctionConverter: import EnumArg rather than using its fully qualified name, correct the multi-operator error message (it serves aggregates too, not only scalar functions), and document the single-distribution-argument assumption in distributionArgument. Signed-off-by: Niels Pardon --- .../isthmus/SubstraitRelVisitor.java | 223 +++++++++++------- .../isthmus/expression/FunctionConverter.java | 16 +- .../isthmus/StatisticalFunctionTest.java | 30 +++ 3 files changed, 179 insertions(+), 90 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 979f75195..94b41964f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -339,6 +339,16 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) { */ @Override public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { + // Substrait's std_dev/variance functions only define fp32/fp64 signatures. If a statistical + // aggregate has a non-floating-point argument, rewrite the aggregate to cast that argument to + // fp64 and cast the result back to the type Calcite inferred, then convert the rewritten plan + // through the normal path. The rewrite is idempotent (fp32/fp64 arguments are left untouched), + // so it terminates when the converted plan is re-converted. + RelNode rewritten = castStatisticalAggregatesToFloatingPoint(aggregate); + if (rewritten != aggregate) { + return apply(rewritten); + } + Rel input = apply(aggregate.getInput()); Stream sets; if (aggregate.groupSets != null) { @@ -442,12 +452,6 @@ Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { * Substrait representation (no matching function binding found) */ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) { - SqlKind kind = call.getAggregation().getKind(); - if (java.util.Set.of(SqlKind.STDDEV_POP, SqlKind.STDDEV_SAMP, SqlKind.VAR_POP, SqlKind.VAR_SAMP) - .contains(kind)) { - input = transformInputForStdDevVariance(input, call); - } - Optional invocation = aggregateFunctionConverter.convert( input, inputType, call, t -> t.accept(rexExpressionConverter)); @@ -463,90 +467,139 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal return builder.build(); } - /** - * Transforms the input relation by conditionally casting aggregate argument fields to DOUBLE - * (FP64) type for statistical aggregate functions like STDDEV_SAMP, STDDEV_POP, VAR_SAMP, and - * VAR_POP. - * - *

          This transformation is necessary because statistical aggregate functions require numeric - * inputs to be in a consistent floating-point format for accurate computation. - * - *

          Optimization: If all fields referenced by the aggregate call's argument list are - * already of a single floating-point type (FP32 or FP64), the input relation is returned - * unchanged without creating a projection. This avoids unnecessary casting when the data is - * already in an acceptable floating-point format. - * - *

          Transformation logic: When casting is required, the method creates a LogicalProject - * that processes all fields in the input relation: - * - *

            - *
          • Fields referenced by the aggregate call: Fields whose indices are in {@code - * call.getArgList()} are conditionally cast: - *
              - *
            • Fields already matching FP64 (ignoring nullability) are passed through unchanged - *
            • All other fields are cast to DOUBLE (FP64) using {@code makeCast} - *
            - *
          • Fields not referenced by the aggregate call: All other fields are passed through - * unchanged using {@code makeInputRef}, regardless of their type - *
          - * - *

          Implementation details: - * - *

            - *
          • The returned LogicalProject preserves the original field names from the input relation - *
          • Empty hints are used ({@code Collections.emptyList()}) - *
          • Empty variable substitutions are used ({@code Collections.emptySet()}) - *
          • Type checking uses {@code TypeCreator.NULLABLE.FP64.equalsIgnoringNullability()} on - * Substrait types converted from Calcite field types - *
          - * - * @param input the input relational node to transform - * @param call the aggregate call containing the argument list that identifies which fields need - * potential casting - * @return the original input if optimization applies, or a LogicalProject that conditionally - * casts aggregate argument fields to DOUBLE (FP64) while preserving all field names - */ - protected RelNode transformInputForStdDevVariance(RelNode input, AggregateCall call) { - List distinctTypes = - input.getRowType().getFieldList().stream() - .filter(f -> call.getArgList().contains(f.getIndex())) - .map(f -> f.getType()) - .distinct() - .collect(Collectors.toList()); + private static boolean isStatisticalDistributionAggregate(SqlKind kind) { + return kind == SqlKind.STDDEV_POP + || kind == SqlKind.STDDEV_SAMP + || kind == SqlKind.VAR_POP + || kind == SqlKind.VAR_SAMP; + } + + private boolean isFloatingPoint(RelDataType type) { + Type substraitType = typeConverter.toSubstrait(type); + return TypeCreator.NULLABLE.FP32.equalsIgnoringNullability(substraitType) + || TypeCreator.NULLABLE.FP64.equalsIgnoringNullability(substraitType); + } + + /** + * Rewrites a Calcite aggregate so that statistical aggregate functions (STDDEV_POP, STDDEV_SAMP, + * VAR_POP, VAR_SAMP) with non-floating-point arguments operate on fp64, since Substrait's {@code + * std_dev} / {@code variance} functions only define fp32 and fp64 signatures. + * + *

          For each statistical aggregate whose single argument is neither fp32 nor fp64 (e.g. an + * integer or decimal column), the rewrite: + * + *

            + *
          1. appends a {@code cast(arg AS fp64)} column to the aggregate's input (leaving the original + * column in place, so other aggregates over the same column are unaffected), + *
          2. re-points the statistical aggregate at the appended column (its return type is re-derived + * over fp64), and + *
          3. casts the aggregate's results back to the types Calcite originally inferred, via a + * projection on top, so the aggregate's output row type is preserved. + *
          + * + *

          The rewrite is idempotent: fp32/fp64 arguments are left untouched, so converting the + * rewritten plan (whose statistical arguments are already fp64) produces no further rewrite and + * the recursion in {@link #visit(org.apache.calcite.rel.core.Aggregate)} terminates. If no + * argument needs casting, the aggregate is returned unchanged. + * + * @param aggregate the Calcite aggregate to inspect + * @return {@code aggregate} unchanged, or a {@link LogicalProject} wrapping a rewritten aggregate + */ + protected RelNode castStatisticalAggregatesToFloatingPoint( + org.apache.calcite.rel.core.Aggregate aggregate) { + RelNode input = aggregate.getInput(); + List calls = aggregate.getAggCallList(); + int inputFieldCount = input.getRowType().getFieldCount(); + + // fp64 cast expressions to append to the input, and the source field each one casts (for reuse) + List appendedCasts = new ArrayList<>(); + List appendedSourceFields = new ArrayList<>(); + // per call: the appended column its argument should be re-pointed at, or -1 if unchanged + List rewrittenArgColumns = new ArrayList<>(calls.size()); + + for (AggregateCall call : calls) { + int rewrittenArgColumn = -1; + if (isStatisticalDistributionAggregate(call.getAggregation().getKind()) + && call.getArgList().size() == 1) { + int argIndex = call.getArgList().get(0); + RelDataType argType = input.getRowType().getFieldList().get(argIndex).getType(); + if (!isFloatingPoint(argType)) { + int existing = appendedSourceFields.indexOf(argIndex); + if (existing >= 0) { + rewrittenArgColumn = inputFieldCount + existing; + } else { + RelDataType fp64 = + typeConverter.toCalcite( + rexBuilder.getTypeFactory(), Type.withNullability(argType.isNullable()).FP64); + appendedCasts.add(rexBuilder.makeCast(fp64, rexBuilder.makeInputRef(input, argIndex))); + appendedSourceFields.add(argIndex); + rewrittenArgColumn = inputFieldCount + appendedCasts.size() - 1; + } + } + } + rewrittenArgColumns.add(rewrittenArgColumn); + } - // we do not need to cast if all referenced fields are already FP32 or FP64 - if (distinctTypes.size() == 1 - && (TypeCreator.NULLABLE.FP32.equalsIgnoringNullability( - typeConverter.toSubstrait(distinctTypes.get(0))) - || TypeCreator.NULLABLE.FP64.equalsIgnoringNullability( - typeConverter.toSubstrait(distinctTypes.get(0))))) { - return input; + if (appendedCasts.isEmpty()) { + return aggregate; } - List castProjects = - input.getRowType().getFieldList().stream() - .map( - f -> - (call.getArgList().contains(f.getIndex()) - && !TypeCreator.NULLABLE.FP64.equalsIgnoringNullability( - typeConverter.toSubstrait(f.getType()))) - ? rexBuilder.makeCast( - typeConverter.toCalcite( - rexBuilder.getTypeFactory(), - Type.withNullability(f.getType().isNullable()).FP64), - rexBuilder.makeInputRef(input, f.getIndex())) - // passthrough all fields that do not need to be casted - : rexBuilder.makeInputRef(input, f.getIndex())) - .collect(Collectors.toList()); - RelDataType projectedRowType = - rexBuilder - .getTypeFactory() - .createStructType( - castProjects.stream().map(RexNode::getType).collect(Collectors.toList()), - input.getRowType().getFieldNames()); + // Extended input: all original columns (passthrough) followed by the appended fp64 casts. + List inputProjects = new ArrayList<>(inputFieldCount + appendedCasts.size()); + for (int i = 0; i < inputFieldCount; i++) { + inputProjects.add(rexBuilder.makeInputRef(input, i)); + } + inputProjects.addAll(appendedCasts); + RelNode extendedInput = + LogicalProject.create(input, Collections.emptyList(), inputProjects, (List) null); + + // Re-point the statistical calls at the appended fp64 columns (return type re-derived); leave + // all other calls unchanged. + List rewrittenCalls = new ArrayList<>(calls.size()); + for (int i = 0; i < calls.size(); i++) { + AggregateCall call = calls.get(i); + int rewrittenArgColumn = rewrittenArgColumns.get(i); + if (rewrittenArgColumn < 0) { + rewrittenCalls.add(call); + } else { + rewrittenCalls.add( + AggregateCall.create( + call.getAggregation(), + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + Collections.singletonList(rewrittenArgColumn), + call.filterArg, + call.distinctKeys, + call.getCollation(), + aggregate.getGroupCount(), + extendedInput, + /* type, null to re-derive over fp64 */ null, + call.getName())); + } + } + org.apache.calcite.rel.core.Aggregate rewrittenAggregate = + aggregate.copy( + aggregate.getTraitSet(), + extendedInput, + aggregate.getGroupSet(), + aggregate.getGroupSets(), + rewrittenCalls); + + // Cast the (now fp64) statistical results back to the types Calcite originally inferred, + // preserving the aggregate's original output row type. Group keys and unaffected measures pass + // through unchanged. + RelDataType originalRowType = aggregate.getRowType(); + List outputProjects = new ArrayList<>(originalRowType.getFieldCount()); + for (int i = 0; i < originalRowType.getFieldCount(); i++) { + RelDataType targetType = originalRowType.getFieldList().get(i).getType(); + RexNode ref = rexBuilder.makeInputRef(rewrittenAggregate, i); + outputProjects.add( + ref.getType().equals(targetType) ? ref : rexBuilder.makeCast(targetType, ref)); + } return LogicalProject.create( - input, Collections.emptyList(), castProjects, projectedRowType, Collections.emptySet()); + rewrittenAggregate, Collections.emptyList(), outputProjects, originalRowType); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 886f69e93..3896fa897 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -7,6 +7,7 @@ import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import com.google.common.collect.Streams; +import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; @@ -247,7 +248,7 @@ public Optional getSqlOperatorFromSubstraitFunc( } else if (resolvedOperators.size() > 1) { throw new IllegalStateException( String.format( - "Found %d SqlOperators: %s for ScalarFunction %s: ", + "Found %d SqlOperators: %s for function %s", resolvedOperators.size(), resolvedOperators, key)); } return Optional.empty(); @@ -256,17 +257,22 @@ public Optional getSqlOperatorFromSubstraitFunc( /** * Extracts the value of the {@code distribution} enum argument, if present. * + *

          This returns the value of the first {@link EnumArg} in the argument list. It assumes the + * only enum argument that disambiguates between operators sharing a key is the {@code + * distribution} argument of {@code std_dev}/{@code variance} — the only enum-argument aggregate + * functions currently mapped. {@link #filterByDistribution} rejects values it does not recognize. + * * @param arguments the Substrait function arguments - * @return the distribution value (e.g. {@code SAMPLE} / {@code POPULATION}) if a {@code - * distribution} {@link io.substrait.expression.EnumArg} is present + * @return the distribution value (e.g. {@code SAMPLE} / {@code POPULATION}) if an {@link EnumArg} + * is present */ private static Optional distributionArgument(List arguments) { if (arguments == null) { return Optional.empty(); } return arguments.stream() - .filter(arg -> arg instanceof io.substrait.expression.EnumArg) - .map(arg -> (io.substrait.expression.EnumArg) arg) + .filter(arg -> arg instanceof EnumArg) + .map(arg -> (EnumArg) arg) .flatMap(arg -> arg.value().stream()) .findFirst(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java index 190401744..7dd03b5f5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java @@ -15,6 +15,7 @@ import java.util.Optional; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelRoot; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -35,6 +36,35 @@ void roundTrip(String fn) throws Exception { assertFullRoundTrip(String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fn, fn), CREATES); } + // Integer arguments are cast to fp64 (and the result cast back) since std_dev/variance only have + // fp32/fp64 signatures. This rewrite (castStatisticalAggregatesToFloatingPoint) inserts a cast + // projection that Calcite normalizes (project merge/column pruning) on the first round trip, so + // these use the identity-projection workaround, which asserts stability after normalization. + + @ParameterizedTest + @CsvSource({"STDDEV_POP", "STDDEV_SAMP", "VAR_POP", "VAR_SAMP"}) + void roundTripIntegerInput(String fn) throws Exception { + assertFullRoundTripWithIdentityProjectionWorkaround( + String.format("SELECT %s(i32) FROM numbers", fn), + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + } + + @Test + void roundTripIntegerInputSharedWithOtherAggregate() throws Exception { + // The integer column is shared by SUM (which must keep operating on the integer) and STDDEV_POP + // (which is cast to fp64); the cast must be appended, not applied in place. + assertFullRoundTripWithIdentityProjectionWorkaround( + "SELECT SUM(i32), STDDEV_POP(i32) FROM numbers", + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + } + + @Test + void roundTripIntegerInputWithGrouping() throws Exception { + assertFullRoundTripWithIdentityProjectionWorkaround( + "SELECT i8, VAR_POP(i32) FROM numbers GROUP BY i8", + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + } + @ParameterizedTest @CsvSource({ "STDDEV_POP, std_dev, POPULATION",