Skip to content

Commit d418e5d

Browse files
committed
feat: add BatchResult serialization support with dedicated codec
1 parent a950699 commit d418e5d

8 files changed

Lines changed: 607 additions & 22 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,12 @@ class ExecutionCounters:
367367
def __init__(
368368
self,
369369
total_tasks: int,
370-
min_successful: int,
370+
min_successful: int | None,
371371
tolerated_failure_count: int | None,
372372
tolerated_failure_percentage: float | None,
373373
):
374374
self.total_tasks: int = total_tasks
375-
self.min_successful: int = min_successful
375+
self.min_successful: int | None = min_successful
376376
self.tolerated_failure_count: int | None = tolerated_failure_count
377377
self.tolerated_failure_percentage: float | None = tolerated_failure_percentage
378378
self.success_count: int = 0
@@ -421,24 +421,26 @@ def is_complete(self) -> bool:
421421
"""
422422
Check if execution should complete (based on completion criteria).
423423
Matches TypeScript isComplete() logic.
424+
425+
Note: This method only checks completion criteria (all done, or min_successful met).
426+
Failure tolerance is enforced separately by should_continue() and combined in should_complete().
424427
"""
425428
with self._lock:
426429
completed_count = self.success_count + self.failure_count
427430

428431
# All tasks completed
429432
if completed_count == self.total_tasks:
430-
# Complete if no failure tolerance OR no failures OR min successful reached
431-
return (
432-
(
433-
self.tolerated_failure_count is None
434-
and self.tolerated_failure_percentage is None
435-
)
436-
or self.failure_count == 0
437-
or self.success_count >= self.min_successful
438-
)
433+
# If min_successful is explicitly set, check if we met it
434+
# Otherwise, complete when all tasks are done
435+
if self.min_successful is not None:
436+
return self.success_count >= self.min_successful
437+
return True
439438

440-
# when we breach min successful, we've completed
441-
return self.success_count >= self.min_successful
439+
# Early completion: when we breach min_successful (only if explicitly set)
440+
return (
441+
self.min_successful is not None
442+
and self.success_count >= self.min_successful
443+
)
442444

443445
def should_complete(self) -> bool:
444446
"""
@@ -455,7 +457,10 @@ def is_all_completed(self) -> bool:
455457
def is_min_successful_reached(self) -> bool:
456458
"""True if minimum successful tasks reached."""
457459
with self._lock:
458-
return self.success_count >= self.min_successful
460+
return (
461+
self.min_successful is not None
462+
and self.success_count >= self.min_successful
463+
)
459464

460465
def is_failure_tolerance_exceeded(self) -> bool:
461466
"""True if failure tolerance was exceeded."""
@@ -594,7 +599,9 @@ def __init__(
594599
self._suspend_exception: SuspendExecution | None = None
595600

596601
# ExecutionCounters will keep track of completion criteria and on-going counters
597-
min_successful = self.completion_config.min_successful or len(self.executables)
602+
# Note: min_successful should remain None if not explicitly set
603+
# When None, the operation completes when all tasks finish (respecting failure tolerance)
604+
min_successful = self.completion_config.min_successful
598605
tolerated_failure_count = self.completion_config.tolerated_failure_count
599606
tolerated_failure_percentage = (
600607
self.completion_config.tolerated_failure_percentage

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def from_items(
8282
name_prefix="map-item-",
8383
serdes=config.serdes,
8484
summary_generator=config.summary_generator,
85+
item_serdes=config.item_serdes,
8586
)
8687

8788
def execute_item(self, child_context, executable: Executable[Callable]) -> R:

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def from_callables(
6969
name_prefix="parallel-branch-",
7070
serdes=config.serdes,
7171
summary_generator=config.summary_generator,
72+
item_serdes=config.item_serdes,
7273
)
7374

7475
def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class TypeTag(StrEnum):
6262
TUPLE = "t"
6363
LIST = "l"
6464
DICT = "m"
65+
BATCH_RESULT = "br"
6566

6667

6768
@dataclass(frozen=True)
@@ -206,7 +207,18 @@ def dispatcher(self):
206207

207208
def encode(self, obj: Any) -> EncodedValue:
208209
"""Encode container using dispatcher for recursive elements."""
210+
# Import here to avoid circular dependency
211+
from aws_durable_execution_sdk_python.concurrency import (
212+
BatchResult,
213+
) # noqa: PLC0415
214+
209215
match obj:
216+
case BatchResult():
217+
# Encode BatchResult as dict with special tag
218+
return EncodedValue(
219+
TypeTag.BATCH_RESULT,
220+
self._wrap(obj.to_dict(), self.dispatcher).value,
221+
)
210222
case list():
211223
return EncodedValue(
212224
TypeTag.LIST, [self._wrap(v, self.dispatcher) for v in obj]
@@ -230,7 +242,16 @@ def encode(self, obj: Any) -> EncodedValue:
230242

231243
def decode(self, tag: TypeTag, value: Any) -> Any:
232244
"""Decode container using dispatcher for recursive elements."""
245+
# Import here to avoid circular dependency
246+
from aws_durable_execution_sdk_python.concurrency import (
247+
BatchResult,
248+
) # noqa: PLC0415
249+
233250
match tag:
251+
case TypeTag.BATCH_RESULT:
252+
# Decode as dict (handles all recursive unwrapping) then reconstruct
253+
decoded_dict = self.decode(TypeTag.DICT, value)
254+
return BatchResult.from_dict(decoded_dict)
234255
case TypeTag.LIST:
235256
if not isinstance(value, list):
236257
msg = f"Expected list, got {type(value)}"
@@ -281,6 +302,9 @@ def __init__(self):
281302
self.container_codec.set_dispatcher(self)
282303

283304
def encode(self, obj: Any) -> EncodedValue:
305+
# Import here to avoid circular dependency
306+
from aws_durable_execution_sdk_python.concurrency import BatchResult
307+
284308
match obj:
285309
case None | str() | bool() | int() | float():
286310
return self.primitive_codec.encode(obj)
@@ -292,7 +316,7 @@ def encode(self, obj: Any) -> EncodedValue:
292316
return self.decimal_codec.encode(obj)
293317
case datetime() | date():
294318
return self.datetime_codec.encode(obj)
295-
case list() | tuple() | dict():
319+
case BatchResult() | list() | tuple() | dict():
296320
return self.container_codec.encode(obj)
297321
case _:
298322
msg = f"Unsupported type: {type(obj)}"
@@ -301,11 +325,7 @@ def encode(self, obj: Any) -> EncodedValue:
301325
def decode(self, tag: TypeTag, value: Any) -> Any:
302326
match tag:
303327
case (
304-
TypeTag.NONE
305-
| TypeTag.STR
306-
| TypeTag.BOOL
307-
| TypeTag.INT
308-
| TypeTag.FLOAT
328+
TypeTag.NONE | TypeTag.STR | TypeTag.BOOL | TypeTag.INT | TypeTag.FLOAT
309329
):
310330
return self.primitive_codec.decode(tag, value)
311331
case TypeTag.BYTES:
@@ -316,7 +336,7 @@ def decode(self, tag: TypeTag, value: Any) -> Any:
316336
return self.decimal_codec.decode(tag, value)
317337
case TypeTag.DATETIME | TypeTag.DATE:
318338
return self.datetime_codec.decode(tag, value)
319-
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
339+
case TypeTag.BATCH_RESULT | TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
320340
return self.container_codec.decode(tag, value)
321341
case _:
322342
msg = f"Unknown type tag: {tag}"

tests/concurrency_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,54 @@ def test_execution_counters_zero_total_tasks():
809809
assert not counters.is_failure_tolerance_exceeded()
810810

811811

812+
def test_execution_counters_none_min_successful():
813+
"""Test ExecutionCounters with None min_successful completes when all tasks done."""
814+
counters = ExecutionCounters(5, None, None, None)
815+
816+
# Should not complete early
817+
assert not counters.should_complete()
818+
819+
counters.complete_task()
820+
counters.complete_task()
821+
counters.complete_task()
822+
assert not counters.should_complete()
823+
824+
# Should complete when all tasks are done
825+
counters.complete_task()
826+
counters.complete_task()
827+
assert counters.should_complete()
828+
829+
830+
def test_execution_counters_none_min_successful_with_failures():
831+
"""Test ExecutionCounters with None min_successful and failure tolerance."""
832+
counters = ExecutionCounters(5, None, 2, None)
833+
834+
# Should not complete early even with successes
835+
counters.complete_task()
836+
counters.complete_task()
837+
assert not counters.should_complete()
838+
839+
# Should complete when failure tolerance exceeded
840+
counters.fail_task()
841+
counters.fail_task()
842+
counters.fail_task()
843+
assert counters.should_complete()
844+
845+
846+
def test_execution_counters_is_min_successful_reached_with_none():
847+
"""Test is_min_successful_reached returns False when min_successful is None."""
848+
counters = ExecutionCounters(5, None, None, None)
849+
850+
counters.complete_task()
851+
counters.complete_task()
852+
counters.complete_task()
853+
counters.complete_task()
854+
counters.complete_task()
855+
856+
# Should always return False when min_successful is None
857+
assert not counters.is_min_successful_reached()
858+
859+
812860
def test_execution_counters_failure_percentage_edge_case():
813861
"""Test ExecutionCounters failure percentage at exact threshold."""
814862
counters = ExecutionCounters(10, 5, None, 20.0)

tests/operation/map_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for map operation."""
22

3+
import json
34
from unittest.mock import Mock, patch
45

56
# Mock the executor.execute method
@@ -750,3 +751,45 @@ def get_checkpoint_result(self, operation_id):
750751
# Verify replay was called, execute was not
751752
mock_replay.assert_called_once()
752753
mock_execute.assert_not_called()
754+
755+
756+
def test_map_result_serialization_roundtrip():
757+
"""Test that map operation BatchResult can be serialized and deserialized."""
758+
759+
items = ["a", "b", "c"]
760+
761+
def func(ctx, item, idx, items):
762+
return {"item": item.upper(), "index": idx}
763+
764+
class MockExecutionState:
765+
durable_execution_arn = "arn:test"
766+
767+
def get_checkpoint_result(self, operation_id):
768+
mock_result = Mock()
769+
mock_result.is_succeeded.return_value = False
770+
return mock_result
771+
772+
execution_state = MockExecutionState()
773+
map_context = Mock()
774+
map_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2", "3"]) # noqa SLF001
775+
map_context.create_child_context = Mock(return_value=Mock())
776+
operation_identifier = OperationIdentifier("test_op", "parent", "test_map")
777+
778+
# Execute map
779+
result = map_handler(
780+
items, func, MapConfig(), execution_state, map_context, operation_identifier
781+
)
782+
783+
# Serialize the BatchResult
784+
serialized = json.dumps(result.to_dict())
785+
786+
# Deserialize
787+
deserialized = BatchResult.from_dict(json.loads(serialized))
788+
789+
# Verify all data preserved
790+
assert len(deserialized.all) == 3
791+
assert deserialized.all[0].result == {"item": "A", "index": 0}
792+
assert deserialized.all[1].result == {"item": "B", "index": 1}
793+
assert deserialized.all[2].result == {"item": "C", "index": 2}
794+
assert deserialized.completion_reason == result.completion_reason
795+
assert all(item.status == BatchItemStatus.SUCCEEDED for item in deserialized.all)

tests/operation/parallel_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the parallel operation module."""
22

3+
import json
34
from unittest.mock import Mock, patch
45

56
import pytest
@@ -734,3 +735,57 @@ def get_checkpoint_result(self, operation_id):
734735
# Verify replay was called, execute was not
735736
mock_replay.assert_called_once()
736737
mock_execute.assert_not_called()
738+
739+
740+
def test_parallel_result_serialization_roundtrip():
741+
"""Test that parallel operation BatchResult can be serialized and deserialized."""
742+
743+
def func1(ctx):
744+
return [1, 2, 3]
745+
746+
def func2(ctx):
747+
return {"status": "complete", "count": 42}
748+
749+
def func3(ctx):
750+
return "simple string"
751+
752+
callables = [func1, func2, func3]
753+
754+
class MockExecutionState:
755+
durable_execution_arn = "arn:test"
756+
757+
def get_checkpoint_result(self, operation_id):
758+
mock_result = Mock()
759+
mock_result.is_succeeded.return_value = False
760+
return mock_result
761+
762+
execution_state = MockExecutionState()
763+
parallel_context = Mock()
764+
parallel_context._create_step_id_for_logical_step = Mock( # noqa SLF001
765+
side_effect=["1", "2", "3"]
766+
)
767+
parallel_context.create_child_context = Mock(return_value=Mock())
768+
operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel")
769+
770+
# Execute parallel
771+
result = parallel_handler(
772+
callables,
773+
ParallelConfig(),
774+
execution_state,
775+
parallel_context,
776+
operation_identifier,
777+
)
778+
779+
# Serialize the BatchResult
780+
serialized = json.dumps(result.to_dict())
781+
782+
# Deserialize
783+
deserialized = BatchResult.from_dict(json.loads(serialized))
784+
785+
# Verify all data preserved
786+
assert len(deserialized.all) == 3
787+
assert deserialized.all[0].result == [1, 2, 3]
788+
assert deserialized.all[1].result == {"status": "complete", "count": 42}
789+
assert deserialized.all[2].result == "simple string"
790+
assert deserialized.completion_reason == result.completion_reason
791+
assert all(item.status == BatchItemStatus.SUCCEEDED for item in deserialized.all)

0 commit comments

Comments
 (0)