Skip to content

Commit 4b87d75

Browse files
committed
feat: Batch Result serialization
- Adds serialization for batch result in the serdes module. Unfortunately we need to do an adhoc import as we are dealing with cyclical dependencies.
1 parent a950699 commit 4b87d75

5 files changed

Lines changed: 286 additions & 23 deletions

File tree

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class TypeTag(StrEnum):
6262
TUPLE = "t"
6363
LIST = "l"
6464
DICT = "m"
65+
BATCH_RESULT = "br"
6566

6667

6768
@dataclass(frozen=True)
@@ -206,7 +207,17 @@ def dispatcher(self):
206207

207208
def encode(self, obj: Any) -> EncodedValue:
208209
"""Encode container using dispatcher for recursive elements."""
210+
# Import here to avoid circular dependency
211+
# concurrency -> child_handler -> serdes -> concurrency
212+
from aws_durable_execution_sdk_python.concurrency import BatchResult # noqa PLC0415
213+
209214
match obj:
215+
case BatchResult():
216+
# Encode BatchResult as dict with special tag
217+
return EncodedValue(
218+
TypeTag.BATCH_RESULT,
219+
self._wrap(obj.to_dict(), self.dispatcher).value,
220+
)
210221
case list():
211222
return EncodedValue(
212223
TypeTag.LIST, [self._wrap(v, self.dispatcher) for v in obj]
@@ -230,7 +241,15 @@ def encode(self, obj: Any) -> EncodedValue:
230241

231242
def decode(self, tag: TypeTag, value: Any) -> Any:
232243
"""Decode container using dispatcher for recursive elements."""
244+
# Import here to avoid circular dependency
245+
from aws_durable_execution_sdk_python.concurrency import BatchResult # noqa PLC0415
246+
233247
match tag:
248+
case TypeTag.BATCH_RESULT:
249+
# Decode BatchResult from dict - value is already the dict structure
250+
# First decode it as a dict to unwrap all nested EncodedValues
251+
decoded_dict = self.decode(TypeTag.DICT, value)
252+
return BatchResult.from_dict(decoded_dict)
234253
case TypeTag.LIST:
235254
if not isinstance(value, list):
236255
msg = f"Expected list, got {type(value)}"
@@ -295,6 +314,11 @@ def encode(self, obj: Any) -> EncodedValue:
295314
case list() | tuple() | dict():
296315
return self.container_codec.encode(obj)
297316
case _:
317+
# Check if it's a BatchResult (handled by container_codec)
318+
from aws_durable_execution_sdk_python.concurrency import BatchResult # noqa PLC0415
319+
320+
if isinstance(obj, BatchResult):
321+
return self.container_codec.encode(obj)
298322
msg = f"Unsupported type: {type(obj)}"
299323
raise SerDesError(msg)
300324

@@ -316,7 +340,7 @@ def decode(self, tag: TypeTag, value: Any) -> Any:
316340
return self.decimal_codec.decode(tag, value)
317341
case TypeTag.DATETIME | TypeTag.DATE:
318342
return self.datetime_codec.decode(tag, value)
319-
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
343+
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT | TypeTag.BATCH_RESULT:
320344
return self.container_codec.decode(tag, value)
321345
case _:
322346
msg = f"Unknown type tag: {tag}"

tests/concurrency_test.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the concurrency module."""
22

3+
import json
34
import random
45
import threading
56
import time
@@ -102,28 +103,6 @@ def test_batch_item_from_dict():
102103
assert item.error is None
103104

104105

105-
def test_batch_item_from_dict_with_error():
106-
"""Test BatchItem from_dict with error object."""
107-
error_data = {
108-
"message": "Test error",
109-
"type": "TestError",
110-
"data": None,
111-
"stackTrace": None,
112-
}
113-
data = {
114-
"index": 1,
115-
"status": "FAILED",
116-
"result": None,
117-
"error": error_data,
118-
}
119-
120-
item = BatchItem.from_dict(data)
121-
assert item.index == 1
122-
assert item.status == BatchItemStatus.FAILED
123-
assert item.result is None
124-
assert item.error is not None
125-
126-
127106
def test_batch_result_creation():
128107
"""Test BatchResult creation."""
129108
items = [
@@ -2676,3 +2655,79 @@ def mock_get_checkpoint_result(operation_id):
26762655
assert len(result.all) == 1
26772656
assert result.all[0].status == BatchItemStatus.SUCCEEDED
26782657
assert result.all[0].result == "re_executed_result"
2658+
2659+
2660+
def test_batch_item_from_dict_with_error():
2661+
"""Test BatchItem.from_dict() with error."""
2662+
data = {
2663+
"index": 3,
2664+
"status": "FAILED",
2665+
"result": None,
2666+
"error": {
2667+
"ErrorType": "ValueError",
2668+
"ErrorMessage": "bad value",
2669+
"StackTrace": [],
2670+
},
2671+
}
2672+
2673+
item = BatchItem.from_dict(data)
2674+
2675+
assert item.index == 3
2676+
assert item.status == BatchItemStatus.FAILED
2677+
assert item.error.type == "ValueError"
2678+
assert item.error.message == "bad value"
2679+
2680+
2681+
def test_batch_result_with_mixed_statuses():
2682+
"""Test BatchResult serialization with mixed item statuses."""
2683+
result = BatchResult(
2684+
all=[
2685+
BatchItem(0, BatchItemStatus.SUCCEEDED, result="success"),
2686+
BatchItem(
2687+
1,
2688+
BatchItemStatus.FAILED,
2689+
error=ErrorObject(message="msg", type="E", data=None, stack_trace=[]),
2690+
),
2691+
BatchItem(2, BatchItemStatus.STARTED),
2692+
],
2693+
completion_reason=CompletionReason.FAILURE_TOLERANCE_EXCEEDED,
2694+
)
2695+
2696+
serialized = json.dumps(result.to_dict())
2697+
deserialized = BatchResult.from_dict(json.loads(serialized))
2698+
2699+
assert len(deserialized.all) == 3
2700+
assert deserialized.all[0].status == BatchItemStatus.SUCCEEDED
2701+
assert deserialized.all[1].status == BatchItemStatus.FAILED
2702+
assert deserialized.all[2].status == BatchItemStatus.STARTED
2703+
assert deserialized.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED
2704+
2705+
2706+
def test_batch_result_empty_list():
2707+
"""Test BatchResult serialization with empty items list."""
2708+
result = BatchResult(all=[], completion_reason=CompletionReason.ALL_COMPLETED)
2709+
2710+
serialized = json.dumps(result.to_dict())
2711+
deserialized = BatchResult.from_dict(json.loads(serialized))
2712+
2713+
assert len(deserialized.all) == 0
2714+
assert deserialized.completion_reason == CompletionReason.ALL_COMPLETED
2715+
2716+
2717+
def test_batch_result_complex_nested_data():
2718+
"""Test BatchResult with complex nested data structures."""
2719+
complex_result = {
2720+
"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}],
2721+
"metadata": {"count": 2, "timestamp": "2025-10-31"},
2722+
}
2723+
2724+
result = BatchResult(
2725+
all=[BatchItem(0, BatchItemStatus.SUCCEEDED, result=complex_result)],
2726+
completion_reason=CompletionReason.ALL_COMPLETED,
2727+
)
2728+
2729+
serialized = json.dumps(result.to_dict())
2730+
deserialized = BatchResult.from_dict(json.loads(serialized))
2731+
2732+
assert deserialized.all[0].result == complex_result
2733+
assert deserialized.all[0].result["users"][0]["name"] == "Alice"

tests/operation/map_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for map operation."""
22

3+
import json
34
from unittest.mock import Mock, patch
45

56
# Mock the executor.execute method
@@ -750,3 +751,45 @@ def get_checkpoint_result(self, operation_id):
750751
# Verify replay was called, execute was not
751752
mock_replay.assert_called_once()
752753
mock_execute.assert_not_called()
754+
755+
756+
def test_map_result_serialization_roundtrip():
757+
"""Test that map operation BatchResult can be serialized and deserialized."""
758+
759+
items = ["a", "b", "c"]
760+
761+
def func(ctx, item, idx, items):
762+
return {"item": item.upper(), "index": idx}
763+
764+
class MockExecutionState:
765+
durable_execution_arn = "arn:test"
766+
767+
def get_checkpoint_result(self, operation_id):
768+
mock_result = Mock()
769+
mock_result.is_succeeded.return_value = False
770+
return mock_result
771+
772+
execution_state = MockExecutionState()
773+
map_context = Mock()
774+
map_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2", "3"]) # noqa SLF001
775+
map_context.create_child_context = Mock(return_value=Mock())
776+
operation_identifier = OperationIdentifier("test_op", "parent", "test_map")
777+
778+
# Execute map
779+
result = map_handler(
780+
items, func, MapConfig(), execution_state, map_context, operation_identifier
781+
)
782+
783+
# Serialize the BatchResult
784+
serialized = json.dumps(result.to_dict())
785+
786+
# Deserialize
787+
deserialized = BatchResult.from_dict(json.loads(serialized))
788+
789+
# Verify all data preserved
790+
assert len(deserialized.all) == 3
791+
assert deserialized.all[0].result == {"item": "A", "index": 0}
792+
assert deserialized.all[1].result == {"item": "B", "index": 1}
793+
assert deserialized.all[2].result == {"item": "C", "index": 2}
794+
assert deserialized.completion_reason == result.completion_reason
795+
assert all(item.status == BatchItemStatus.SUCCEEDED for item in deserialized.all)

tests/operation/parallel_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the parallel operation module."""
22

3+
import json
34
from unittest.mock import Mock, patch
45

56
import pytest
@@ -734,3 +735,57 @@ def get_checkpoint_result(self, operation_id):
734735
# Verify replay was called, execute was not
735736
mock_replay.assert_called_once()
736737
mock_execute.assert_not_called()
738+
739+
740+
def test_parallel_result_serialization_roundtrip():
741+
"""Test that parallel operation BatchResult can be serialized and deserialized."""
742+
743+
def func1(ctx):
744+
return [1, 2, 3]
745+
746+
def func2(ctx):
747+
return {"status": "complete", "count": 42}
748+
749+
def func3(ctx):
750+
return "simple string"
751+
752+
callables = [func1, func2, func3]
753+
754+
class MockExecutionState:
755+
durable_execution_arn = "arn:test"
756+
757+
def get_checkpoint_result(self, operation_id):
758+
mock_result = Mock()
759+
mock_result.is_succeeded.return_value = False
760+
return mock_result
761+
762+
execution_state = MockExecutionState()
763+
parallel_context = Mock()
764+
parallel_context._create_step_id_for_logical_step = Mock( # noqa SLF001
765+
side_effect=["1", "2", "3"]
766+
)
767+
parallel_context.create_child_context = Mock(return_value=Mock())
768+
operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel")
769+
770+
# Execute parallel
771+
result = parallel_handler(
772+
callables,
773+
ParallelConfig(),
774+
execution_state,
775+
parallel_context,
776+
operation_identifier,
777+
)
778+
779+
# Serialize the BatchResult
780+
serialized = json.dumps(result.to_dict())
781+
782+
# Deserialize
783+
deserialized = BatchResult.from_dict(json.loads(serialized))
784+
785+
# Verify all data preserved
786+
assert len(deserialized.all) == 3
787+
assert deserialized.all[0].result == [1, 2, 3]
788+
assert deserialized.all[1].result == {"status": "complete", "count": 42}
789+
assert deserialized.all[2].result == "simple string"
790+
assert deserialized.completion_reason == result.completion_reason
791+
assert all(item.status == BatchItemStatus.SUCCEEDED for item in deserialized.all)

tests/serdes_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88

99
import pytest
1010

11+
from aws_durable_execution_sdk_python.concurrency import (
12+
BatchItem,
13+
BatchItemStatus,
14+
BatchResult,
15+
CompletionReason,
16+
)
1117
from aws_durable_execution_sdk_python.exceptions import (
1218
DurableExecutionsError,
1319
ExecutionError,
1420
SerDesError,
1521
)
22+
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
1623
from aws_durable_execution_sdk_python.serdes import (
1724
BytesCodec,
1825
ContainerCodec,
@@ -894,3 +901,82 @@ def test_all_t_v_nested_dicts():
894901

895902

896903
# endregion
904+
905+
906+
# to_dict() support tests
907+
def test_default_serdes_supports_to_dict_objects():
908+
"""Test that default serdes automatically handles BatchResult serialization/deserialization."""
909+
910+
result = BatchResult(
911+
all=[BatchItem(0, BatchItemStatus.SUCCEEDED, result="test")],
912+
completion_reason=CompletionReason.ALL_COMPLETED,
913+
)
914+
915+
# Default serdes should automatically handle BatchResult
916+
serialized = serialize(
917+
serdes=None,
918+
value=result,
919+
operation_id="test_op",
920+
durable_execution_arn="arn:test",
921+
)
922+
923+
# Deserialize returns BatchResult (not dict)
924+
deserialized = deserialize(
925+
serdes=None,
926+
data=serialized,
927+
operation_id="test_op",
928+
durable_execution_arn="arn:test",
929+
)
930+
931+
assert isinstance(deserialized, BatchResult)
932+
assert deserialized.completion_reason == CompletionReason.ALL_COMPLETED
933+
assert len(deserialized.all) == 1
934+
assert deserialized.all[0].result == "test"
935+
936+
937+
def test_to_dict_output_is_serializable():
938+
"""Test that to_dict() output is serializable by default serdes."""
939+
940+
result = BatchResult(
941+
all=[
942+
BatchItem(0, BatchItemStatus.SUCCEEDED, result={"key": "value"}),
943+
BatchItem(
944+
1,
945+
BatchItemStatus.FAILED,
946+
error=ErrorObject(
947+
message="error", type="TestError", data=None, stack_trace=[]
948+
),
949+
),
950+
],
951+
completion_reason=CompletionReason.ALL_COMPLETED,
952+
)
953+
954+
# Convert to dict
955+
result_dict = result.to_dict()
956+
957+
# Dict should be serializable
958+
serialized = serialize(
959+
serdes=None,
960+
value=result_dict,
961+
operation_id="test_op",
962+
durable_execution_arn="arn:test",
963+
)
964+
965+
# Deserialize
966+
deserialized_dict = deserialize(
967+
serdes=None,
968+
data=serialized,
969+
operation_id="test_op",
970+
durable_execution_arn="arn:test",
971+
)
972+
973+
# Verify structure preserved
974+
assert deserialized_dict["completionReason"] == "ALL_COMPLETED"
975+
assert len(deserialized_dict["all"]) == 2
976+
assert deserialized_dict["all"][0]["result"] == {"key": "value"}
977+
assert deserialized_dict["all"][1]["error"]["ErrorType"] == "TestError"
978+
979+
# Can reconstruct BatchResult
980+
reconstructed = BatchResult.from_dict(deserialized_dict)
981+
assert len(reconstructed.all) == 2
982+
assert reconstructed.completion_reason == CompletionReason.ALL_COMPLETED

0 commit comments

Comments
 (0)