Skip to content

Commit 39c0153

Browse files
author
Rares Polenciuc
committed
feat: add replay-aware logging to suppress logs during execution replay
- Track visited operations in DurableContext logger - Suppress default logger output when replaying existing operations - Add ReplayAwareLogger class for custom replay behavior - Provide is_replay() method to check current execution state - Ignore EXECUTION type operations in replay detection
1 parent 5abdb88 commit 39c0153

3 files changed

Lines changed: 327 additions & 8 deletions

File tree

src/aws_durable_execution_sdk_python/context.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __init__(
167167
self.logger: Logger = logger or Logger.from_log_info(
168168
logger=logging.getLogger(),
169169
info=log_info,
170+
execution_state=state,
170171
)
171172

172173
# region factories
@@ -212,6 +213,8 @@ def set_logger(self, new_logger: LoggerInterface):
212213
self.logger = Logger.from_log_info(
213214
logger=new_logger,
214215
info=self._log_info,
216+
execution_state=self.state,
217+
visited_operations=self.logger.visited_operations,
215218
)
216219

217220
def _create_step_id(self) -> str:
@@ -248,6 +251,9 @@ def create_callback(
248251
if not config:
249252
config = CallbackConfig()
250253
operation_id: str = self._create_step_id()
254+
# Mark operation as visited before execution
255+
self.logger.mark_operation_visited(operation_id)
256+
251257
callback_id: str = create_callback_handler(
252258
state=self.state,
253259
operation_identifier=OperationIdentifier(
@@ -281,12 +287,16 @@ def invoke(
281287
Returns:
282288
The result of the invoked function
283289
"""
290+
operation_id = self._create_step_id()
291+
# Mark operation as visited before execution
292+
self.logger.mark_operation_visited(operation_id)
293+
284294
return invoke_handler(
285295
function_name=function_name,
286296
payload=payload,
287297
state=self.state,
288298
operation_identifier=OperationIdentifier(
289-
operation_id=self._create_step_id(),
299+
operation_id=operation_id,
290300
parent_id=self._parent_id,
291301
name=name,
292302
),
@@ -361,6 +371,8 @@ def run_in_child_context(
361371
step_name: str | None = self._resolve_step_name(name, func)
362372
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
363373
operation_id = self._create_step_id()
374+
# Mark operation as visited before execution
375+
self.logger.mark_operation_visited(operation_id)
364376

365377
def callable_with_child_context():
366378
return func(self.create_child_context(parent_id=operation_id))
@@ -383,12 +395,16 @@ def step(
383395
step_name = self._resolve_step_name(name, func)
384396
logger.debug("Step name: %s", step_name)
385397

398+
operation_id = self._create_step_id()
399+
# Mark operation as visited before execution
400+
self.logger.mark_operation_visited(operation_id)
401+
386402
return step_handler(
387403
func=func,
388404
config=config,
389405
state=self.state,
390406
operation_identifier=OperationIdentifier(
391-
operation_id=self._create_step_id(),
407+
operation_id=operation_id,
392408
parent_id=self._parent_id,
393409
name=step_name,
394410
),
@@ -405,11 +421,16 @@ def wait(self, seconds: int, name: str | None = None) -> None:
405421
if seconds < 1:
406422
msg = "seconds must be an integer greater than 0"
407423
raise ValidationError(msg)
424+
425+
operation_id = self._create_step_id()
426+
# Mark operation as visited before execution
427+
self.logger.mark_operation_visited(operation_id)
428+
408429
wait_handler(
409430
seconds=seconds,
410431
state=self.state,
411432
operation_identifier=OperationIdentifier(
412-
operation_id=self._create_step_id(),
433+
operation_id=operation_id,
413434
parent_id=self._parent_id,
414435
name=name,
415436
),
@@ -455,12 +476,16 @@ def wait_for_condition(
455476
msg = "`config` is required for wait_for_condition"
456477
raise ValidationError(msg)
457478

479+
operation_id = self._create_step_id()
480+
# Mark operation as visited before execution
481+
self.logger.mark_operation_visited(operation_id)
482+
458483
return wait_for_condition_handler(
459484
check=check,
460485
config=config,
461486
state=self.state,
462487
operation_identifier=OperationIdentifier(
463-
operation_id=self._create_step_id(),
488+
operation_id=operation_id,
464489
parent_id=self._parent_id,
465490
name=name,
466491
),

src/aws_durable_execution_sdk_python/logger.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from dataclasses import dataclass
66
from typing import TYPE_CHECKING
77

8+
from aws_durable_execution_sdk_python.lambda_service import OperationType
89
from aws_durable_execution_sdk_python.types import LoggerInterface
910

1011
if TYPE_CHECKING:
1112
from collections.abc import Mapping, MutableMapping
1213

1314
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
15+
from aws_durable_execution_sdk_python.state import ExecutionState
1416

1517

1618
@dataclass(frozen=True)
@@ -44,13 +46,25 @@ def with_parent_id(self, parent_id: str) -> LogInfo:
4446

4547
class Logger(LoggerInterface):
4648
def __init__(
47-
self, logger: LoggerInterface, default_extra: Mapping[str, object]
49+
self,
50+
logger: LoggerInterface,
51+
default_extra: Mapping[str, object],
52+
execution_state: ExecutionState | None = None,
53+
visited_operations: set[str] | None = None,
4854
) -> None:
4955
self._logger = logger
5056
self._default_extra = default_extra
57+
self._execution_state = execution_state
58+
self._visited_operations = visited_operations or set()
5159

5260
@classmethod
53-
def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
61+
def from_log_info(
62+
cls,
63+
logger: LoggerInterface,
64+
info: LogInfo,
65+
execution_state: ExecutionState | None = None,
66+
visited_operations: set[str] | None = None,
67+
) -> Logger:
5468
"""Create a new logger with the given LogInfo."""
5569
extra: MutableMapping[str, object] = {"execution_arn": info.execution_arn}
5670
if info.parent_id:
@@ -59,45 +73,118 @@ def from_log_info(cls, logger: LoggerInterface, info: LogInfo) -> Logger:
5973
extra["name"] = info.name
6074
if info.attempt:
6175
extra["attempt"] = info.attempt
62-
return cls(logger, extra)
76+
return cls(logger, extra, execution_state, visited_operations)
6377

6478
def with_log_info(self, info: LogInfo) -> Logger:
6579
"""Clone the existing logger with new LogInfo."""
6680
return Logger.from_log_info(
6781
logger=self._logger,
6882
info=info,
83+
execution_state=self._execution_state,
84+
visited_operations=self._visited_operations,
6985
)
7086

7187
def get_logger(self) -> LoggerInterface:
7288
"""Get the underlying logger."""
7389
return self._logger
7490

91+
def is_replay(self) -> bool:
92+
"""Check if we are currently in replay mode.
93+
94+
Returns True if there are operations in the execution state that haven't been visited yet.
95+
This indicates we are replaying previously executed operations.
96+
"""
97+
if not self._execution_state:
98+
return False
99+
100+
# If there are no operations, we're not in replay
101+
if not self._execution_state.operations:
102+
return False
103+
104+
# Check if there are any operations in the execution state that we haven't visited
105+
# Only consider operations that are not EXECUTION type (which are system operations)
106+
for operation_id, operation in self._execution_state.operations.items():
107+
# Skip EXECUTION operations as they are system operations, not user operations
108+
if operation.operation_type == OperationType.EXECUTION:
109+
continue
110+
if operation_id not in self._visited_operations:
111+
return True
112+
return False
113+
114+
def mark_operation_visited(self, operation_id: str) -> None:
115+
"""Mark an operation as visited."""
116+
self._visited_operations.add(operation_id)
117+
118+
def _should_log(self) -> bool:
119+
"""Determine if logging should occur based on replay state."""
120+
# For the default logger, only log when not in replay
121+
return not self.is_replay()
122+
75123
def debug(
76124
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
77125
) -> None:
126+
if not self._should_log():
127+
return
78128
merged_extra = {**self._default_extra, **(extra or {})}
79129
self._logger.debug(msg, *args, extra=merged_extra)
80130

81131
def info(
82132
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
83133
) -> None:
134+
if not self._should_log():
135+
return
84136
merged_extra = {**self._default_extra, **(extra or {})}
85137
self._logger.info(msg, *args, extra=merged_extra)
86138

87139
def warning(
88140
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
89141
) -> None:
142+
if not self._should_log():
143+
return
90144
merged_extra = {**self._default_extra, **(extra or {})}
91145
self._logger.warning(msg, *args, extra=merged_extra)
92146

93147
def error(
94148
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
95149
) -> None:
150+
if not self._should_log():
151+
return
96152
merged_extra = {**self._default_extra, **(extra or {})}
97153
self._logger.error(msg, *args, extra=merged_extra)
98154

99155
def exception(
100156
self, msg: object, *args: object, extra: Mapping[str, object] | None = None
101157
) -> None:
158+
if not self._should_log():
159+
return
102160
merged_extra = {**self._default_extra, **(extra or {})}
103161
self._logger.exception(msg, *args, extra=merged_extra)
162+
163+
@property
164+
def visited_operations(self):
165+
return self._visited_operations
166+
167+
168+
class ReplayAwareLogger(Logger):
169+
"""A logger that provides custom replay behavior for advanced users.
170+
171+
This logger allows users to customize logging behavior during replay by overriding
172+
the _should_log method. By default, it behaves the same as the base Logger.
173+
"""
174+
175+
def _should_log(self) -> bool:
176+
"""Override this method to customize replay logging behavior.
177+
178+
Returns:
179+
bool: True if logging should occur, False otherwise.
180+
181+
Example:
182+
def _should_log(self) -> bool:
183+
# Always log, even during replay
184+
return True
185+
186+
def _should_log(self) -> bool:
187+
# Only log errors during replay
188+
return not self.is_replay() or self._current_log_level == 'error'
189+
"""
190+
return super()._should_log()

0 commit comments

Comments
 (0)