-
Notifications
You must be signed in to change notification settings - Fork 18
feat: add per-context replay status tracking #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,7 +43,7 @@ | |
| SerDes, | ||
| deserialize, | ||
| ) | ||
| from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001 | ||
| from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus # noqa: TCH001 | ||
| from aws_durable_execution_sdk_python.threading import OrderedCounter | ||
| from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol | ||
| from aws_durable_execution_sdk_python.types import ( | ||
|
|
@@ -277,6 +277,42 @@ def result(self) -> T | None: | |
| raise SuspendExecution(msg) | ||
|
|
||
|
|
||
| class OperationIdGenerator: | ||
| def __init__(self, step_id_prefix: str | None, parent_id: str | None) -> None: | ||
| self._operation_counter: OrderedCounter = OrderedCounter() | ||
| self._virtual_operation_counter: OrderedCounter = OrderedCounter() | ||
| # child operations use this to generate deterministic step ids. | ||
| # differs from `parent_id` only for virtual contexts. | ||
| self._step_id_prefix: str | None = ( | ||
| step_id_prefix if step_id_prefix is not None else parent_id | ||
| ) | ||
|
|
||
| def peek_next_step_id(self): | ||
| next_step = self._operation_counter.get_current() + 1 | ||
| return self.create_step_id_for_logical_step(next_step, is_virtual=False) | ||
|
zhongkechen marked this conversation as resolved.
|
||
|
|
||
| def create_step_id(self, is_virtual: bool = False) -> str: | ||
| """Generate a thread-safe step id, incrementing in order of invocation. | ||
|
|
||
| This method is an internal implementation detail. Do not rely the exact format of | ||
| the id generated by this method. It is subject to change without notice. | ||
| """ | ||
| new_counter: int = ( | ||
| self._virtual_operation_counter if is_virtual else self._operation_counter | ||
| ).increment() | ||
| return self.create_step_id_for_logical_step(new_counter, is_virtual=is_virtual) | ||
|
|
||
| def create_step_id_for_logical_step(self, step: int, is_virtual: bool) -> str: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: personally I think it would have been better to keep the virtual child context id tracking in a separate PR |
||
| """ | ||
| Generate a step_id based on the given logical step. | ||
| This allows us to recover operation ids or even look | ||
| forward without changing the internal state of this context. | ||
| """ | ||
| parts = [self._step_id_prefix, "v" if is_virtual else None, step] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this a breaking change?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How SDK generates the operation id is internal to the SDK. Users will only see a hash value.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
[UPDATE] sorry I misunderstood the counter. Currently, a virtual child context in the user context (pre hashing) would look like the following: The only thing that bugs me here is that the virtual id doesn't preserve order within the context as a whole, but I guess it doesn't really matter in python since it's all hashed anyway. |
||
| step_id: str = "-".join([str(part) for part in parts if part]) | ||
| return hashlib.blake2b(step_id.encode()).hexdigest()[:64] | ||
|
|
||
|
|
||
| class DurableContext(DurableContextProtocol): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -286,29 +322,30 @@ def __init__( | |
| parent_id: str | None = None, | ||
| logger: Logger | None = None, | ||
| step_id_prefix: str | None = None, | ||
| replay_status: ReplayStatus = ReplayStatus.REPLAY, | ||
| ) -> None: | ||
| self.state: ExecutionState = state | ||
| self.execution_context: ExecutionContext = execution_context | ||
| self.lambda_context = lambda_context | ||
| # operations inside this context use this id as their parent | ||
| self._parent_id: str | None = parent_id | ||
| # child operations use this to generate deterministic step ids. | ||
| # differs from `parent_id` only for virtual contexts. | ||
| self._step_id_prefix: str | None = ( | ||
| step_id_prefix if step_id_prefix is not None else parent_id | ||
| self._is_virtual: bool = ( | ||
| step_id_prefix is not None and parent_id != step_id_prefix | ||
| ) | ||
| # cached at construction to make invariant even if parent/prefix mutates. | ||
| self._is_virtual: bool = self._parent_id != self._step_id_prefix | ||
| self._step_counter: OrderedCounter = OrderedCounter() | ||
| self._operation_id_generator: OperationIdGenerator = OperationIdGenerator( | ||
| step_id_prefix, parent_id | ||
| ) | ||
| self._replay_status: ReplayStatus = replay_status | ||
| self._track_replay() | ||
|
|
||
| log_info = LogInfo( | ||
| execution_state=state, | ||
| parent_id=parent_id, | ||
| ) | ||
| self._log_info = log_info | ||
| self.logger: Logger = logger or Logger.from_log_info( | ||
| logger=logging.getLogger(), | ||
| info=log_info, | ||
| context=self, | ||
| ) | ||
|
|
||
| @property | ||
|
|
@@ -323,6 +360,11 @@ def is_virtual(self) -> bool: | |
| """ | ||
| return self._is_virtual | ||
|
|
||
| @property | ||
| def is_replaying(self) -> bool: | ||
| """True if this context is in replay mode""" | ||
| return self._replay_status is ReplayStatus.REPLAY | ||
|
|
||
| # region factories | ||
| @staticmethod | ||
| def from_lambda_context( | ||
|
|
@@ -371,9 +413,9 @@ def create_child_context( | |
| lambda_context=self.lambda_context, | ||
| parent_id=child_parent_id, | ||
| step_id_prefix=operation_id, | ||
| replay_status=self._replay_status, | ||
| logger=self.logger.with_log_info( | ||
| LogInfo( | ||
| execution_state=self.state, | ||
| parent_id=child_parent_id, | ||
| ) | ||
| ), | ||
|
|
@@ -396,26 +438,20 @@ def set_logger(self, new_logger: LoggerInterface): | |
| self.logger = Logger.from_log_info( | ||
| logger=new_logger, | ||
| info=self._log_info, | ||
| context=self, | ||
| ) | ||
|
|
||
| def _create_step_id_for_logical_step(self, step: int) -> str: | ||
| """ | ||
| Generate a step_id based on the given logical step. | ||
| This allows us to recover operation ids or even look | ||
| forward without changing the internal state of this context. | ||
| """ | ||
| prefix: str | None = self._step_id_prefix | ||
| step_id: str = f"{prefix}-{step}" if prefix else str(step) | ||
| return hashlib.blake2b(step_id.encode()).hexdigest()[:64] | ||
|
|
||
| def _create_step_id(self) -> str: | ||
| """Generate a thread-safe step id, incrementing in order of invocation. | ||
|
|
||
| This method is an internal implementation detail. Do not rely the exact format of | ||
| the id generated by this method. It is subject to change without notice. | ||
| """ | ||
| new_counter: int = self._step_counter.increment() | ||
| return self._create_step_id_for_logical_step(new_counter) | ||
| def _track_replay(self) -> None: | ||
| """Transition replay status to NEW if the next operation has not been checkpointed""" | ||
| if self._replay_status is ReplayStatus.NEW: | ||
| return | ||
| # check if next operation exists | ||
| next_step_id = self._operation_id_generator.peek_next_step_id() | ||
| if not self.state.get_checkpoint_result(next_step_id).is_existent(): | ||
|
zhongkechen marked this conversation as resolved.
|
||
| # update the context replay status to NEW | ||
| self._replay_status = ReplayStatus.NEW | ||
| # update the execution replay status to NEW | ||
| self.state.transition_replay_status() | ||
|
zhongkechen marked this conversation as resolved.
|
||
|
|
||
| # region Operations | ||
|
|
||
|
|
@@ -438,7 +474,7 @@ def create_callback( | |
| """ | ||
| if not config: | ||
| config = CallbackConfig() | ||
| operation_id: str = self._create_step_id() | ||
| operation_id: str = self._operation_id_generator.create_step_id() | ||
| executor: CallbackOperationExecutor = CallbackOperationExecutor( | ||
| state=self.state, | ||
| operation_identifier=OperationIdentifier( | ||
|
|
@@ -448,14 +484,14 @@ def create_callback( | |
| ), | ||
| config=config, | ||
| ) | ||
| self._track_replay() | ||
| callback_id: str = executor.process() | ||
| result: Callback = Callback( | ||
| callback_id=callback_id, | ||
| operation_id=operation_id, | ||
| state=self.state, | ||
| serdes=config.serdes, | ||
| ) | ||
| self.state.track_replay(operation_id=operation_id) | ||
| return result | ||
|
|
||
| def invoke( | ||
|
|
@@ -478,7 +514,7 @@ def invoke( | |
| """ | ||
| if not config: | ||
| config = InvokeConfig[P, R]() | ||
| operation_id = self._create_step_id() | ||
| operation_id = self._operation_id_generator.create_step_id() | ||
| executor: InvokeOperationExecutor[R] = InvokeOperationExecutor( | ||
| function_name=function_name, | ||
| payload=payload, | ||
|
|
@@ -490,8 +526,8 @@ def invoke( | |
| ), | ||
| config=config, | ||
| ) | ||
| self._track_replay() | ||
| result: R = executor.process() | ||
| self.state.track_replay(operation_id=operation_id) | ||
| return result | ||
|
|
||
| def map( | ||
|
|
@@ -504,7 +540,7 @@ def map( | |
| """Execute a callable for each item in parallel.""" | ||
| map_name: str | None = self._resolve_step_name(name, func) | ||
|
|
||
| operation_id = self._create_step_id() | ||
| operation_id = self._operation_id_generator.create_step_id() | ||
| operation_identifier = OperationIdentifier( | ||
| operation_id=operation_id, | ||
| parent_id=self._parent_id, | ||
|
|
@@ -526,6 +562,7 @@ def map_in_child_context() -> BatchResult[R]: | |
| operation_identifier=operation_identifier, | ||
| ) | ||
|
|
||
| self._track_replay() | ||
| result: BatchResult[R] = child_handler( | ||
| func=map_in_child_context, | ||
| state=self.state, | ||
|
|
@@ -539,7 +576,6 @@ def map_in_child_context() -> BatchResult[R]: | |
| item_serdes=None, | ||
| ), | ||
| ) | ||
| self.state.track_replay(operation_id=operation_id) | ||
| return result | ||
|
|
||
| def parallel( | ||
|
|
@@ -550,7 +586,7 @@ def parallel( | |
| ) -> BatchResult[T]: | ||
| """Execute multiple callables in parallel.""" | ||
| # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id | ||
| operation_id = self._create_step_id() | ||
| operation_id = self._operation_id_generator.create_step_id() | ||
| parallel_context = self.create_child_context(operation_id=operation_id) | ||
| operation_identifier = OperationIdentifier( | ||
| operation_id=operation_id, parent_id=self._parent_id, name=name | ||
|
|
@@ -569,6 +605,7 @@ def parallel_in_child_context() -> BatchResult[T]: | |
| operation_identifier=operation_identifier, | ||
| ) | ||
|
|
||
| self._track_replay() | ||
| result: BatchResult[T] = child_handler( | ||
| func=parallel_in_child_context, | ||
| state=self.state, | ||
|
|
@@ -582,7 +619,6 @@ def parallel_in_child_context() -> BatchResult[T]: | |
| item_serdes=None, | ||
| ), | ||
| ) | ||
| self.state.track_replay(operation_id=operation_id) | ||
| return result | ||
|
|
||
| def run_in_child_context( | ||
|
|
@@ -596,18 +632,19 @@ def run_in_child_context( | |
| Use this to nest and group operations. | ||
|
|
||
| Args: | ||
| callable (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it. | ||
| func (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it. | ||
| name (str | None): name for the operation. | ||
| config (ChildConfig | None = None): c | ||
|
|
||
| Returns: | ||
| T: The result of the callable. | ||
| """ | ||
| step_name: str | None = self._resolve_step_name(name, func) | ||
| # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id | ||
| operation_id = self._create_step_id() | ||
|
|
||
| is_virtual: bool = config.is_virtual if config else False | ||
| # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id | ||
| operation_id = self._operation_id_generator.create_step_id( | ||
| is_virtual=is_virtual | ||
| ) | ||
|
|
||
| def callable_with_child_context(): | ||
| return func( | ||
|
|
@@ -616,6 +653,7 @@ def callable_with_child_context(): | |
| ) | ||
| ) | ||
|
|
||
| self._track_replay() | ||
| result: T = child_handler( | ||
| func=callable_with_child_context, | ||
| state=self.state, | ||
|
|
@@ -626,7 +664,6 @@ def callable_with_child_context(): | |
| ), | ||
| config=config, | ||
| ) | ||
| self.state.track_replay(operation_id=operation_id) | ||
| return result | ||
|
|
||
| def step( | ||
|
|
@@ -639,7 +676,7 @@ def step( | |
| logger.debug("Step name: %s", step_name) | ||
| if not config: | ||
| config = StepConfig() | ||
| operation_id = self._create_step_id() | ||
| operation_id = self._operation_id_generator.create_step_id() | ||
| executor: StepOperationExecutor[T] = StepOperationExecutor( | ||
| func=func, | ||
| config=config, | ||
|
|
@@ -651,8 +688,8 @@ def step( | |
| ), | ||
| context_logger=self.logger, | ||
| ) | ||
| self._track_replay() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Same pattern for the other operations. This will result in No customer-observable effect today because the only internal reader of the flag is |
||
| result: T = executor.process() | ||
| self.state.track_replay(operation_id=operation_id) | ||
| return result | ||
|
|
||
| def wait(self, duration: Duration, name: str | None = None) -> None: | ||
|
|
@@ -666,7 +703,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None: | |
| if seconds < 1: | ||
| msg = "duration must be at least 1 second" | ||
| raise ValidationError(msg) | ||
| operation_id = self._create_step_id() | ||
| operation_id = self._operation_id_generator.create_step_id() | ||
| wait_seconds = duration.seconds | ||
| executor: WaitOperationExecutor = WaitOperationExecutor( | ||
| seconds=wait_seconds, | ||
|
|
@@ -677,8 +714,8 @@ def wait(self, duration: Duration, name: str | None = None) -> None: | |
| name=name, | ||
| ), | ||
| ) | ||
| self._track_replay() | ||
| executor.process() | ||
| self.state.track_replay(operation_id=operation_id) | ||
|
|
||
| def wait_for_callback( | ||
| self, | ||
|
|
@@ -720,7 +757,7 @@ def wait_for_condition( | |
| msg = "`config` is required for wait_for_condition" | ||
| raise ValidationError(msg) | ||
|
|
||
| operation_id = self._create_step_id() | ||
| operation_id = self._operation_id_generator.create_step_id() | ||
| executor: WaitForConditionOperationExecutor[T] = ( | ||
| WaitForConditionOperationExecutor( | ||
| check=check, | ||
|
|
@@ -734,8 +771,8 @@ def wait_for_condition( | |
| context_logger=self.logger, | ||
| ) | ||
| ) | ||
| self._track_replay() | ||
| result: T = executor.process() | ||
| self.state.track_replay(operation_id=operation_id) | ||
| return result | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to look at
_virtual_operation_counter?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is used to determine the replay status.
_virtual_operation_counteris not used for that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If next operation is a virtual operation then what would
self._operation_counter.get_current() + 1return?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it returns the id of the next "real" operation