Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/aws_durable_execution_sdk_python/concurrency/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,13 @@ def _execute_item_in_child_context(
and execution-order invariant.
"""

operation_id: str = executor_context._create_step_id_for_logical_step( # noqa: SLF001
executable.index
is_virtual: bool = self.nesting_type is NestingType.FLAT
operation_id: str = (
executor_context._operation_id_generator.create_step_id_for_logical_step( # noqa: SLF001
executable.index, is_virtual=is_virtual
)
)
name: str = self.get_iteration_name(executable.index)
is_virtual: bool = self.nesting_type is NestingType.FLAT

child_context: DurableContext = executor_context.create_child_context(
operation_id, is_virtual=is_virtual
Expand Down Expand Up @@ -447,7 +449,6 @@ def run_in_child_handler() -> ResultType:
is_virtual=is_virtual,
),
)
child_context.state.track_replay(operation_id=operation_id)
return result

def replay(self, execution_state: ExecutionState, executor_context: DurableContext):
Expand All @@ -458,10 +459,11 @@ def replay(self, execution_state: ExecutionState, executor_context: DurableConte
This will pre-generate all the operation ids for the children and collect the checkpointed
results.
"""
is_virtual: bool = self.nesting_type is NestingType.FLAT
items: list[BatchItem[ResultType]] = []
for executable in self.executables:
operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001
executable.index
operation_id = executor_context._operation_id_generator.create_step_id_for_logical_step( # noqa: SLF001
executable.index, is_virtual=is_virtual
)
checkpoint = execution_state.get_checkpoint_result(operation_id)

Expand Down
131 changes: 84 additions & 47 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
SerDes,
deserialize,
)
from aws_durable_execution_sdk_python.state import ExecutionState # noqa: TCH001
from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus # noqa: TCH001
from aws_durable_execution_sdk_python.threading import OrderedCounter
from aws_durable_execution_sdk_python.types import Callback as CallbackProtocol
from aws_durable_execution_sdk_python.types import (
Expand Down Expand Up @@ -277,6 +277,42 @@ def result(self) -> T | None:
raise SuspendExecution(msg)


class OperationIdGenerator:
def __init__(self, step_id_prefix: str | None, parent_id: str | None) -> None:
self._operation_counter: OrderedCounter = OrderedCounter()
self._virtual_operation_counter: OrderedCounter = OrderedCounter()
# child operations use this to generate deterministic step ids.
# differs from `parent_id` only for virtual contexts.
self._step_id_prefix: str | None = (
step_id_prefix if step_id_prefix is not None else parent_id
)

def peek_next_step_id(self):
next_step = self._operation_counter.get_current() + 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to look at _virtual_operation_counter?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is used to determine the replay status. _virtual_operation_counter is not used for that.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If next operation is a virtual operation then what would self._operation_counter.get_current() + 1 return?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it returns the id of the next "real" operation

return self.create_step_id_for_logical_step(next_step, is_virtual=False)
Comment thread
zhongkechen marked this conversation as resolved.

def create_step_id(self, is_virtual: bool = False) -> str:
"""Generate a thread-safe step id, incrementing in order of invocation.

This method is an internal implementation detail. Do not rely the exact format of
the id generated by this method. It is subject to change without notice.
"""
new_counter: int = (
self._virtual_operation_counter if is_virtual else self._operation_counter
).increment()
return self.create_step_id_for_logical_step(new_counter, is_virtual=is_virtual)

def create_step_id_for_logical_step(self, step: int, is_virtual: bool) -> str:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: personally I think it would have been better to keep the virtual child context id tracking in a separate PR

"""
Generate a step_id based on the given logical step.
This allows us to recover operation ids or even look
forward without changing the internal state of this context.
"""
parts = [self._step_id_prefix, "v" if is_virtual else None, step]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a breaking change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How SDK generates the operation id is internal to the SDK. Users will only see a hash value.

Copy link
Copy Markdown
Contributor

@SilanHe SilanHe May 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still want deterministic Ids AND unique Id for all virtual operations. With the current proposal, if there are multiple virtual child context between "real" steps, wouldn't they all have the same ID?

[UPDATE] sorry I misunderstood the counter. Currently, a virtual child context in the user context (pre hashing) would look like the following:

v-1
v-2
1
2
v-3
3-v-1

The only thing that bugs me here is that the virtual id doesn't preserve order within the context as a whole, but I guess it doesn't really matter in python since it's all hashed anyway.

step_id: str = "-".join([str(part) for part in parts if part])
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]


class DurableContext(DurableContextProtocol):
def __init__(
self,
Expand All @@ -286,29 +322,30 @@ def __init__(
parent_id: str | None = None,
logger: Logger | None = None,
step_id_prefix: str | None = None,
replay_status: ReplayStatus = ReplayStatus.REPLAY,
) -> None:
self.state: ExecutionState = state
self.execution_context: ExecutionContext = execution_context
self.lambda_context = lambda_context
# operations inside this context use this id as their parent
self._parent_id: str | None = parent_id
# child operations use this to generate deterministic step ids.
# differs from `parent_id` only for virtual contexts.
self._step_id_prefix: str | None = (
step_id_prefix if step_id_prefix is not None else parent_id
self._is_virtual: bool = (
step_id_prefix is not None and parent_id != step_id_prefix
)
# cached at construction to make invariant even if parent/prefix mutates.
self._is_virtual: bool = self._parent_id != self._step_id_prefix
self._step_counter: OrderedCounter = OrderedCounter()
self._operation_id_generator: OperationIdGenerator = OperationIdGenerator(
step_id_prefix, parent_id
)
self._replay_status: ReplayStatus = replay_status
self._track_replay()

log_info = LogInfo(
execution_state=state,
parent_id=parent_id,
)
self._log_info = log_info
self.logger: Logger = logger or Logger.from_log_info(
logger=logging.getLogger(),
info=log_info,
context=self,
)

@property
Expand All @@ -323,6 +360,11 @@ def is_virtual(self) -> bool:
"""
return self._is_virtual

@property
def is_replaying(self) -> bool:
"""True if this context is in replay mode"""
return self._replay_status is ReplayStatus.REPLAY

# region factories
@staticmethod
def from_lambda_context(
Expand Down Expand Up @@ -371,9 +413,9 @@ def create_child_context(
lambda_context=self.lambda_context,
parent_id=child_parent_id,
step_id_prefix=operation_id,
replay_status=self._replay_status,
logger=self.logger.with_log_info(
LogInfo(
execution_state=self.state,
parent_id=child_parent_id,
)
),
Expand All @@ -396,26 +438,20 @@ def set_logger(self, new_logger: LoggerInterface):
self.logger = Logger.from_log_info(
logger=new_logger,
info=self._log_info,
context=self,
)

def _create_step_id_for_logical_step(self, step: int) -> str:
"""
Generate a step_id based on the given logical step.
This allows us to recover operation ids or even look
forward without changing the internal state of this context.
"""
prefix: str | None = self._step_id_prefix
step_id: str = f"{prefix}-{step}" if prefix else str(step)
return hashlib.blake2b(step_id.encode()).hexdigest()[:64]

def _create_step_id(self) -> str:
"""Generate a thread-safe step id, incrementing in order of invocation.

This method is an internal implementation detail. Do not rely the exact format of
the id generated by this method. It is subject to change without notice.
"""
new_counter: int = self._step_counter.increment()
return self._create_step_id_for_logical_step(new_counter)
def _track_replay(self) -> None:
"""Transition replay status to NEW if the next operation has not been checkpointed"""
if self._replay_status is ReplayStatus.NEW:
return
# check if next operation exists
next_step_id = self._operation_id_generator.peek_next_step_id()
if not self.state.get_checkpoint_result(next_step_id).is_existent():
Comment thread
zhongkechen marked this conversation as resolved.
# update the context replay status to NEW
self._replay_status = ReplayStatus.NEW
# update the execution replay status to NEW
self.state.transition_replay_status()
Comment thread
zhongkechen marked this conversation as resolved.

# region Operations

Expand All @@ -438,7 +474,7 @@ def create_callback(
"""
if not config:
config = CallbackConfig()
operation_id: str = self._create_step_id()
operation_id: str = self._operation_id_generator.create_step_id()
executor: CallbackOperationExecutor = CallbackOperationExecutor(
state=self.state,
operation_identifier=OperationIdentifier(
Expand All @@ -448,14 +484,14 @@ def create_callback(
),
config=config,
)
self._track_replay()
callback_id: str = executor.process()
result: Callback = Callback(
callback_id=callback_id,
operation_id=operation_id,
state=self.state,
serdes=config.serdes,
)
self.state.track_replay(operation_id=operation_id)
return result

def invoke(
Expand All @@ -478,7 +514,7 @@ def invoke(
"""
if not config:
config = InvokeConfig[P, R]()
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
executor: InvokeOperationExecutor[R] = InvokeOperationExecutor(
function_name=function_name,
payload=payload,
Expand All @@ -490,8 +526,8 @@ def invoke(
),
config=config,
)
self._track_replay()
result: R = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

def map(
Expand All @@ -504,7 +540,7 @@ def map(
"""Execute a callable for each item in parallel."""
map_name: str | None = self._resolve_step_name(name, func)

operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
operation_identifier = OperationIdentifier(
operation_id=operation_id,
parent_id=self._parent_id,
Expand All @@ -526,6 +562,7 @@ def map_in_child_context() -> BatchResult[R]:
operation_identifier=operation_identifier,
)

self._track_replay()
result: BatchResult[R] = child_handler(
func=map_in_child_context,
state=self.state,
Expand All @@ -539,7 +576,6 @@ def map_in_child_context() -> BatchResult[R]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def parallel(
Expand All @@ -550,7 +586,7 @@ def parallel(
) -> BatchResult[T]:
"""Execute multiple callables in parallel."""
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
parallel_context = self.create_child_context(operation_id=operation_id)
operation_identifier = OperationIdentifier(
operation_id=operation_id, parent_id=self._parent_id, name=name
Expand All @@ -569,6 +605,7 @@ def parallel_in_child_context() -> BatchResult[T]:
operation_identifier=operation_identifier,
)

self._track_replay()
result: BatchResult[T] = child_handler(
func=parallel_in_child_context,
state=self.state,
Expand All @@ -582,7 +619,6 @@ def parallel_in_child_context() -> BatchResult[T]:
item_serdes=None,
),
)
self.state.track_replay(operation_id=operation_id)
return result

def run_in_child_context(
Expand All @@ -596,18 +632,19 @@ def run_in_child_context(
Use this to nest and group operations.

Args:
callable (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it.
func (Callable[[DurableContext], T]): Run this callable and pass the child context as the argument to it.
name (str | None): name for the operation.
config (ChildConfig | None = None): c

Returns:
T: The result of the callable.
"""
step_name: str | None = self._resolve_step_name(name, func)
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._create_step_id()

is_virtual: bool = config.is_virtual if config else False
# _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id
operation_id = self._operation_id_generator.create_step_id(
is_virtual=is_virtual
)

def callable_with_child_context():
return func(
Expand All @@ -616,6 +653,7 @@ def callable_with_child_context():
)
)

self._track_replay()
result: T = child_handler(
func=callable_with_child_context,
state=self.state,
Expand All @@ -626,7 +664,6 @@ def callable_with_child_context():
),
config=config,
)
self.state.track_replay(operation_id=operation_id)
return result

def step(
Expand All @@ -639,7 +676,7 @@ def step(
logger.debug("Step name: %s", step_name)
if not config:
config = StepConfig()
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
executor: StepOperationExecutor[T] = StepOperationExecutor(
func=func,
config=config,
Expand All @@ -651,8 +688,8 @@ def step(
),
context_logger=self.logger,
)
self._track_replay()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_track_replay() runs after create_step_id() increments the counter, so peek_next_step_id resolves to N+1 instead of N. The check asks "is the op after the current one checkpointed?" rather than "is the current one?" JS does the inverse order on purpose (checkAndUpdateReplayMode runs before createStepId)

Same pattern for the other operations.

This will result in is_replaying being incorrect during the last replayed op's processing window.

No customer-observable effect today because the only internal reader of the flag is Logger._should_log.

result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result

def wait(self, duration: Duration, name: str | None = None) -> None:
Expand All @@ -666,7 +703,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
if seconds < 1:
msg = "duration must be at least 1 second"
raise ValidationError(msg)
operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
wait_seconds = duration.seconds
executor: WaitOperationExecutor = WaitOperationExecutor(
seconds=wait_seconds,
Expand All @@ -677,8 +714,8 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
name=name,
),
)
self._track_replay()
executor.process()
self.state.track_replay(operation_id=operation_id)

def wait_for_callback(
self,
Expand Down Expand Up @@ -720,7 +757,7 @@ def wait_for_condition(
msg = "`config` is required for wait_for_condition"
raise ValidationError(msg)

operation_id = self._create_step_id()
operation_id = self._operation_id_generator.create_step_id()
executor: WaitForConditionOperationExecutor[T] = (
WaitForConditionOperationExecutor(
check=check,
Expand All @@ -734,8 +771,8 @@ def wait_for_condition(
context_logger=self.logger,
)
)
self._track_replay()
result: T = executor.process()
self.state.track_replay(operation_id=operation_id)
return result


Expand Down
Loading