diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
index 9b26c6a2d..3ca5093dd 100644
--- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
+++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
@@ -1,6 +1,7 @@
package io.substrait.dsl;
import io.substrait.expression.AggregateFunctionInvocation;
+import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.Cast;
import io.substrait.expression.Expression.FailureBehavior;
@@ -13,6 +14,7 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
+import io.substrait.expression.StatisticalDistribution;
import io.substrait.expression.WindowBound;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
@@ -1361,6 +1363,193 @@ public Aggregate.Measure sum0(Expression expr) {
R.I64);
}
+ /**
+ * Creates a population standard deviation aggregate measure for a specific field.
+ *
+ *
Computes the standard deviation using the population formula (n denominator), which
+ * considers all values in the dataset as the entire population. This is equivalent to SQL's
+ * STDDEV_POP function.
+ *
+ * @param input the input relation containing the field
+ * @param field the zero-based index of the field to aggregate
+ * @return an aggregate measure computing population standard deviation with
+ * distribution=POPULATION enum argument
+ */
+ public Aggregate.Measure stddevPopulation(Rel input, int field) {
+ return stddevPopulation(fieldReference(input, field));
+ }
+
+ /**
+ * Creates a population standard deviation aggregate measure for an expression.
+ *
+ *
Computes the standard deviation using the population formula (n denominator), which
+ * considers all values in the dataset as the entire population. This is equivalent to SQL's
+ * STDDEV_POP function.
+ *
+ *
The measure is created with:
+ *
+ *
+ * - Function: Substrait's "std_dev" from the arithmetic extension
+ *
- Argument: distribution=POPULATION (enum argument)
+ *
- Output type: nullable version of the input expression type
+ *
- Aggregation phase: INITIAL_TO_RESULT
+ *
- Invocation: ALL (processes all rows)
+ *
+ *
+ * @param expr the expression to aggregate (typically a numeric field reference)
+ * @return an aggregate measure computing population standard deviation
+ */
+ public Aggregate.Measure stddevPopulation(Expression expr) {
+ return statisticalAggregate(expr, "std_dev", StatisticalDistribution.POPULATION);
+ }
+
+ /**
+ * Creates a sample standard deviation aggregate measure for a specific field.
+ *
+ * Computes the standard deviation using the sample formula (n-1 denominator), which applies
+ * Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV
+ * function.
+ *
+ * @param input the input relation containing the field
+ * @param field the zero-based index of the field to aggregate
+ * @return an aggregate measure computing sample standard deviation with distribution=SAMPLE enum
+ * argument
+ */
+ public Aggregate.Measure stddevSample(Rel input, int field) {
+ return stddevSample(fieldReference(input, field));
+ }
+
+ /**
+ * Creates a sample standard deviation aggregate measure for an expression.
+ *
+ *
Computes the standard deviation using the sample formula (n-1 denominator), which applies
+ * Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV
+ * function.
+ *
+ *
The measure is created with:
+ *
+ *
+ * - Function: Substrait's "std_dev" from the arithmetic extension
+ *
- Argument: distribution=SAMPLE (enum argument)
+ *
- Output type: nullable version of the input expression type
+ *
- Aggregation phase: INITIAL_TO_RESULT
+ *
- Invocation: ALL (processes all rows)
+ *
+ *
+ * @param expr the expression to aggregate (typically a numeric field reference)
+ * @return an aggregate measure computing sample standard deviation
+ */
+ public Aggregate.Measure stddevSample(Expression expr) {
+ return statisticalAggregate(expr, "std_dev", StatisticalDistribution.SAMPLE);
+ }
+
+ /**
+ * Creates a population variance aggregate measure for a specific field.
+ *
+ * Computes the variance using the population formula (n denominator), which considers all
+ * values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function.
+ *
+ * @param input the input relation containing the field
+ * @param field the zero-based index of the field to aggregate
+ * @return an aggregate measure computing population variance with distribution=POPULATION enum
+ * argument
+ */
+ public Aggregate.Measure variancePopulation(Rel input, int field) {
+ return variancePopulation(fieldReference(input, field));
+ }
+
+ /**
+ * Creates a population variance aggregate measure for an expression.
+ *
+ *
Computes the variance using the population formula (n denominator), which considers all
+ * values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function.
+ *
+ *
The measure is created with:
+ *
+ *
+ * - Function: Substrait's "variance" from the arithmetic extension
+ *
- Argument: distribution=POPULATION (enum argument)
+ *
- Output type: nullable version of the input expression type
+ *
- Aggregation phase: INITIAL_TO_RESULT
+ *
- Invocation: ALL (processes all rows)
+ *
+ *
+ * @param expr the expression to aggregate (typically a numeric field reference)
+ * @return an aggregate measure computing population variance
+ */
+ public Aggregate.Measure variancePopulation(Expression expr) {
+ return statisticalAggregate(expr, "variance", StatisticalDistribution.POPULATION);
+ }
+
+ /**
+ * Creates a sample variance aggregate measure for a specific field.
+ *
+ * Computes the variance using the sample formula (n-1 denominator), which applies Bessel's
+ * correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function.
+ *
+ * @param input the input relation containing the field
+ * @param field the zero-based index of the field to aggregate
+ * @return an aggregate measure computing sample variance with distribution=SAMPLE enum argument
+ */
+ public Aggregate.Measure varianceSample(Rel input, int field) {
+ return varianceSample(fieldReference(input, field));
+ }
+
+ /**
+ * Creates a sample variance aggregate measure for an expression.
+ *
+ *
Computes the variance using the sample formula (n-1 denominator), which applies Bessel's
+ * correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function.
+ *
+ *
The measure is created with:
+ *
+ *
+ * - Function: Substrait's "variance" from the arithmetic extension
+ *
- Argument: distribution=SAMPLE (enum argument)
+ *
- Output type: nullable version of the input expression type
+ *
- Aggregation phase: INITIAL_TO_RESULT
+ *
- Invocation: ALL (processes all rows)
+ *
+ *
+ * @param expr the expression to aggregate (typically a numeric field reference)
+ * @return an aggregate measure computing sample variance
+ */
+ public Aggregate.Measure varianceSample(Expression expr) {
+ return statisticalAggregate(expr, "variance", StatisticalDistribution.SAMPLE);
+ }
+
+ /**
+ * Helper method to create statistical aggregate measures (std_dev, variance) with a {@code
+ * distribution} enum argument.
+ *
+ * Uses the non-deprecated function signatures that carry the population/sample distinction as
+ * a leading {@code distribution} {@link EnumArg} (e.g. {@code std_dev:req_fp64}).
+ *
+ * @param expr the expression to aggregate
+ * @param functionName the Substrait function name ("std_dev" or "variance")
+ * @param distribution the distribution type (SAMPLE or POPULATION)
+ * @return an aggregate measure with the specified distribution argument
+ */
+ private Aggregate.Measure statisticalAggregate(
+ Expression expr, String functionName, StatisticalDistribution distribution) {
+ String typeString = ToTypeString.apply(expr.getType());
+ SimpleExtension.AggregateFunctionVariant declaration =
+ extensions.getAggregateFunction(
+ SimpleExtension.FunctionAnchor.of(
+ DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
+ String.format("%s:req_%s", functionName, typeString)));
+ EnumArg distributionArg =
+ EnumArg.of((SimpleExtension.EnumArgument) declaration.args().get(0), distribution.name());
+ return measure(
+ AggregateFunctionInvocation.builder()
+ .arguments(Arrays.asList(distributionArg, expr))
+ .outputType(TypeCreator.asNullable(expr.getType()))
+ .declaration(declaration)
+ .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
+ .invocation(Expression.AggregationInvocation.ALL)
+ .build());
+ }
+
private Aggregate.Measure singleArgumentArithmeticAggregate(
Expression expr, String functionName, Type outputType) {
String typeString = ToTypeString.apply(expr.getType());
diff --git a/core/src/main/java/io/substrait/expression/StatisticalDistribution.java b/core/src/main/java/io/substrait/expression/StatisticalDistribution.java
new file mode 100644
index 000000000..62d223024
--- /dev/null
+++ b/core/src/main/java/io/substrait/expression/StatisticalDistribution.java
@@ -0,0 +1,17 @@
+package io.substrait.expression;
+
+/**
+ * The {@code distribution} enum argument of the Substrait {@code std_dev} and {@code variance}
+ * aggregate functions.
+ *
+ *
Distinguishes between the sample (n-1 denominator, Bessel's correction) and population (n
+ * denominator) variants. The enum constant names match the Substrait extension's enum option names
+ * ({@code SAMPLE} / {@code POPULATION}), so {@link #name()} yields the value used to build an
+ * {@link EnumArg}.
+ */
+public enum StatisticalDistribution {
+ /** Sample distribution (uses the n-1 denominator, Bessel's correction). */
+ SAMPLE,
+ /** Population distribution (uses the n denominator). */
+ POPULATION
+}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
index 0d5d5bf0e..4d7ea1bb9 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
@@ -29,6 +29,30 @@ public class AggregateFunctions {
/** Substrait-specific AVG aggregate function (nullable return type). */
public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG);
+ /**
+ * Standard deviation (population) aggregate function. Maps to Substrait's std_dev function with
+ * distribution=POPULATION enum argument.
+ */
+ public static SqlAggFunction STDDEV_POP = new SubstraitAvgAggFunction(SqlKind.STDDEV_POP);
+
+ /**
+ * Standard deviation (sample) aggregate function. Maps to Substrait's std_dev function with
+ * distribution=SAMPLE enum argument.
+ */
+ public static SqlAggFunction STDDEV_SAMP = new SubstraitAvgAggFunction(SqlKind.STDDEV_SAMP);
+
+ /**
+ * Variance (population) aggregate function. Maps to Substrait's variance function with
+ * distribution=POPULATION enum argument.
+ */
+ public static SqlAggFunction VAR_POP = new SubstraitAvgAggFunction(SqlKind.VAR_POP);
+
+ /**
+ * Variance (sample) aggregate function. Maps to Substrait's variance function with
+ * distribution=SAMPLE enum argument.
+ */
+ public static SqlAggFunction VAR_SAMP = new SubstraitAvgAggFunction(SqlKind.VAR_SAMP);
+
/** Substrait-specific SUM aggregate function (nullable return type). */
public static SqlAggFunction SUM = new SubstraitSumAggFunction();
@@ -42,18 +66,34 @@ public class AggregateFunctions {
* @return optional containing Substrait equivalent if conversion applies
*/
public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) {
- if (aggFunction instanceof SqlMinMaxAggFunction) {
- SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction;
- return Optional.of(
- fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX);
- } else if (aggFunction instanceof SqlAvgAggFunction) {
- return Optional.of(AggregateFunctions.AVG);
- } else if (aggFunction instanceof SqlSumAggFunction) {
- return Optional.of(AggregateFunctions.SUM);
- } else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
- return Optional.of(AggregateFunctions.SUM0);
- } else {
- return Optional.empty();
+ // First check by SqlKind to handle all statistical functions
+ SqlKind kind = aggFunction.getKind();
+ switch (kind) {
+ case MIN:
+ return Optional.of(AggregateFunctions.MIN);
+ case MAX:
+ return Optional.of(AggregateFunctions.MAX);
+ case AVG:
+ return Optional.of(AggregateFunctions.AVG);
+ case STDDEV_POP:
+ return Optional.of(AggregateFunctions.STDDEV_POP);
+ case STDDEV_SAMP:
+ return Optional.of(AggregateFunctions.STDDEV_SAMP);
+ case VAR_POP:
+ return Optional.of(AggregateFunctions.VAR_POP);
+ case VAR_SAMP:
+ return Optional.of(AggregateFunctions.VAR_SAMP);
+ case SUM:
+ case SUM0:
+ // Check instance type for SUM variants
+ if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
+ return Optional.of(AggregateFunctions.SUM0);
+ } else if (aggFunction instanceof SqlSumAggFunction) {
+ return Optional.of(AggregateFunctions.SUM);
+ }
+ return Optional.empty();
+ default:
+ return Optional.empty();
}
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java
index 705e02e3c..f8695db64 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java
@@ -46,8 +46,11 @@ public static boolean isValidCalciteAggregate(Aggregate aggregate) {
*/
private static boolean isValidCalciteMeasure(Aggregate.Measure measure) {
return
- // all function arguments to measures must be field references
- measure.getFunction().arguments().stream().allMatch(farg -> isSimpleFieldReference(farg))
+ // all value (Expression) function arguments to measures must be field references; non-value
+ // arguments such as the std_dev/variance "distribution" enum argument are exempt
+ measure.getFunction().arguments().stream()
+ .filter(farg -> farg instanceof Expression)
+ .allMatch(farg -> isSimpleFieldReference(farg))
&&
// all sort fields must be field references
measure.getFunction().sort().stream().allMatch(sf -> isSimpleFieldReference(sf.expr()))
@@ -157,9 +160,9 @@ public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) {
private Aggregate.Measure updateMeasure(Aggregate.Measure measure) {
AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction();
- List newFunctionArgs =
+ List newFunctionArgs =
oldAggregateFunctionInvocation.arguments().stream()
- .map(this::projectOutNonFieldReference)
+ .map(this::projectOutNonFieldReferenceArg)
.collect(Collectors.toList());
List newSortFields =
@@ -194,11 +197,13 @@ private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) {
return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build();
}
- private Expression projectOutNonFieldReference(FunctionArg farg) {
+ private FunctionArg projectOutNonFieldReferenceArg(FunctionArg farg) {
if ((farg instanceof Expression)) {
return projectOutNonFieldReference((Expression) farg);
} else {
- throw new IllegalArgumentException("cannot handle non-expression argument for aggregate");
+ // Non-value arguments (e.g. the std_dev/variance "distribution" enum argument) are not
+ // field references and are passed through unchanged.
+ return farg;
}
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java
index 52849f96f..2d42b9f35 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java
@@ -384,8 +384,11 @@ public RelNode visit(Aggregate aggregate, Context context) throws RuntimeExcepti
private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) {
List eArgs = measure.getFunction().arguments();
+ // Only value (Expression) arguments map to Calcite aggregate operands. Enum arguments such as
+ // the std_dev/variance "distribution" are used to disambiguate the operator, not as operands.
List arguments =
- IntStream.range(0, measure.getFunction().arguments().size())
+ IntStream.range(0, eArgs.size())
+ .filter(i -> eArgs.get(i) instanceof Expression)
.mapToObj(
i ->
eArgs
@@ -398,7 +401,9 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) {
.collect(java.util.stream.Collectors.toList());
Optional operator =
aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc(
- measure.getFunction().declaration().key(), measure.getFunction().outputType());
+ measure.getFunction().declaration().key(),
+ measure.getFunction().outputType(),
+ measure.getFunction().arguments());
if (!operator.isPresent()) {
throw new IllegalArgumentException(
String.format(
diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
index bc4d0a9f2..94b41964f 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java
@@ -36,6 +36,7 @@
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
+import io.substrait.type.TypeCreator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -52,11 +53,14 @@
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.TableModify;
+import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;
@@ -84,6 +88,9 @@ public class SubstraitRelVisitor extends RelNodeVisitor {
private Map fieldAccessDepthMap;
+ /** Rex builder for creating Rex expressions during conversion. */
+ protected RexBuilder rexBuilder;
+
/**
* Creates a new SubstraitRelVisitor with the specified type factory and extensions.
*
@@ -106,6 +113,7 @@ public SubstraitRelVisitor(ConverterProvider converterProvider) {
this.typeConverter = converterProvider.getTypeConverter();
this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter();
this.rexExpressionConverter = converterProvider.getRexExpressionConverter(this);
+ this.rexBuilder = new RexBuilder(converterProvider.getTypeFactory());
}
/**
@@ -331,6 +339,16 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) {
*/
@Override
public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
+ // Substrait's std_dev/variance functions only define fp32/fp64 signatures. If a statistical
+ // aggregate has a non-floating-point argument, rewrite the aggregate to cast that argument to
+ // fp64 and cast the result back to the type Calcite inferred, then convert the rewritten plan
+ // through the normal path. The rewrite is idempotent (fp32/fp64 arguments are left untouched),
+ // so it terminates when the converted plan is re-converted.
+ RelNode rewritten = castStatisticalAggregatesToFloatingPoint(aggregate);
+ if (rewritten != aggregate) {
+ return apply(rewritten);
+ }
+
Rel input = apply(aggregate.getInput());
Stream sets;
if (aggregate.groupSets != null) {
@@ -411,6 +429,28 @@ Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {
return Aggregate.Grouping.builder().addAllExpressions(references).build();
}
+ /**
+ * Converts a Calcite {@link AggregateCall} to a Substrait {@link Aggregate.Measure}.
+ *
+ * This method handles the conversion of aggregate function calls from Calcite's representation
+ * to Substrait's format. For statistical aggregate functions (STDDEV_POP, STDDEV_SAMP, VAR_POP,
+ * VAR_SAMP), it automatically transforms the input relation by inserting a projection that casts
+ * the aggregate function's argument fields to DOUBLE (FP64) type, ensuring type compatibility
+ * with Substrait's statistical function requirements. Fields not referenced by the aggregate
+ * function are passed through unchanged.
+ *
+ *
The method also processes optional filter arguments (FILTER clauses) by converting them to
+ * Substrait's preMeasureFilter representation.
+ *
+ * @param input the input relational node providing data to the aggregate operation
+ * @param inputType the Substrait struct type representing the schema of the input relation
+ * @param call the Calcite aggregate call to convert, containing the aggregate function,
+ * arguments, and optional filter
+ * @return a Substrait {@link Aggregate.Measure} representing the aggregate function invocation
+ * with its configuration
+ * @throws UnsupportedOperationException if the aggregate function cannot be converted to a
+ * Substrait representation (no matching function binding found)
+ */
Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) {
Optional invocation =
aggregateFunctionConverter.convert(
@@ -427,6 +467,141 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal
return builder.build();
}
+ private static boolean isStatisticalDistributionAggregate(SqlKind kind) {
+ return kind == SqlKind.STDDEV_POP
+ || kind == SqlKind.STDDEV_SAMP
+ || kind == SqlKind.VAR_POP
+ || kind == SqlKind.VAR_SAMP;
+ }
+
+ private boolean isFloatingPoint(RelDataType type) {
+ Type substraitType = typeConverter.toSubstrait(type);
+ return TypeCreator.NULLABLE.FP32.equalsIgnoringNullability(substraitType)
+ || TypeCreator.NULLABLE.FP64.equalsIgnoringNullability(substraitType);
+ }
+
+ /**
+ * Rewrites a Calcite aggregate so that statistical aggregate functions (STDDEV_POP, STDDEV_SAMP,
+ * VAR_POP, VAR_SAMP) with non-floating-point arguments operate on fp64, since Substrait's {@code
+ * std_dev} / {@code variance} functions only define fp32 and fp64 signatures.
+ *
+ * For each statistical aggregate whose single argument is neither fp32 nor fp64 (e.g. an
+ * integer or decimal column), the rewrite:
+ *
+ *
+ * - appends a {@code cast(arg AS fp64)} column to the aggregate's input (leaving the original
+ * column in place, so other aggregates over the same column are unaffected),
+ *
- re-points the statistical aggregate at the appended column (its return type is re-derived
+ * over fp64), and
+ *
- casts the aggregate's results back to the types Calcite originally inferred, via a
+ * projection on top, so the aggregate's output row type is preserved.
+ *
+ *
+ * The rewrite is idempotent: fp32/fp64 arguments are left untouched, so converting the
+ * rewritten plan (whose statistical arguments are already fp64) produces no further rewrite and
+ * the recursion in {@link #visit(org.apache.calcite.rel.core.Aggregate)} terminates. If no
+ * argument needs casting, the aggregate is returned unchanged.
+ *
+ * @param aggregate the Calcite aggregate to inspect
+ * @return {@code aggregate} unchanged, or a {@link LogicalProject} wrapping a rewritten aggregate
+ */
+ protected RelNode castStatisticalAggregatesToFloatingPoint(
+ org.apache.calcite.rel.core.Aggregate aggregate) {
+ RelNode input = aggregate.getInput();
+ List calls = aggregate.getAggCallList();
+ int inputFieldCount = input.getRowType().getFieldCount();
+
+ // fp64 cast expressions to append to the input, and the source field each one casts (for reuse)
+ List appendedCasts = new ArrayList<>();
+ List appendedSourceFields = new ArrayList<>();
+ // per call: the appended column its argument should be re-pointed at, or -1 if unchanged
+ List rewrittenArgColumns = new ArrayList<>(calls.size());
+
+ for (AggregateCall call : calls) {
+ int rewrittenArgColumn = -1;
+ if (isStatisticalDistributionAggregate(call.getAggregation().getKind())
+ && call.getArgList().size() == 1) {
+ int argIndex = call.getArgList().get(0);
+ RelDataType argType = input.getRowType().getFieldList().get(argIndex).getType();
+ if (!isFloatingPoint(argType)) {
+ int existing = appendedSourceFields.indexOf(argIndex);
+ if (existing >= 0) {
+ rewrittenArgColumn = inputFieldCount + existing;
+ } else {
+ RelDataType fp64 =
+ typeConverter.toCalcite(
+ rexBuilder.getTypeFactory(), Type.withNullability(argType.isNullable()).FP64);
+ appendedCasts.add(rexBuilder.makeCast(fp64, rexBuilder.makeInputRef(input, argIndex)));
+ appendedSourceFields.add(argIndex);
+ rewrittenArgColumn = inputFieldCount + appendedCasts.size() - 1;
+ }
+ }
+ }
+ rewrittenArgColumns.add(rewrittenArgColumn);
+ }
+
+ if (appendedCasts.isEmpty()) {
+ return aggregate;
+ }
+
+ // Extended input: all original columns (passthrough) followed by the appended fp64 casts.
+ List inputProjects = new ArrayList<>(inputFieldCount + appendedCasts.size());
+ for (int i = 0; i < inputFieldCount; i++) {
+ inputProjects.add(rexBuilder.makeInputRef(input, i));
+ }
+ inputProjects.addAll(appendedCasts);
+ RelNode extendedInput =
+ LogicalProject.create(input, Collections.emptyList(), inputProjects, (List) null);
+
+ // Re-point the statistical calls at the appended fp64 columns (return type re-derived); leave
+ // all other calls unchanged.
+ List rewrittenCalls = new ArrayList<>(calls.size());
+ for (int i = 0; i < calls.size(); i++) {
+ AggregateCall call = calls.get(i);
+ int rewrittenArgColumn = rewrittenArgColumns.get(i);
+ if (rewrittenArgColumn < 0) {
+ rewrittenCalls.add(call);
+ } else {
+ rewrittenCalls.add(
+ AggregateCall.create(
+ call.getAggregation(),
+ call.isDistinct(),
+ call.isApproximate(),
+ call.ignoreNulls(),
+ Collections.singletonList(rewrittenArgColumn),
+ call.filterArg,
+ call.distinctKeys,
+ call.getCollation(),
+ aggregate.getGroupCount(),
+ extendedInput,
+ /* type, null to re-derive over fp64 */ null,
+ call.getName()));
+ }
+ }
+
+ org.apache.calcite.rel.core.Aggregate rewrittenAggregate =
+ aggregate.copy(
+ aggregate.getTraitSet(),
+ extendedInput,
+ aggregate.getGroupSet(),
+ aggregate.getGroupSets(),
+ rewrittenCalls);
+
+ // Cast the (now fp64) statistical results back to the types Calcite originally inferred,
+ // preserving the aggregate's original output row type. Group keys and unaffected measures pass
+ // through unchanged.
+ RelDataType originalRowType = aggregate.getRowType();
+ List outputProjects = new ArrayList<>(originalRowType.getFieldCount());
+ for (int i = 0; i < originalRowType.getFieldCount(); i++) {
+ RelDataType targetType = originalRowType.getFieldList().get(i).getType();
+ RexNode ref = rexBuilder.makeInputRef(rewrittenAggregate, i);
+ outputProjects.add(
+ ref.getType().equals(targetType) ? ref : rexBuilder.makeCast(targetType, ref));
+ }
+ return LogicalProject.create(
+ rewrittenAggregate, Collections.emptyList(), outputProjects, originalRowType);
+ }
+
/**
* Converts a Calcite {@link org.apache.calcite.rel.core.Match}.
*
diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java
index 6f8918179..0b9b7e63d 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java
@@ -5,11 +5,13 @@
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
+import io.substrait.expression.StatisticalDistribution;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.SubstraitRelVisitor;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.Type;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
@@ -22,6 +24,7 @@
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
/**
@@ -76,9 +79,19 @@ public AggregateFunctionConverter(
/**
* Builds a Substrait aggregate invocation from the matched call and arguments.
*
- * @param call wrapped aggregate call
- * @param function matched Substrait function variant
- * @param arguments converted arguments
+ * This method constructs an {@link AggregateFunctionInvocation} with appropriate configuration
+ * including sort fields and invocation type (DISTINCT or ALL).
+ *
+ *
Statistical Functions: For standard deviation and variance functions (STDDEV_POP,
+ * STDDEV_SAMP, VAR_POP, VAR_SAMP), the population/sample distinction is carried by a leading
+ * {@code distribution} {@link io.substrait.expression.EnumArg} argument. That argument is
+ * synthesized as an operand in {@link #convert} so that the generic matcher resolves the enum-arg
+ * function variant ({@code std_dev:req_*} / {@code variance:req_*}) and constructs the {@link
+ * io.substrait.expression.EnumArg}; no special handling is required here.
+ *
+ * @param call wrapped aggregate call containing the Calcite aggregate information
+ * @param function matched Substrait function variant from the extension catalog
+ * @param arguments converted function arguments
* @param outputType result type of the invocation
* @return aggregate function invocation
*/
@@ -100,6 +113,7 @@ protected AggregateFunctionInvocation generateBinding(
agg.isDistinct()
? Expression.AggregationInvocation.DISTINCT
: Expression.AggregationInvocation.ALL;
+
return ExpressionCreator.aggregateFunction(
function,
outputType,
@@ -128,14 +142,50 @@ public Optional convert(
if (m == null) {
return Optional.empty();
}
- if (!m.allowedArgCount(call.getArgList().size())) {
+
+ // For statistical aggregates (std_dev/variance) the SAMPLE/POPULATION distinction is carried
+ // by a leading "distribution" enum argument. Synthesize it as an operand so the generic matcher
+ // resolves the enum-arg function variant and builds the EnumArg.
+ List leadingArgs = leadingEnumArgs(call);
+ if (!m.allowedArgCount(call.getArgList().size() + leadingArgs.size())) {
return Optional.empty();
}
- WrappedAggregateCall wrapped = new WrappedAggregateCall(call, input, rexBuilder, inputType);
+ WrappedAggregateCall wrapped =
+ new WrappedAggregateCall(call, leadingArgs, input, rexBuilder, inputType);
return m.attemptMatch(wrapped, topLevelConverter);
}
+ /**
+ * Computes the synthetic leading operands to prepend to a Calcite aggregate call before matching.
+ *
+ * For standard deviation and variance functions, Substrait carries the population/sample
+ * distinction as a leading {@code distribution} enum argument, whereas Calcite encodes it in the
+ * operator's {@link SqlKind}. This returns the matching {@link StatisticalDistribution} flag so
+ * the generic matcher selects the {@code std_dev:req_*} / {@code variance:req_*} variant and
+ * constructs the corresponding {@link io.substrait.expression.EnumArg}.
+ *
+ * @param call the Calcite aggregate call
+ * @return the leading enum operands (a single distribution flag for statistical functions, empty
+ * otherwise)
+ */
+ private List leadingEnumArgs(AggregateCall call) {
+ List leadingArgs = new ArrayList<>();
+ switch (call.getAggregation().getKind()) {
+ case STDDEV_SAMP:
+ case VAR_SAMP:
+ leadingArgs.add(rexBuilder.makeFlag(StatisticalDistribution.SAMPLE));
+ break;
+ case STDDEV_POP:
+ case VAR_POP:
+ leadingArgs.add(rexBuilder.makeFlag(StatisticalDistribution.POPULATION));
+ break;
+ default:
+ break;
+ }
+ return leadingArgs;
+ }
+
/**
* Resolves the appropriate function finder, applying Substrait-specific variants when needed.
*
@@ -160,6 +210,7 @@ protected FunctionFinder getFunctionFinder(AggregateCall call) {
/** Lightweight wrapper around {@link AggregateCall} providing operands and type access. */
static class WrappedAggregateCall implements FunctionConverter.GenericCall {
private final AggregateCall call;
+ private final List leadingArgs;
private final RelNode input;
private final RexBuilder rexBuilder;
private final Type.Struct inputType;
@@ -168,26 +219,36 @@ static class WrappedAggregateCall implements FunctionConverter.GenericCall {
* Creates a new wrapped aggregate call.
*
* @param call underlying Calcite aggregate call
+ * @param leadingArgs synthetic operands (e.g. a {@code distribution} enum flag) prepended ahead
+ * of the field arguments during matching
* @param input input relational node
* @param rexBuilder Rex builder for operand construction
* @param inputType Substrait input struct type
*/
private WrappedAggregateCall(
- AggregateCall call, RelNode input, RexBuilder rexBuilder, Type.Struct inputType) {
+ AggregateCall call,
+ List leadingArgs,
+ RelNode input,
+ RexBuilder rexBuilder,
+ Type.Struct inputType) {
this.call = call;
+ this.leadingArgs = leadingArgs;
this.input = input;
this.rexBuilder = rexBuilder;
this.inputType = inputType;
}
/**
- * Returns operands as input references over the argument list.
+ * Returns operands as the synthetic leading operands followed by input references over the
+ * argument list.
*
* @return stream of RexNode operands
*/
@Override
public Stream getOperands() {
- return call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r));
+ return Stream.concat(
+ leadingArgs.stream(),
+ call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r)));
}
/**
diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java
index 2f95ca276..2fef9f683 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java
@@ -1,6 +1,7 @@
package io.substrait.isthmus.expression;
import io.substrait.expression.EnumArg;
+import io.substrait.expression.StatisticalDistribution;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.extension.SimpleExtension.Argument;
@@ -78,6 +79,20 @@ public class EnumConverter {
calciteEnumMap.put(
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 1),
ExtractIndexing.class);
+
+ // std_dev and variance carry the SAMPLE/POPULATION distinction as a leading enum argument.
+ calciteEnumMap.put(
+ argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "std_dev:req_fp32", 0),
+ StatisticalDistribution.class);
+ calciteEnumMap.put(
+ argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "std_dev:req_fp64", 0),
+ StatisticalDistribution.class);
+ calciteEnumMap.put(
+ argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "variance:req_fp32", 0),
+ StatisticalDistribution.class);
+ calciteEnumMap.put(
+ argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "variance:req_fp64", 0),
+ StatisticalDistribution.class);
}
private static Optional> constructValue(
@@ -90,6 +105,10 @@ private static Optional> constructValue(
return option.get().map(SqlTrimFunction.Flag::valueOf);
}
+ if (cls.isAssignableFrom(StatisticalDistribution.class)) {
+ return option.get().map(StatisticalDistribution::valueOf);
+ }
+
// ExtractIndexing does not need to be converted here. Calcite
// doesn't have the concept of the indexing. It's date
// functions are all indexed from 1
diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
index 34655ccb5..3896fa897 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java
@@ -7,9 +7,11 @@
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Streams;
+import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
+import io.substrait.expression.StatisticalDistribution;
import io.substrait.extension.SimpleExtension;
import io.substrait.extension.SimpleExtension.Argument;
import io.substrait.function.ParameterizedType;
@@ -41,6 +43,7 @@
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -181,6 +184,27 @@ public FunctionConverter(
* @return matching {@link SqlOperator}, or empty if none
*/
public Optional getSqlOperatorFromSubstraitFunc(String key, Type outputType) {
+ return getSqlOperatorFromSubstraitFunc(key, outputType, java.util.Collections.emptyList());
+ }
+
+ /**
+ * Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction).
+ *
+ * Given a Substrait function key (e.g., "std_dev:req_fp64"), output type, and function
+ * arguments (which may include a {@code distribution} {@link io.substrait.expression.EnumArg}),
+ * this method finds the corresponding Calcite {@link SqlOperator}. When multiple operators match,
+ * the output type and the {@code distribution} enum argument are used to disambiguate.
+ *
+ *
For example, both STDDEV_POP and STDDEV_SAMP map to "std_dev:req_fp64", but differ in the
+ * {@code distribution} enum argument (POPULATION vs SAMPLE).
+ *
+ * @param key the Substrait function key (function name with type signature)
+ * @param outputType the expected output type
+ * @param arguments the function arguments (used to read the {@code distribution} enum argument)
+ * @return the matching {@link SqlOperator}, or empty if no match found
+ */
+ public Optional getSqlOperatorFromSubstraitFunc(
+ String key, Type outputType, List arguments) {
Map resolver = getTypeBasedResolver();
Collection operators = substraitFuncKeyToSqlOperatorMap.get(key);
if (operators.isEmpty()) {
@@ -192,27 +216,101 @@ public Optional getSqlOperatorFromSubstraitFunc(String key, Type ou
return Optional.of(operators.iterator().next());
}
- // at least 2 operators. Use output type to resolve SqlOperator.
+ // First, filter by output type to ensure type compatibility
String outputTypeStr = outputType.accept(ToTypeString.INSTANCE);
- List resolvedOperators =
+ List typeFilteredOperators =
operators.stream()
.filter(
operator ->
resolver.containsKey(operator)
&& resolver.get(operator).types().contains(outputTypeStr))
.collect(Collectors.toList());
+
+ // If type filtering resolved to a single operator, return it
+ if (typeFilteredOperators.size() == 1) {
+ return Optional.of(typeFilteredOperators.get(0));
+ }
+
+ // If still ambiguous and a distribution enum argument is present, disambiguate by it.
+ // Both the population and sample operators share one key (e.g. variance:req_fp32), since the
+ // SAMPLE/POPULATION value lives in the argument, not the signature.
+ Optional distribution = distributionArgument(arguments);
+ List resolvedOperators = typeFilteredOperators;
+ if (distribution.isPresent()) {
+ List candidates =
+ typeFilteredOperators.isEmpty() ? List.copyOf(operators) : typeFilteredOperators;
+ resolvedOperators = filterByDistribution(candidates, distribution.get());
+ }
+
// only one SqlOperator is possible
if (resolvedOperators.size() == 1) {
return Optional.of(resolvedOperators.get(0));
} else if (resolvedOperators.size() > 1) {
throw new IllegalStateException(
String.format(
- "Found %d SqlOperators: %s for ScalarFunction %s: ",
+ "Found %d SqlOperators: %s for function %s",
resolvedOperators.size(), resolvedOperators, key));
}
return Optional.empty();
}
+ /**
+ * Extracts the value of the {@code distribution} enum argument, if present.
+ *
+ * This returns the value of the first {@link EnumArg} in the argument list. It assumes the
+ * only enum argument that disambiguates between operators sharing a key is the {@code
+ * distribution} argument of {@code std_dev}/{@code variance} — the only enum-argument aggregate
+ * functions currently mapped. {@link #filterByDistribution} rejects values it does not recognize.
+ *
+ * @param arguments the Substrait function arguments
+ * @return the distribution value (e.g. {@code SAMPLE} / {@code POPULATION}) if an {@link EnumArg}
+ * is present
+ */
+ private static Optional distributionArgument(List arguments) {
+ if (arguments == null) {
+ return Optional.empty();
+ }
+ return arguments.stream()
+ .filter(arg -> arg instanceof EnumArg)
+ .map(arg -> (EnumArg) arg)
+ .flatMap(arg -> arg.value().stream())
+ .findFirst();
+ }
+
+ /**
+ * Filters SqlOperators based on the {@code distribution} enum argument.
+ *
+ * For statistical functions like STDDEV and VAR, the {@code distribution} argument determines
+ * whether to use the population or sample variant:
+ *
+ *
+ * - distribution=POPULATION → STDDEV_POP, VAR_POP
+ *
- distribution=SAMPLE → STDDEV_SAMP, VAR_SAMP
+ *
+ *
+ * @param operators the list of candidate SqlOperators
+ * @param distributionValue the distribution value from the Substrait enum argument
+ * @return filtered list of SqlOperators matching the distribution
+ */
+ private List filterByDistribution(
+ List operators, String distributionValue) {
+ return operators.stream()
+ .filter(
+ operator -> {
+ SqlKind kind = operator.getKind();
+ // Match distribution value to SqlKind
+ if (StatisticalDistribution.POPULATION.name().equals(distributionValue)) {
+ return kind == SqlKind.STDDEV_POP || kind == SqlKind.VAR_POP;
+ } else if (StatisticalDistribution.SAMPLE.name().equals(distributionValue)) {
+ return kind == SqlKind.STDDEV_SAMP || kind == SqlKind.VAR_SAMP;
+ }
+ throw new IllegalArgumentException(
+ String.format(
+ "Unknown distribution value '%s' for operator %s", distributionValue, kind));
+ })
+ .collect(Collectors.toList());
+ }
+
/**
* Returns the resolver used to disambiguate Calcite operators by output type.
*
diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java
index 85f44a312..4c7ca3023 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java
@@ -155,7 +155,19 @@ public class FunctionMappings {
s(AggregateFunctions.SUM0, "sum0"),
s(SqlStdOperatorTable.COUNT, "count"),
s(SqlStdOperatorTable.APPROX_COUNT_DISTINCT, "approx_count_distinct"),
- s(AggregateFunctions.AVG, "avg"))
+ s(AggregateFunctions.AVG, "avg"),
+ /*
+ * Substrait std_dev and variance functions use a leading 'distribution' enum
+ * argument (SAMPLE or POPULATION) to distinguish between population and sample
+ * variants. AggregateFunctionConverter synthesizes that argument based on the SqlKind.
+ *
+ * Note: Standard Calcite operators (SqlStdOperatorTable.STDDEV_SAMP, etc.) are
+ * automatically converted to these Substrait variants via toSubstraitAggVariant().
+ */
+ s(AggregateFunctions.STDDEV_POP, "std_dev"),
+ s(AggregateFunctions.STDDEV_SAMP, "std_dev"),
+ s(AggregateFunctions.VAR_POP, "variance"),
+ s(AggregateFunctions.VAR_SAMP, "variance"))
.build();
/** Window function signatures (including supported aggregates) mapped to Substrait names. */
diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java
index e3082f55b..04459ae7e 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java
@@ -45,6 +45,14 @@ private Aggregate.Measure functionPicker(Rel input, int field, String fname) {
return sb.sum0(input, field);
case "avg":
return sb.avg(input, field);
+ case "stddev_pop":
+ return sb.stddevPopulation(input, field);
+ case "stddev_samp":
+ return sb.stddevSample(input, field);
+ case "var_pop":
+ return sb.variancePopulation(input, field);
+ case "var_samp":
+ return sb.varianceSample(input, field);
default:
throw new UnsupportedOperationException(
String.format("no function is associated with %s", fname));
@@ -54,14 +62,41 @@ private Aggregate.Measure functionPicker(Rel input, int field, String fname) {
// Create one function call per numeric type column
private List functions(Rel input, String fname) {
// first column is for grouping, skip it
+ // Statistical functions (stddev_*, var_*) only support floating-point types in Substrait.
+ // This filtering ensures we only test with fp32 and fp64 types for these functions,
+ // avoiding type mismatch errors during round-trip conversion.
+ boolean isStatisticalFunction = fname.startsWith("stddev_") || fname.startsWith("var_");
return IntStream.range(1, tableTypes.size())
.boxed()
+ .filter(
+ index -> {
+ if (!isStatisticalFunction) {
+ return true; // All numeric types for non-statistical functions
+ }
+ // Only floating-point types for statistical functions
+ Type type = tableTypes.get(index);
+ return type.equals(R.FP32)
+ || type.equals(R.FP64)
+ || type.equals(N.FP32)
+ || type.equals(N.FP64);
+ })
.map(index -> functionPicker(input, index, fname))
.collect(Collectors.toList());
}
@ParameterizedTest
- @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"})
+ @ValueSource(
+ strings = {
+ "max",
+ "min",
+ "sum",
+ "sum0",
+ "avg",
+ "stddev_pop",
+ "stddev_samp",
+ "var_pop",
+ "var_samp"
+ })
void emptyGrouping(String aggFunction) {
Aggregate rel =
sb.aggregate(
@@ -70,7 +105,18 @@ void emptyGrouping(String aggFunction) {
}
@ParameterizedTest
- @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"})
+ @ValueSource(
+ strings = {
+ "max",
+ "min",
+ "sum",
+ "sum0",
+ "avg",
+ "stddev_pop",
+ "stddev_samp",
+ "var_pop",
+ "var_samp"
+ })
void withGrouping(String aggFunction) {
Aggregate rel =
sb.aggregate(
diff --git a/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java
new file mode 100644
index 000000000..7dd03b5f5
--- /dev/null
+++ b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java
@@ -0,0 +1,114 @@
+package io.substrait.isthmus;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.substrait.expression.AggregateFunctionInvocation;
+import io.substrait.expression.EnumArg;
+import io.substrait.isthmus.sql.SubstraitCreateStatementParser;
+import io.substrait.isthmus.sql.SubstraitSqlToCalcite;
+import io.substrait.plan.Plan;
+import io.substrait.relation.Aggregate;
+import io.substrait.relation.Rel;
+import java.util.List;
+import java.util.Optional;
+import org.apache.calcite.prepare.Prepare;
+import org.apache.calcite.rel.RelRoot;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+
+/**
+ * Verifies that the SQL statistical aggregates (STDDEV_POP/SAMP, VAR_POP/SAMP) map to the Substrait
+ * {@code std_dev} / {@code variance} functions using the non-deprecated enum-argument signatures
+ * ({@code std_dev:req_fp64} etc.), carrying the population/sample distinction as a {@code
+ * distribution} {@link EnumArg} rather than a function option.
+ */
+class StatisticalFunctionTest extends PlanTestBase {
+
+ static final String CREATES =
+ "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)";
+
+ @ParameterizedTest
+ @CsvSource({"STDDEV_POP", "STDDEV_SAMP", "VAR_POP", "VAR_SAMP"})
+ void roundTrip(String fn) throws Exception {
+ assertFullRoundTrip(String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fn, fn), CREATES);
+ }
+
+ // Integer arguments are cast to fp64 (and the result cast back) since std_dev/variance only have
+ // fp32/fp64 signatures. This rewrite (castStatisticalAggregatesToFloatingPoint) inserts a cast
+ // projection that Calcite normalizes (project merge/column pruning) on the first round trip, so
+ // these use the identity-projection workaround, which asserts stability after normalization.
+
+ @ParameterizedTest
+ @CsvSource({"STDDEV_POP", "STDDEV_SAMP", "VAR_POP", "VAR_SAMP"})
+ void roundTripIntegerInput(String fn) throws Exception {
+ assertFullRoundTripWithIdentityProjectionWorkaround(
+ String.format("SELECT %s(i32) FROM numbers", fn),
+ SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES));
+ }
+
+ @Test
+ void roundTripIntegerInputSharedWithOtherAggregate() throws Exception {
+ // The integer column is shared by SUM (which must keep operating on the integer) and STDDEV_POP
+ // (which is cast to fp64); the cast must be appended, not applied in place.
+ assertFullRoundTripWithIdentityProjectionWorkaround(
+ "SELECT SUM(i32), STDDEV_POP(i32) FROM numbers",
+ SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES));
+ }
+
+ @Test
+ void roundTripIntegerInputWithGrouping() throws Exception {
+ assertFullRoundTripWithIdentityProjectionWorkaround(
+ "SELECT i8, VAR_POP(i32) FROM numbers GROUP BY i8",
+ SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES));
+ }
+
+ @ParameterizedTest
+ @CsvSource({
+ "STDDEV_POP, std_dev, POPULATION",
+ "STDDEV_SAMP, std_dev, SAMPLE",
+ "VAR_POP, variance, POPULATION",
+ "VAR_SAMP, variance, SAMPLE",
+ })
+ void usesEnumArgSignature(String sqlFn, String substraitFn, String distribution)
+ throws Exception {
+ Prepare.CatalogReader catalog =
+ SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES);
+ RelRoot calcite =
+ SubstraitSqlToCalcite.convertQuery(
+ String.format("SELECT %s(fp64) FROM numbers", sqlFn),
+ catalog,
+ converterProvider.getSqlOperatorTable());
+ Plan.Root root = SubstraitRelVisitor.convert(calcite, converterProvider);
+
+ AggregateFunctionInvocation function = firstMeasure(root.getInput()).getFunction();
+
+ // The non-deprecated enum-arg variant is used (note the "req" enum argument in the key).
+ assertEquals(substraitFn + ":req_fp64", function.declaration().key());
+
+ // The distribution is carried as a leading EnumArg, not as a function option.
+ List args = function.arguments();
+ EnumArg distributionArg = assertInstanceOf(EnumArg.class, args.get(0));
+ assertEquals(Optional.of(distribution), distributionArg.value());
+ assertTrue(function.options().isEmpty(), "expected no function options");
+ }
+
+ /** Recursively finds the first {@link Aggregate} measure in the relation tree. */
+ private static Aggregate.Measure firstMeasure(Rel rel) {
+ if (rel instanceof Aggregate) {
+ Aggregate aggregate = (Aggregate) rel;
+ if (!aggregate.getMeasures().isEmpty()) {
+ return aggregate.getMeasures().get(0);
+ }
+ }
+ for (Rel input : rel.getInputs()) {
+ Aggregate.Measure measure = firstMeasure(input);
+ if (measure != null) {
+ return measure;
+ }
+ }
+ return null;
+ }
+}
diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java
index d39cc2084..142eb78bf 100644
--- a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java
+++ b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java
@@ -9,6 +9,8 @@
import io.substrait.isthmus.expression.FunctionConverter.FunctionFinder;
import java.util.List;
import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.SqlTypeName;
import org.junit.jupiter.api.Test;
@@ -34,4 +36,84 @@ void testFunctionFinderMatch() {
assertEquals("sum0", functionFinder.getSubstraitName());
assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator());
}
+
+ @Test
+ void testStddevPopFunctionFinderMatch() {
+ AggregateFunctionConverter converter =
+ new AggregateFunctionConverter(
+ extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT);
+
+ FunctionFinder functionFinder =
+ converter.getFunctionFinder(
+ AggregateCall.create(
+ new SqlAvgAggFunction(SqlKind.STDDEV_POP),
+ false,
+ List.of(0),
+ -1,
+ typeFactory.createSqlType(SqlTypeName.DOUBLE),
+ null));
+ assertNotNull(functionFinder);
+ assertEquals("std_dev", functionFinder.getSubstraitName());
+ assertEquals(AggregateFunctions.STDDEV_POP, functionFinder.getOperator());
+ }
+
+ @Test
+ void testStddevSampFunctionFinderMatch() {
+ AggregateFunctionConverter converter =
+ new AggregateFunctionConverter(
+ extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT);
+
+ FunctionFinder functionFinder =
+ converter.getFunctionFinder(
+ AggregateCall.create(
+ new SqlAvgAggFunction(SqlKind.STDDEV_SAMP),
+ false,
+ List.of(0),
+ -1,
+ typeFactory.createSqlType(SqlTypeName.DOUBLE),
+ null));
+ assertNotNull(functionFinder);
+ assertEquals("std_dev", functionFinder.getSubstraitName());
+ assertEquals(AggregateFunctions.STDDEV_SAMP, functionFinder.getOperator());
+ }
+
+ @Test
+ void testVarPopFunctionFinderMatch() {
+ AggregateFunctionConverter converter =
+ new AggregateFunctionConverter(
+ extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT);
+
+ FunctionFinder functionFinder =
+ converter.getFunctionFinder(
+ AggregateCall.create(
+ new SqlAvgAggFunction(SqlKind.VAR_POP),
+ false,
+ List.of(0),
+ -1,
+ typeFactory.createSqlType(SqlTypeName.DOUBLE),
+ null));
+ assertNotNull(functionFinder);
+ assertEquals("variance", functionFinder.getSubstraitName());
+ assertEquals(AggregateFunctions.VAR_POP, functionFinder.getOperator());
+ }
+
+ @Test
+ void testVarSampFunctionFinderMatch() {
+ AggregateFunctionConverter converter =
+ new AggregateFunctionConverter(
+ extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT);
+
+ FunctionFinder functionFinder =
+ converter.getFunctionFinder(
+ AggregateCall.create(
+ new SqlAvgAggFunction(SqlKind.VAR_SAMP),
+ false,
+ List.of(0),
+ -1,
+ typeFactory.createSqlType(SqlTypeName.DOUBLE),
+ null));
+ assertNotNull(functionFinder);
+ assertEquals("variance", functionFinder.getSubstraitName());
+ assertEquals(AggregateFunctions.VAR_SAMP, functionFinder.getOperator());
+ }
}