Skip to content

Commit debd520

Browse files
Alex Wangwangyb-A
authored andcommitted
feat: Add SerDes support
- Add SerDes and default JsonSerDes class - Add serialize and deserialize helper method - Add SerDes support for operations: - child - step - wait_for_condition - Add and update unit tests
1 parent b46b73a commit debd520

12 files changed

Lines changed: 585 additions & 102 deletions

File tree

src/aws_durable_functions_sdk_python/config.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
from abc import ABC, abstractmethod
65
from dataclasses import dataclass, field
76
from enum import Enum
87
from typing import TYPE_CHECKING, Generic, TypeVar
@@ -18,6 +17,8 @@
1817
from concurrent.futures import Future
1918

2019
from aws_durable_functions_sdk_python.lambda_service import OperationSubType
20+
from aws_durable_functions_sdk_python.serdes import SerDes
21+
2122

2223
Numeric = int | float # deliberately leaving off complex
2324

@@ -82,16 +83,6 @@ class ParallelConfig:
8283
serdes: SerDes | None = None
8384

8485

85-
class SerDes(ABC, Generic[T]):
86-
@abstractmethod
87-
def serialize(self, value: T) -> str:
88-
pass
89-
90-
@abstractmethod
91-
def deserialize(self, data: str) -> T:
92-
pass
93-
94-
9586
class StepSemantics(Enum):
9687
AT_MOST_ONCE_PER_RETRY = "AT_MOST_ONCE_PER_RETRY"
9788
AT_LEAST_ONCE_PER_RETRY = "AT_LEAST_ONCE_PER_RETRY"

src/aws_durable_functions_sdk_python/context.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
import logging
54
from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar
65

@@ -10,7 +9,6 @@
109
ChildConfig,
1110
MapConfig,
1211
ParallelConfig,
13-
SerDes,
1412
StepConfig,
1513
WaitForCallbackConfig,
1614
WaitForConditionConfig,
@@ -39,6 +37,7 @@
3937
from aws_durable_functions_sdk_python.operation.wait_for_condition import (
4038
wait_for_condition_handler,
4139
)
40+
from aws_durable_functions_sdk_python.serdes import SerDes, deserialize
4241
from aws_durable_functions_sdk_python.state import ExecutionState # noqa: TCH001
4342
from aws_durable_functions_sdk_python.threading import OrderedCounter
4443
from aws_durable_functions_sdk_python.types import (
@@ -103,12 +102,12 @@ def __init__(
103102
callback_id: str,
104103
operation_id: str,
105104
state: ExecutionState,
106-
serdes: SerDes | None = None,
105+
serdes: SerDes[T] | None = None,
107106
):
108107
self.callback_id: str = callback_id
109108
self.operation_id: str = operation_id
110109
self.state: ExecutionState = state
111-
self.serdes: SerDes | None = serdes
110+
self.serdes: SerDes[T] | None = serdes
112111

113112
def result(self) -> T | None:
114113
"""Return the result of the future. Will block until result is available.
@@ -132,11 +131,15 @@ def result(self) -> T | None:
132131
checkpointed_result.raise_callable_error()
133132

134133
if checkpointed_result.is_succeeded():
135-
# TODO: serdes
136134
if checkpointed_result.result is None:
137135
return None # type: ignore
138136

139-
return json.loads(checkpointed_result.result)
137+
return deserialize(
138+
serdes=self.serdes,
139+
data=checkpointed_result.result,
140+
operation_id=self.operation_id,
141+
durable_execution_arn=self.state.durable_execution_arn,
142+
)
140143

141144
msg = "Callback must be started before you can await the result."
142145
raise FatalError(msg)
@@ -270,6 +273,8 @@ def create_callback(
270273
Return:
271274
Callback future. Use result() on this future to wait for the callback resuilt.
272275
"""
276+
if not config:
277+
config = CallbackConfig()
273278
operation_id: str = self._create_step_id()
274279
callback_id: str = create_callback_handler(
275280
state=self.state,
@@ -280,7 +285,10 @@ def create_callback(
280285
)
281286

282287
return Callback(
283-
callback_id=callback_id, operation_id=operation_id, state=self.state
288+
callback_id=callback_id,
289+
operation_id=operation_id,
290+
state=self.state,
291+
serdes=config.serdes,
284292
)
285293

286294
def map(

src/aws_durable_functions_sdk_python/operation/child.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import json
65
import logging
76
from typing import TYPE_CHECKING, TypeVar
87

@@ -13,6 +12,7 @@
1312
OperationSubType,
1413
OperationUpdate,
1514
)
15+
from aws_durable_functions_sdk_python.serdes import deserialize, serialize
1616

1717
if TYPE_CHECKING:
1818
from collections.abc import Callable
@@ -50,8 +50,12 @@ def child_handler(
5050
)
5151
if checkpointed_result.result is None:
5252
return None # type: ignore
53-
return json.loads(checkpointed_result.result)
54-
53+
return deserialize(
54+
serdes=config.serdes,
55+
data=checkpointed_result.result,
56+
operation_id=operation_identifier.operation_id,
57+
durable_execution_arn=state.durable_execution_arn,
58+
)
5559
if checkpointed_result.is_failed():
5660
checkpointed_result.raise_callable_error()
5761
sub_type = (
@@ -67,7 +71,12 @@ def child_handler(
6771

6872
try:
6973
raw_result: T = func()
70-
serialized_result: str = json.dumps(raw_result)
74+
serialized_result: str = serialize(
75+
serdes=config.serdes,
76+
value=raw_result,
77+
operation_id=operation_identifier.operation_id,
78+
durable_execution_arn=state.durable_execution_arn,
79+
)
7180

7281
success_operation = OperationUpdate.create_context_succeed(
7382
identifier=operation_identifier,

src/aws_durable_functions_sdk_python/operation/step.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import json
65
import logging
76
import time
87
from typing import TYPE_CHECKING, TypeVar
@@ -20,6 +19,7 @@
2019
from aws_durable_functions_sdk_python.lambda_service import ErrorObject, OperationUpdate
2120
from aws_durable_functions_sdk_python.logger import Logger, LogInfo
2221
from aws_durable_functions_sdk_python.retries import RetryPresets
22+
from aws_durable_functions_sdk_python.serdes import deserialize, serialize
2323
from aws_durable_functions_sdk_python.types import StepContext
2424

2525
if TYPE_CHECKING:
@@ -59,11 +59,15 @@ def step_handler(
5959
operation_identifier.operation_id,
6060
operation_identifier.name,
6161
)
62-
# TODO: serdes
6362
if checkpointed_result.result is None:
6463
return None # type: ignore
6564

66-
return json.loads(checkpointed_result.result)
65+
return deserialize(
66+
serdes=config.serdes,
67+
data=checkpointed_result.result,
68+
operation_id=operation_identifier.operation_id,
69+
durable_execution_arn=state.durable_execution_arn,
70+
)
6771

6872
if checkpointed_result.is_failed():
6973
# have to throw the exact same error on replay as the checkpointed failure
@@ -107,7 +111,12 @@ def step_handler(
107111
try:
108112
# this is the actual code provided by the caller to execute durably inside the step
109113
raw_result: T = func(step_context)
110-
serialized_result: str = json.dumps(raw_result)
114+
serialized_result: str = serialize(
115+
serdes=config.serdes,
116+
value=raw_result,
117+
operation_id=operation_identifier.operation_id,
118+
durable_execution_arn=state.durable_execution_arn,
119+
)
111120

112121
success_operation: OperationUpdate = OperationUpdate.create_step_succeed(
113122
identifier=operation_identifier,

src/aws_durable_functions_sdk_python/operation/wait_for_condition.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import json
65
import logging
76
import time
87
from typing import TYPE_CHECKING, TypeVar
@@ -13,6 +12,7 @@
1312
)
1413
from aws_durable_functions_sdk_python.lambda_service import ErrorObject, OperationUpdate
1514
from aws_durable_functions_sdk_python.logger import LogInfo
15+
from aws_durable_functions_sdk_python.serdes import deserialize, serialize
1616
from aws_durable_functions_sdk_python.types import WaitForConditionCheckContext
1717

1818
if TYPE_CHECKING:
@@ -26,6 +26,7 @@
2626
from aws_durable_functions_sdk_python.logger import Logger
2727
from aws_durable_functions_sdk_python.state import ExecutionState
2828

29+
2930
T = TypeVar("T")
3031

3132
logger = logging.getLogger(__name__)
@@ -57,10 +58,14 @@ def wait_for_condition_handler(
5758
operation_identifier.operation_id,
5859
operation_identifier.name,
5960
)
60-
# TODO: use serdes from config
6161
if checkpointed_result.result is None:
6262
return None # type: ignore
63-
return json.loads(checkpointed_result.result)
63+
return deserialize(
64+
serdes=config.serdes,
65+
data=checkpointed_result.result,
66+
operation_id=operation_identifier.operation_id,
67+
durable_execution_arn=state.durable_execution_arn,
68+
)
6469

6570
if checkpointed_result.is_failed():
6671
checkpointed_result.raise_callable_error()
@@ -69,9 +74,13 @@ def wait_for_condition_handler(
6974
if checkpointed_result.is_started_or_ready():
7075
# This is a retry - get state from previous checkpoint
7176
if checkpointed_result.result:
72-
# TODO: serdes here
7377
try:
74-
current_state = json.loads(checkpointed_result.result)
78+
current_state = deserialize(
79+
serdes=config.serdes,
80+
data=checkpointed_result.result,
81+
operation_id=operation_identifier.operation_id,
82+
durable_execution_arn=state.durable_execution_arn,
83+
)
7584
except Exception:
7685
# default to initial state if there's an error getting checkpointed state
7786
logger.exception(
@@ -117,8 +126,12 @@ def wait_for_condition_handler(
117126
# Check if condition is met with the wait strategy
118127
decision: WaitForConditionDecision = config.wait_strategy(new_state, attempt)
119128

120-
# TODO: SerDes here
121-
serialized_state = json.dumps(new_state)
129+
serialized_state = serialize(
130+
serdes=config.serdes,
131+
value=new_state,
132+
operation_id=operation_identifier.operation_id,
133+
durable_execution_arn=state.durable_execution_arn,
134+
)
122135

123136
logger.debug(
124137
"wait_for_condition check completed: %s, name: %s, attempt: %s",
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Serialization and deserialization"""
2+
3+
import json
4+
import logging
5+
from abc import ABC, abstractmethod
6+
from dataclasses import dataclass
7+
from typing import Generic, TypeVar
8+
9+
from aws_durable_functions_sdk_python.exceptions import FatalError
10+
11+
logger = logging.getLogger(__name__)
12+
13+
T = TypeVar("T")
14+
15+
16+
@dataclass(frozen=True)
17+
class SerDesContext:
18+
operation_id: str
19+
durable_execution_arn: str
20+
21+
22+
class SerDes(ABC, Generic[T]):
23+
@abstractmethod
24+
def serialize(self, value: T, serdes_context: SerDesContext) -> str:
25+
pass
26+
27+
@abstractmethod
28+
def deserialize(self, data: str, serdes_context: SerDesContext) -> T:
29+
pass
30+
31+
32+
class JsonSerDes(SerDes[T]):
33+
def serialize(self, value: T, _: SerDesContext) -> str:
34+
return json.dumps(value)
35+
36+
def deserialize(self, data: str, _: SerDesContext) -> T:
37+
return json.loads(data)
38+
39+
40+
_DEFAULT_JSON_SERDES: SerDes = JsonSerDes()
41+
42+
43+
def serialize(
44+
serdes: SerDes[T] | None, value: T, operation_id: str, durable_execution_arn: str
45+
) -> str:
46+
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
47+
if serdes is None:
48+
serdes = _DEFAULT_JSON_SERDES
49+
try:
50+
return serdes.serialize(value, serdes_context)
51+
except Exception as e:
52+
logger.exception(
53+
"⚠️ Serialization failed for id: %s",
54+
operation_id,
55+
)
56+
msg = f"Serialization failed for id: {operation_id}, error: {e}."
57+
raise FatalError(msg) from e
58+
59+
60+
def deserialize(
61+
serdes: SerDes[T] | None, data: str, operation_id: str, durable_execution_arn: str
62+
) -> T:
63+
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
64+
if serdes is None:
65+
serdes = _DEFAULT_JSON_SERDES
66+
try:
67+
return serdes.deserialize(data, serdes_context)
68+
except Exception as e:
69+
logger.exception(
70+
"⚠️ Deserialization failed for id: %s",
71+
operation_id,
72+
)
73+
msg = f"Deserialization failed for id: {operation_id}"
74+
raise FatalError(msg) from e

0 commit comments

Comments
 (0)