Skip to content

Commit 3bed0ab

Browse files
author
Alex Wang
committed
feat: add replay aware state and logging
- check if it is replaying in ExecutionState - add should_replay logging in default logger feat: update logging extra field to CamelCase
1 parent 636d757 commit 3bed0ab

10 files changed

Lines changed: 260 additions & 82 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def _execute_item_in_child_context(
381381
executor_context._parent_id, # noqa: SLF001
382382
name,
383383
)
384+
child_context.state.track_replay(operation_id=operation_id)
384385

385386
def run_in_child_handler():
386387
return self.execute_item(child_context, executable)

src/aws_durable_execution_sdk_python/context.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def __init__(
176176
self._step_counter: OrderedCounter = OrderedCounter()
177177

178178
log_info = LogInfo(
179-
execution_arn=state.durable_execution_arn, parent_id=parent_id
179+
execution_state=state,
180+
parent_id=parent_id,
180181
)
181182
self._log_info = log_info
182183
self.logger: Logger = logger or Logger.from_log_info(
@@ -205,7 +206,8 @@ def create_child_context(self, parent_id: str) -> DurableContext:
205206
parent_id=parent_id,
206207
logger=self.logger.with_log_info(
207208
LogInfo(
208-
execution_arn=self.state.durable_execution_arn, parent_id=parent_id
209+
execution_state=self.state,
210+
parent_id=parent_id,
209211
)
210212
),
211213
)
@@ -269,6 +271,7 @@ def create_callback(
269271
if not config:
270272
config = CallbackConfig()
271273
operation_id: str = self._create_step_id()
274+
self.state.track_replay(operation_id=operation_id)
272275
callback_id: str = create_callback_handler(
273276
state=self.state,
274277
operation_identifier=OperationIdentifier(
@@ -302,12 +305,14 @@ def invoke(
302305
Returns:
303306
The result of the invoked function
304307
"""
308+
operation_id = self._create_step_id()
309+
self.state.track_replay(operation_id=operation_id)
305310
return invoke_handler(
306311
function_name=function_name,
307312
payload=payload,
308313
state=self.state,
309314
operation_identifier=OperationIdentifier(
310-
operation_id=self._create_step_id(),
315+
operation_id=operation_id,
311316
parent_id=self._parent_id,
312317
name=name,
313318
),
@@ -325,6 +330,7 @@ def map(
325330
map_name: str | None = self._resolve_step_name(name, func)
326331

327332
operation_id = self._create_step_id()
333+
self.state.track_replay(operation_id=operation_id)
328334
operation_identifier = OperationIdentifier(
329335
operation_id=operation_id, parent_id=self._parent_id, name=map_name
330336
)
@@ -367,6 +373,7 @@ def parallel(
367373
"""Execute multiple callables in parallel."""
368374
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
369375
operation_id = self._create_step_id()
376+
self.state.track_replay(operation_id=operation_id)
370377
parallel_context = self.create_child_context(parent_id=operation_id)
371378
operation_identifier = OperationIdentifier(
372379
operation_id=operation_id, parent_id=self._parent_id, name=name
@@ -420,6 +427,7 @@ def run_in_child_context(
420427
step_name: str | None = self._resolve_step_name(name, func)
421428
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
422429
operation_id = self._create_step_id()
430+
self.state.track_replay(operation_id=operation_id)
423431

424432
def callable_with_child_context():
425433
return func(self.create_child_context(parent_id=operation_id))
@@ -441,13 +449,15 @@ def step(
441449
) -> T:
442450
step_name = self._resolve_step_name(name, func)
443451
logger.debug("Step name: %s", step_name)
452+
operation_id = self._create_step_id()
453+
self.state.track_replay(operation_id=operation_id)
444454

445455
return step_handler(
446456
func=func,
447457
config=config,
448458
state=self.state,
449459
operation_identifier=OperationIdentifier(
450-
operation_id=self._create_step_id(),
460+
operation_id=operation_id,
451461
parent_id=self._parent_id,
452462
name=step_name,
453463
),
@@ -465,11 +475,13 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
465475
if seconds < 1:
466476
msg = "duration must be at least 1 second"
467477
raise ValidationError(msg)
478+
operation_id = self._create_step_id()
479+
self.state.track_replay(operation_id=operation_id)
468480
wait_handler(
469481
seconds=seconds,
470482
state=self.state,
471483
operation_identifier=OperationIdentifier(
472-
operation_id=self._create_step_id(),
484+
operation_id=operation_id,
473485
parent_id=self._parent_id,
474486
name=name,
475487
),
@@ -515,12 +527,14 @@ def wait_for_condition(
515527
msg = "`config` is required for wait_for_condition"
516528
raise ValidationError(msg)
517529

530+
operation_id = self._create_step_id()
531+
self.state.track_replay(operation_id=operation_id)
518532
return wait_for_condition_handler(
519533
check=check,
520534
config=config,
521535
state=self.state,
522536
operation_identifier=OperationIdentifier(
523-
operation_id=self._create_step_id(),
537+
operation_id=operation_id,
524538
parent_id=self._parent_id,
525539
name=name,
526540
),

src/aws_durable_execution_sdk_python/execution.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from enum import Enum
1010
from typing import TYPE_CHECKING, Any
1111

12-
from aws_durable_execution_sdk_python.context import DurableContext, ExecutionState
12+
from aws_durable_execution_sdk_python.context import DurableContext
1313
from aws_durable_execution_sdk_python.exceptions import (
1414
BackgroundThreadError,
1515
BotoClientError,
@@ -27,6 +27,7 @@
2727
OperationType,
2828
OperationUpdate,
2929
)
30+
from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus
3031

3132
if TYPE_CHECKING:
3233
from collections.abc import Callable, MutableMapping
@@ -268,6 +269,9 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
268269
initial_checkpoint_token=invocation_input.checkpoint_token,
269270
operations={},
270271
service_client=service_client,
272+
replay_status=ReplayStatus.REPLAY
273+
if len(invocation_input.initial_execution_state.operations) > 0
274+
else ReplayStatus.NEW,
271275
)
272276

273277
execution_state.fetch_paginated_operations(

src/aws_durable_execution_sdk_python/logger.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,30 @@
88
from aws_durable_execution_sdk_python.types import LoggerInterface
99

1010
if TYPE_CHECKING:
11-
from collections.abc import Mapping, MutableMapping
11+
from collections.abc import Callable, Mapping, MutableMapping
1212

13+
from aws_durable_execution_sdk_python.context import ExecutionState
1314
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1415

1516

1617
@dataclass(frozen=True)
1718
class LogInfo:
18-
execution_arn: str
19+
execution_state: ExecutionState
1920
parent_id: str | None = None
2021
operation_id: str | None = None
2122
name: str | None = None
2223
attempt: int | None = None
2324

2425
@classmethod
2526
def from_operation_identifier(
26-
cls, execution_arn: str, op_id: OperationIdentifier, attempt: int | None = None
27+
cls,
28+
execution_state: ExecutionState,
29+
op_id: OperationIdentifier,
30+
attempt: int | None = None,
2731
) -> LogInfo:
2832
"""Create new log info from an execution arn, OperationIdentifier and attempt."""
2933
return cls(
30-
execution_arn=execution_arn,
34+
execution_state=execution_state,
3135
parent_id=op_id.parent_id,
3236
operation_id=op_id.operation_id,
3337
name=op_id.name,
@@ -37,7 +41,7 @@ def from_operation_identifier(
3741
def with_parent_id(self, parent_id: str) -> LogInfo:
3842
"""Clone the log info with a new parent id."""
3943
return LogInfo(
40-
execution_arn=self.execution_arn,
44+
execution_state=self.execution_state,
4145
parent_id=parent_id,
4246
operation_id=self.operation_id,
4347
name=self.name,
@@ -47,25 +51,33 @@ def with_parent_id(self, parent_id: str) -> LogInfo:
4751

4852
class Logger(LoggerInterface):
4953
def __init__(
50-
self, logger: LoggerInterface, default_extra: Mapping[str, object]
54+
self,
55+
logger: LoggerInterface,
56+
default_extra: Mapping[str, object],
57+
execution_state: ExecutionState,
5158
) -> None:
5259
self._logger = logger
5360
self._default_extra = default_extra
61+
self._execution_state = execution_state
5462

5563
@classmethod
5664
def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
5765
"""Create a new logger with the given LogInfo."""
58-
extra: MutableMapping[str, object] = {"execution_arn": info.execution_arn}
66+
extra: MutableMapping[str, object] = {
67+
"executionArn": info.execution_state.durable_execution_arn
68+
}
5969
if info.parent_id:
60-
extra["parent_id"] = info.parent_id
70+
extra["parentId"] = info.parent_id
6171
if info.name:
6272
# Use 'operation_name' instead of 'name' as key because the stdlib LogRecord internally reserved 'name' parameter
63-
extra["operation_name"] = info.name
73+
extra["operationName"] = info.name
6474
if info.attempt is not None:
6575
extra["attempt"] = info.attempt + 1
6676
if info.operation_id:
67-
extra["operation_id"] = info.operation_id
68-
return cls(logger, extra)
77+
extra["operationId"] = info.operation_id
78+
return cls(
79+
logger=logger, default_extra=extra, execution_state=info.execution_state
80+
)
6981

7082
def with_log_info(self, info: LogInfo) -> Logger:
7183
"""Clone the existing logger with new LogInfo."""
@@ -81,29 +93,39 @@ def get_logger(self) -> LoggerInterface:
8193
def debug(
8294
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
8395
) -> None:
84-
merged_extra = {**self._default_extra, **(extra or {})}
85-
self._logger.debug(msg, *args, extra=merged_extra)
96+
self._log(self._logger.debug, msg, *args, extra=extra)
8697

8798
def info(
8899
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
89100
) -> None:
90-
merged_extra = {**self._default_extra, **(extra or {})}
91-
self._logger.info(msg, *args, extra=merged_extra)
101+
self._log(self._logger.info, msg, *args, extra=extra)
92102

93103
def warning(
94104
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
95105
) -> None:
96-
merged_extra = {**self._default_extra, **(extra or {})}
97-
self._logger.warning(msg, *args, extra=merged_extra)
106+
self._log(self._logger.warning, msg, *args, extra=extra)
98107

99108
def error(
100109
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
101110
) -> None:
102-
merged_extra = {**self._default_extra, **(extra or {})}
103-
self._logger.error(msg, *args, extra=merged_extra)
111+
self._log(self._logger.error, msg, *args, extra=extra)
104112

105113
def exception(
106114
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
107115
) -> None:
116+
self._log(self._logger.exception, msg, *args, extra=extra)
117+
118+
def _log(
119+
self,
120+
log_func: Callable,
121+
msg: object,
122+
*args: object,
123+
extra: Mapping[str, object] | None = None,
124+
):
125+
if not self._should_log():
126+
return
108127
merged_extra = {**self._default_extra, **(extra or {})}
109-
self._logger.exception(msg, *args, extra=merged_extra)
128+
log_func(msg, *args, extra=merged_extra)
129+
130+
def _should_log(self) -> bool:
131+
return not self._execution_state.is_replaying()

src/aws_durable_execution_sdk_python/operation/step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def step_handler(
128128
step_context = StepContext(
129129
logger=context_logger.with_log_info(
130130
LogInfo.from_operation_identifier(
131-
execution_arn=state.durable_execution_arn,
131+
execution_state=state,
132132
op_id=operation_identifier,
133133
attempt=attempt,
134134
)

src/aws_durable_execution_sdk_python/operation/wait_for_condition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def wait_for_condition_handler(
133133
check_context = WaitForConditionCheckContext(
134134
logger=context_logger.with_log_info(
135135
LogInfo.from_operation_identifier(
136-
execution_arn=state.durable_execution_arn,
136+
execution_state=state,
137137
op_id=operation_identifier,
138138
attempt=attempt,
139139
)

src/aws_durable_execution_sdk_python/state.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import threading
99
import time
1010
from dataclasses import dataclass
11+
from enum import Enum
1112
from threading import Lock
1213
from typing import TYPE_CHECKING
1314

@@ -210,6 +211,13 @@ def get_next_attempt_timestamp(self) -> datetime.datetime | None:
210211
CHECKPOINT_NOT_FOUND = CheckpointedResult.create_not_found()
211212

212213

214+
class ReplayStatus(Enum):
215+
"""Status indicating whether execution is replaying or executing new operations."""
216+
217+
REPLAY = "replay"
218+
NEW = "new"
219+
220+
213221
class ExecutionState:
214222
"""Get, set and maintain execution state. This is mutable. Create and check checkpoints."""
215223

@@ -220,6 +228,7 @@ def __init__(
220228
operations: MutableMapping[str, Operation],
221229
service_client: DurableServiceClient,
222230
batcher_config: CheckpointBatcherConfig | None = None,
231+
replay_status: ReplayStatus = ReplayStatus.NEW,
223232
):
224233
self.durable_execution_arn: str = durable_execution_arn
225234
self._current_checkpoint_token: str = initial_checkpoint_token
@@ -247,6 +256,8 @@ def __init__(
247256

248257
# Protects parent_to_children and parent_done
249258
self._parent_done_lock: Lock = Lock()
259+
self._replay_status: ReplayStatus = replay_status
260+
self._replay_status_lock: Lock = Lock()
250261

251262
def fetch_paginated_operations(
252263
self,
@@ -277,6 +288,42 @@ def fetch_paginated_operations(
277288
with self._operations_lock:
278289
self.operations.update({op.operation_id: op for op in all_operations})
279290

291+
def track_replay(self, operation_id: str) -> None:
292+
"""Check if operation exists with completed status; if not, transition to NEW status.
293+
294+
This method is called before each operation (step, wait, invoke, etc.) to determine
295+
if we've reached the replay boundary. Once we encounter an operation that doesn't
296+
exist or isn't completed, we transition from REPLAY to NEW status, which enables
297+
logging for all subsequent code.
298+
299+
Args:
300+
operation_id: The operation ID to check
301+
"""
302+
with self._replay_status_lock:
303+
if self._replay_status == ReplayStatus.REPLAY:
304+
operation = self.operations.get(operation_id)
305+
# Transition if operation doesn't exist OR isn't in a completed state
306+
if not operation or operation.status not in {
307+
OperationStatus.SUCCEEDED,
308+
OperationStatus.FAILED,
309+
OperationStatus.CANCELLED,
310+
OperationStatus.STOPPED,
311+
}:
312+
logger.debug(
313+
"Transitioning from REPLAY to NEW status at operation %s",
314+
operation_id,
315+
)
316+
self._replay_status = ReplayStatus.NEW
317+
318+
def is_replaying(self) -> bool:
319+
"""Check if execution is currently in replay mode.
320+
321+
Returns:
322+
True if in REPLAY status, False if in NEW status
323+
"""
324+
with self._replay_status_lock:
325+
return self._replay_status is ReplayStatus.REPLAY
326+
280327
def get_checkpoint_result(self, checkpoint_id: str) -> CheckpointedResult:
281328
"""Get checkpoint result.
282329

tests/e2e/execution_int_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,10 @@ def mock_checkpoint(
217217
123,
218218
"str",
219219
extra={
220-
"execution_arn": "test-arn",
221-
"operation_name": "mystep",
220+
"executionArn": "test-arn",
221+
"operationName": "mystep",
222222
"attempt": 1,
223-
"operation_id": operation_id,
223+
"operationId": operation_id,
224224
},
225225
)
226226

0 commit comments

Comments
 (0)