Skip to content

Commit 3f3712a

Browse files
author
jmarkerink
committed
feat: implemented switch expression operator
1 parent b854cd1 commit 3f3712a

2 files changed

Lines changed: 201 additions & 0 deletions

File tree

core/src/main/java/de/bwaldvogel/mongo/backend/aggregation/Expression.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,6 +1424,85 @@ Object apply(List<?> expressionValue, Document document) {
14241424
}
14251425
},
14261426

1427+
$switch {
1428+
@Override
1429+
Object apply(Object expressionValue, Document document) {
1430+
Document switchDocument = requireDocument(expressionValue, 40060);
1431+
1432+
// Validate that 'branches' field exists
1433+
if (!switchDocument.containsKey("branches")) {
1434+
throw new MongoServerError(40061, "Missing 'branches' parameter to " + name());
1435+
}
1436+
1437+
// Validate unsupported parameters
1438+
List<String> supportedKeys = asList("branches", "default");
1439+
for (String key : switchDocument.keySet()) {
1440+
if (!supportedKeys.contains(key)) {
1441+
throw new MongoServerError(40067, "Unrecognized parameter to " + name() + ": " + key);
1442+
}
1443+
}
1444+
1445+
// Get and validate branches
1446+
Object branchesValue = switchDocument.get("branches");
1447+
if (!(branchesValue instanceof Collection<?>)) {
1448+
throw new MongoServerError(40061, name() + " expected an array for 'branches', found: " + describeType(branchesValue));
1449+
}
1450+
1451+
Collection<?> branches = (Collection<?>) branchesValue;
1452+
if (branches.isEmpty()) {
1453+
throw new MongoServerError(40060, name() + " requires at least one branch");
1454+
}
1455+
1456+
// Evaluate each branch
1457+
for (Object branchValue : branches) {
1458+
if (!(branchValue instanceof Document)) {
1459+
throw new MongoServerError(40062, name() + " expected each branch to be an object, found: " + describeType(branchValue));
1460+
}
1461+
1462+
Document branch = (Document) branchValue;
1463+
1464+
// Validate branch has required fields
1465+
if (!branch.containsKey("case")) {
1466+
throw new MongoServerError(40064, name() + " requires each branch have a 'case' expression");
1467+
}
1468+
if (!branch.containsKey("then")) {
1469+
throw new MongoServerError(40065, name() + " requires each branch have a 'then' expression");
1470+
}
1471+
1472+
// Validate branch has no extra fields
1473+
for (String key : branch.keySet()) {
1474+
if (!asList("case", "then").contains(key)) {
1475+
throw new MongoServerError(40063, name() + " found an unknown argument to a branch: " + key);
1476+
}
1477+
}
1478+
1479+
// Evaluate the case expression
1480+
Object caseExpression = branch.get("case");
1481+
Object caseResult = evaluate(caseExpression, document);
1482+
1483+
// If case is true, evaluate and return the then expression
1484+
if (Utils.isTrue(caseResult)) {
1485+
Object thenExpression = branch.get("then");
1486+
return evaluate(thenExpression, document);
1487+
}
1488+
}
1489+
1490+
// No case matched, check for default
1491+
if (switchDocument.containsKey("default")) {
1492+
Object defaultExpression = switchDocument.get("default");
1493+
return evaluate(defaultExpression, document);
1494+
}
1495+
1496+
// No case matched and no default provided
1497+
throw new MongoServerError(40066, name() + " could not find a matching branch for an input, and no default was specified.");
1498+
}
1499+
1500+
@Override
1501+
Object apply(List<?> expressionValue, Document document) {
1502+
throw new UnsupportedOperationException("must not be invoked");
1503+
}
1504+
},
1505+
14271506
$sqrt {
14281507
@Override
14291508
Object apply(List<?> expressionValue, Document document) {

test-common/src/main/java/de/bwaldvogel/mongo/backend/AbstractAggregationTest.java

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,128 @@ void testProjectWithCondition() throws Exception {
27332733
);
27342734
}
27352735

2736+
@Test
2737+
void testAggregateWithSwitch() throws Exception {
2738+
collection.insertOne(json("_id: 1, name: 'Dave', qty: 1"));
2739+
collection.insertOne(json("_id: 2, name: 'Carol', qty: 5"));
2740+
collection.insertOne(json("_id: 3, name: 'Bob', qty: 10"));
2741+
collection.insertOne(json("_id: 4, name: 'Alice', qty: 20"));
2742+
2743+
List<Document> pipeline = jsonList("""
2744+
$project: {
2745+
name: 1,
2746+
qtyDiscount: {
2747+
$switch: {
2748+
branches: [
2749+
{ case: { $gte: ['$qty', 10] }, then: 0.15 },
2750+
{ case: { $gte: ['$qty', 5] }, then: 0.10 },
2751+
{ case: { $gte: ['$qty', 1] }, then: 0.05 }
2752+
],
2753+
default: 0
2754+
}
2755+
}
2756+
}
2757+
""");
2758+
2759+
assertThat(collection.aggregate(pipeline))
2760+
.containsExactlyInAnyOrder(
2761+
json("_id: 1, name: 'Dave', qtyDiscount: 0.05"),
2762+
json("_id: 2, name: 'Carol', qtyDiscount: 0.10"),
2763+
json("_id: 3, name: 'Bob', qtyDiscount: 0.15"),
2764+
json("_id: 4, name: 'Alice', qtyDiscount: 0.15")
2765+
);
2766+
}
2767+
2768+
@Test
2769+
void testAggregateWithSwitchDefault() throws Exception {
2770+
collection.insertOne(json("_id: 1, status: 'active'"));
2771+
collection.insertOne(json("_id: 2, status: 'inactive'"));
2772+
collection.insertOne(json("_id: 3, status: 'unknown'"));
2773+
2774+
List<Document> pipeline = jsonList("""
2775+
$project: {
2776+
statusCode: {
2777+
$switch: {
2778+
branches: [
2779+
{ case: { $eq: ['$status', 'active'] }, then: 1 },
2780+
{ case: { $eq: ['$status', 'inactive'] }, then: 0 }
2781+
],
2782+
default: -1
2783+
}
2784+
}
2785+
}
2786+
""");
2787+
2788+
assertThat(collection.aggregate(pipeline))
2789+
.containsExactlyInAnyOrder(
2790+
json("_id: 1, statusCode: 1"),
2791+
json("_id: 2, statusCode: 0"),
2792+
json("_id: 3, statusCode: -1")
2793+
);
2794+
}
2795+
2796+
@Test
2797+
void testAggregateWithSwitchMissingDefault() throws Exception {
2798+
collection.insertOne(json("_id: 1, value: 100"));
2799+
2800+
List<Document> pipeline = jsonList("""
2801+
$project: {
2802+
result: {
2803+
$switch: {
2804+
branches: [
2805+
{ case: { $eq: ['$value', 50] }, then: 'fifty' }
2806+
]
2807+
}
2808+
}
2809+
}
2810+
""");
2811+
2812+
assertThatExceptionOfType(MongoCommandException.class)
2813+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2814+
.withMessageContaining("$switch could not find a matching branch for an input, and no default was specified");
2815+
}
2816+
2817+
@Test
2818+
void testAggregateWithSwitchMissingBranches() throws Exception {
2819+
collection.insertOne(json("_id: 1, value: 100"));
2820+
2821+
List<Document> pipeline = jsonList("""
2822+
$project: {
2823+
result: {
2824+
$switch: {
2825+
default: 'none'
2826+
}
2827+
}
2828+
}
2829+
""");
2830+
2831+
assertThatExceptionOfType(MongoCommandException.class)
2832+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2833+
.withMessageContaining("Missing 'branches' parameter to $switch");
2834+
}
2835+
2836+
@Test
2837+
void testAggregateWithSwitchInvalidBranch() throws Exception {
2838+
collection.insertOne(json("_id: 1, value: 100"));
2839+
2840+
List<Document> pipeline = jsonList("""
2841+
$project: {
2842+
result: {
2843+
$switch: {
2844+
branches: [
2845+
{ case: { $eq: ['$value', 100] } }
2846+
],
2847+
default: 'none'
2848+
}
2849+
}
2850+
}
2851+
""");
2852+
2853+
assertThatExceptionOfType(MongoCommandException.class)
2854+
.isThrownBy(() -> collection.aggregate(pipeline).first())
2855+
.withMessageContaining("$switch requires each branch have a 'then' expression");
2856+
}
2857+
27362858
// https://github.com/bwaldvogel/mongo-java-server/issues/138
27372859
@Test
27382860
public void testAggregateWithGeoNear() throws Exception {

0 commit comments

Comments
 (0)