-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathprocessor.py
More file actions
125 lines (101 loc) · 4.53 KB
/
processor.py
File metadata and controls
125 lines (101 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Main checkpoint processor that orchestrates operation transformations."""
from __future__ import annotations
from typing import TYPE_CHECKING
from aws_durable_execution_sdk_python.lambda_service import (
CheckpointOutput,
CheckpointUpdatedExecutionState,
OperationUpdate,
StateOutput,
Operation,
)
from aws_durable_execution_sdk_python_testing.checkpoint.transformer import (
OperationTransformer,
)
from aws_durable_execution_sdk_python_testing.checkpoint.validators.checkpoint import (
CheckpointValidator,
)
from aws_durable_execution_sdk_python_testing.exceptions import (
InvalidParameterValueException,
)
from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier
from aws_durable_execution_sdk_python_testing.token import CheckpointToken
if TYPE_CHECKING:
from aws_durable_execution_sdk_python_testing.execution import Execution
from aws_durable_execution_sdk_python_testing.scheduler import Scheduler
from aws_durable_execution_sdk_python_testing.stores.base import ExecutionStore
class CheckpointProcessor:
"""Handle OperationUpdate transformations and execution state updates."""
def __init__(self, store: ExecutionStore, scheduler: Scheduler):
self._store = store
self._scheduler = scheduler
self._notifier = ExecutionNotifier()
self._transformer = OperationTransformer()
def add_execution_observer(self, observer) -> None:
"""Add observer for execution events."""
self._notifier.add_observer(observer)
def process_checkpoint(
self,
checkpoint_token: str,
updates: list[OperationUpdate],
client_token: str | None, # noqa: ARG002
) -> CheckpointOutput:
"""Process checkpoint updates and return result with updated execution state."""
# 1. Get current execution state
token: CheckpointToken = CheckpointToken.from_str(checkpoint_token)
execution: Execution = self._store.load(token.execution_arn)
# 2. Validate checkpoint token
if execution.is_complete or token.token_sequence != execution.token_sequence:
msg: str = "Invalid checkpoint token"
raise InvalidParameterValueException(msg)
# 3. Validate all updates, state transitions are valid, sizes etc.
CheckpointValidator.validate_input(updates, execution)
# 4. Transform OperationUpdate -> Operation and schedule future replays
updated_operations, all_updates = self._transformer.process_updates(
updates=updates,
current_operations=execution.operations,
notifier=self._notifier,
execution_arn=token.execution_arn,
)
# 5. Generate a new checkpoint token and save updated operations
new_checkpoint_token = execution.get_new_checkpoint_token()
execution.operations = updated_operations
execution.updates.extend(all_updates)
self._store.update(execution)
# 6. Return checkpoint result
return CheckpointOutput(
checkpoint_token=new_checkpoint_token,
new_execution_state=CheckpointUpdatedExecutionState(
operations=execution.get_navigable_operations(), next_marker=None
),
)
def get_execution_state(
self,
checkpoint_token: str,
next_marker: str | None = None,
max_items: int = 1000,
) -> StateOutput:
"""Get current execution state with batched checkpoint token validation and pagination."""
if not checkpoint_token:
msg: str = "Checkpoint token is required"
raise InvalidParameterValueException(msg)
token: CheckpointToken = CheckpointToken.from_str(checkpoint_token)
execution: Execution = self._store.load(token.execution_arn)
execution.validate_checkpoint_token(checkpoint_token)
# Get all operations
all_operations: list[Operation] = execution.get_navigable_operations()
# Apply pagination
start_index: int = 0
if next_marker:
try:
start_index = int(next_marker)
except ValueError:
start_index = 0
end_index: int = start_index + max_items
paginated_operations: list[Operation] = all_operations[start_index:end_index]
# Determine next marker
next_marker_result: str | None = (
str(end_index) if end_index < len(all_operations) else None
)
return StateOutput(
operations=paginated_operations, next_marker=next_marker_result
)