Skip to content

Commit 74acac0

Browse files
Merge branch 'main' into lint_update
2 parents 6a7d69c + 9277fff commit 74acac0

5 files changed

Lines changed: 496 additions & 33 deletions

File tree

.github/workflows/notify_slack.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
steps:
1515
- name: Send issue notification to Slack
1616
if: github.event_name == 'issues'
17-
uses: slackapi/slack-github-action@v3.0.1
17+
uses: slackapi/slack-github-action@af78098f536edbc4de71162a307590698245be95 # v3.0.1
1818
with:
1919
webhook: ${{ secrets.SLACK_WEBHOOK_URL_ISSUE }}
2020
webhook-type: incoming-webhook
@@ -27,7 +27,7 @@ jobs:
2727
2828
- name: Send pull request notification to Slack
2929
if: github.event_name == 'pull_request_target'
30-
uses: slackapi/slack-github-action@v3.0.1
30+
uses: slackapi/slack-github-action@af78098f536edbc4de71162a307590698245be95 # v3.0.1
3131
with:
3232
webhook: ${{ secrets.SLACK_WEBHOOK_URL_PR }}
3333
webhook-type: incoming-webhook
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# SPDX-FileCopyrightText: 2025-present Amazon.com, Inc. or its affiliates.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
__version__ = "1.3.0"
4+
__version__ = "1.4.0"

src/aws_durable_execution_sdk_python/state.py

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -592,15 +592,21 @@ def checkpoint_batches_forever(self) -> None:
592592
batch: list[QueuedOperation] = self._collect_checkpoint_batch()
593593

594594
if batch:
595-
# Extract OperationUpdates from QueuedOperations for API call
596-
updates: list[OperationUpdate] = [
597-
q.operation_update for q in batch if q.operation_update is not None
598-
]
595+
# Extract OperationUpdates, excluding empty checkpoints from API call
596+
updates: list[OperationUpdate] = []
597+
empty_count = 0
598+
599+
for q in batch:
600+
if q.operation_update is not None:
601+
updates.append(q.operation_update)
602+
else:
603+
empty_count += 1
599604

600605
logger.debug(
601-
"Processing checkpoint batch with %d operations (%d non-empty)",
602-
len(batch),
606+
"Sending %d OperationUpdates out of %d operations, excluding %d empty checkpoints",
603607
len(updates),
608+
len(batch),
609+
empty_count,
604610
)
605611

606612
try:
@@ -687,26 +693,43 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
687693
operation if queues are empty, then collects additional operations within the time
688694
window.
689695
696+
Empty checkpoints (operation_update=None) are coalesced: the first empty checkpoint
697+
counts toward the batch operation limit, but subsequent empty checkpoints do not.
698+
All empty checkpoints remain in the batch so their completion events are signaled.
699+
This avoids unnecessary batches when many concurrent map/parallel branches resume
700+
simultaneously and each queues an empty checkpoint.
701+
690702
Returns:
691703
List of QueuedOperation objects ready for batch processing. Returns empty list
692704
if no operations are available.
693705
"""
694706
batch: list[QueuedOperation] = []
707+
has_empty_checkpoint = False
695708
total_size = 0
709+
effective_operation_count = 0 # Operations that count toward batch limit
696710

697711
# First, drain overflow queue (FIFO order preserved)
698712
try:
699-
while len(batch) < self._batcher_config.max_batch_operations:
713+
while effective_operation_count < self._batcher_config.max_batch_operations:
700714
overflow_op = self._overflow_queue.get_nowait()
701-
op_size = self._calculate_operation_size(overflow_op)
702-
703-
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
704-
# Put back and stop
705-
self._overflow_queue.put(overflow_op)
706-
break
707715

708-
batch.append(overflow_op)
709-
total_size += op_size
716+
if overflow_op.operation_update is None: # Empty checkpoint
717+
batch.append(overflow_op)
718+
if not has_empty_checkpoint:
719+
effective_operation_count += (
720+
1 # First empty counts toward limit
721+
)
722+
has_empty_checkpoint = True
723+
# Subsequent empties don't count toward limit
724+
else:
725+
op_size = self._calculate_operation_size(overflow_op)
726+
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
727+
# Put back and stop
728+
self._overflow_queue.put(overflow_op)
729+
break
730+
batch.append(overflow_op)
731+
total_size += op_size
732+
effective_operation_count += 1
710733
except queue.Empty:
711734
pass
712735

@@ -720,7 +743,13 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
720743
) # Check stop signal every 100ms
721744
self._checkpoint_queue.task_done()
722745
batch.append(first_op)
723-
total_size += self._calculate_operation_size(first_op)
746+
747+
if first_op.operation_update is None:
748+
has_empty_checkpoint = True
749+
else:
750+
total_size += self._calculate_operation_size(first_op)
751+
752+
effective_operation_count = 1
724753
break
725754
except queue.Empty:
726755
continue
@@ -735,7 +764,7 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
735764
# Collect additional operations within the time window
736765
while (
737766
time.time() < batch_deadline
738-
and len(batch) < self._batcher_config.max_batch_operations
767+
and effective_operation_count < self._batcher_config.max_batch_operations
739768
and not self._checkpointing_stopped.is_set()
740769
):
741770
remaining_time = min(
@@ -749,26 +778,39 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
749778
try:
750779
additional_op = self._checkpoint_queue.get(timeout=remaining_time)
751780
self._checkpoint_queue.task_done()
752-
op_size = self._calculate_operation_size(additional_op)
753-
754-
# Check if adding this operation would exceed size limit
755-
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
756-
# Put in overflow queue for next batch
757-
self._overflow_queue.put(additional_op)
758-
logger.debug(
759-
"Batch size limit reached, moving operation to overflow queue"
760-
)
761-
break
762781

763-
batch.append(additional_op)
764-
total_size += op_size
782+
if additional_op.operation_update is None: # Empty checkpoint
783+
batch.append(additional_op)
784+
if not has_empty_checkpoint:
785+
effective_operation_count += (
786+
1 # First empty counts toward limit
787+
)
788+
has_empty_checkpoint = True
789+
# Subsequent empties don't count toward limit
790+
else:
791+
op_size = self._calculate_operation_size(additional_op)
792+
# Check if adding this operation would exceed size limit
793+
if total_size + op_size > self._batcher_config.max_batch_size_bytes:
794+
# Put in overflow queue for next batch
795+
self._overflow_queue.put(additional_op)
796+
logger.debug(
797+
"Batch size limit reached, moving operation to overflow queue"
798+
)
799+
break
800+
batch.append(additional_op)
801+
total_size += op_size
802+
effective_operation_count += 1
765803

766804
except queue.Empty:
767805
break
768806

807+
empty_count = sum(1 for q in batch if q.operation_update is None)
769808
logger.debug(
770-
"Collected batch of %d operations, total size: %d bytes",
809+
"Collected batch of %d operations (%d effective, %d non-empty, %d empty), total size: %d bytes",
771810
len(batch),
811+
effective_operation_count,
812+
len(batch) - empty_count,
813+
empty_count,
772814
total_size,
773815
)
774816
return batch
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""Integration test: empty checkpoint coalescing with concurrent map + wait.
2+
3+
Python equivalent of the Java MapWithConditionAndCallbackExample referenced in
4+
issue #325. Verifies that when many concurrent map branches resume from timed
5+
wait operations simultaneously, the empty checkpoints produced by the
6+
resubmitter (executor.py) are coalesced into minimal API calls instead of
7+
being split across multiple batches.
8+
9+
Background
10+
----------
11+
When a map branch suspends via TimedSuspendExecution and later resumes, the
12+
ConcurrentExecutor resubmitter calls::
13+
14+
execution_state.create_checkpoint() # empty checkpoint
15+
16+
before resubmitting the branch. In high-concurrency scenarios (300+ branches)
17+
all resuming at the same time, 300+ empty checkpoints flood the checkpoint
18+
queue.
19+
20+
Without the coalescing optimization (issue #325), the 250-operation batch limit
21+
causes these to be split across multiple batches → multiple API calls.
22+
With the optimization, all subsequent empty checkpoints beyond the first do
23+
NOT count toward the batch limit, so they are coalesced into a single batch
24+
and a single API call.
25+
26+
These tests directly simulate that concurrent-checkpoint pattern by launching
27+
many threads that each call ``create_checkpoint()`` simultaneously, mirroring
28+
what the map resubmitter does when all branches resume at once.
29+
"""
30+
31+
from __future__ import annotations
32+
33+
import threading
34+
from concurrent.futures import ThreadPoolExecutor
35+
36+
37+
from aws_durable_execution_sdk_python.lambda_service import (
38+
CheckpointOutput,
39+
CheckpointUpdatedExecutionState,
40+
LambdaClient,
41+
OperationAction,
42+
OperationUpdate,
43+
OperationType,
44+
)
45+
from aws_durable_execution_sdk_python.state import (
46+
CheckpointBatcherConfig,
47+
ExecutionState,
48+
QueuedOperation,
49+
)
50+
from aws_durable_execution_sdk_python.threading import CompletionEvent
51+
52+
from unittest.mock import Mock
53+
54+
55+
def _make_state(
56+
mock_client: Mock,
57+
batch_time: float = 5.0,
58+
max_ops: int = 250,
59+
) -> ExecutionState:
60+
config = CheckpointBatcherConfig(
61+
max_batch_size_bytes=10 * 1024 * 1024,
62+
max_batch_time_seconds=batch_time,
63+
max_batch_operations=max_ops,
64+
)
65+
return ExecutionState(
66+
durable_execution_arn="test-arn",
67+
initial_checkpoint_token="token-0", # noqa: S106
68+
operations={},
69+
service_client=mock_client,
70+
batcher_config=config,
71+
)
72+
73+
74+
def _make_tracking_client() -> tuple[Mock, list]:
75+
"""Return a (mock LambdaClient, checkpoint_calls list) pair."""
76+
calls: list[list] = []
77+
mock_client = Mock(spec=LambdaClient)
78+
79+
def _checkpoint(
80+
durable_execution_arn, checkpoint_token, updates, client_token=None
81+
):
82+
calls.append(list(updates))
83+
return CheckpointOutput(
84+
checkpoint_token=f"token_{len(calls)}",
85+
new_execution_state=CheckpointUpdatedExecutionState(),
86+
)
87+
88+
mock_client.checkpoint = _checkpoint
89+
return mock_client, calls
90+
91+
92+
def test_map_with_concurrent_waits_coalesces_empty_checkpoints():
93+
"""300 concurrent branches all create empty checkpoints simultaneously.
94+
95+
Simulates the Java MapWithConditionAndCallbackExample scenario: 300 map
96+
branches all resuming from a wait operation at the same time, each calling
97+
the resubmitter which enqueues an empty checkpoint.
98+
99+
Without the coalescing optimization, the 250-op batch limit splits 300
100+
empty checkpoints into 2 batches (250 + 50) → 2 API calls.
101+
With the optimization (effective_operation_count stays 1 for empties),
102+
all 300 are collected in a single batch → 1 API call.
103+
"""
104+
mock_client, calls = _make_tracking_client()
105+
state = _make_state(mock_client, batch_time=5.0, max_ops=250)
106+
107+
batcher = ThreadPoolExecutor(max_workers=1)
108+
batcher.submit(state.checkpoint_batches_forever)
109+
110+
# 300 branches all call create_checkpoint() concurrently, each blocking
111+
# until the batch is processed — mirrors the resubmitter pattern.
112+
branch_count = 300
113+
start_barrier = threading.Barrier(branch_count)
114+
errors: list[Exception] = []
115+
116+
def branch_work():
117+
try:
118+
start_barrier.wait() # all start simultaneously
119+
state.create_checkpoint() # empty checkpoint, synchronous
120+
except Exception as e: # noqa: BLE001
121+
errors.append(e)
122+
123+
threads = [threading.Thread(target=branch_work) for _ in range(branch_count)]
124+
for t in threads:
125+
t.start()
126+
for t in threads:
127+
t.join(timeout=30)
128+
129+
try:
130+
assert not errors, f"Branch errors: {errors}"
131+
132+
# All 300 empty checkpoints should be batched into 1 API call.
133+
# Without the fix, 300 > 250 limit would produce 2 calls.
134+
assert len(calls) == 1, (
135+
f"Expected 1 coalesced API call for {branch_count} concurrent empty "
136+
f"checkpoints, got {len(calls)}. The 250-op limit must not split empties."
137+
)
138+
assert calls[0] == [], "Empty checkpoints should produce an empty updates list"
139+
finally:
140+
state.stop_checkpointing()
141+
batcher.shutdown(wait=True)
142+
143+
144+
def test_map_with_concurrent_waits_api_call_count_scales_with_real_ops_not_empties():
145+
"""400 empty checkpoints + 10 real ops → 1 API call with limit=11.
146+
147+
Demonstrates that the effective batch count is driven by real operations
148+
(and only the *first* empty), not the total number of empties.
149+
150+
With limit=11: the first empty counts as effective_op 1, and each of the
151+
10 real ops increments the count (effective_ops 2–11). The limit is hit
152+
exactly when the last real op is collected. All 399 remaining empties are
153+
coalesced in without incrementing the count.
154+
155+
Result: 1 batch (410 operations, 10 real) → 1 API call.
156+
"""
157+
mock_client, calls = _make_tracking_client()
158+
# limit = 1 (first empty) + 10 (real ops) = 11, so all fit in one batch
159+
state = _make_state(mock_client, batch_time=5.0, max_ops=11)
160+
161+
completion_events: list[CompletionEvent] = []
162+
163+
try:
164+
# 400 empty checkpoints (simulating concurrent branch resumes)
165+
for _ in range(400):
166+
ev = CompletionEvent()
167+
completion_events.append(ev)
168+
state._checkpoint_queue.put(QueuedOperation(None, ev)) # noqa: SLF001
169+
170+
# 10 real operations alongside the empties
171+
for i in range(10):
172+
op = OperationUpdate(
173+
operation_id=f"op_{i}",
174+
operation_type=OperationType.STEP,
175+
action=OperationAction.START,
176+
)
177+
178+
ev = CompletionEvent()
179+
completion_events.append(ev)
180+
state._checkpoint_queue.put(QueuedOperation(op, ev)) # noqa: SLF001
181+
182+
batcher = ThreadPoolExecutor(max_workers=1)
183+
batcher.submit(state.checkpoint_batches_forever)
184+
185+
# Wait for all 410 to be processed
186+
for ev in completion_events:
187+
ev.wait()
188+
189+
# 1 empty (effective=1) + 10 real ops (effective=11) exhaust the batch
190+
# limit exactly. The 399 remaining empties coalesce in → still 1 API call.
191+
assert len(calls) == 1, (
192+
f"Expected 1 API call with 400 empty + 10 real ops (limit=11), "
193+
f"got {len(calls)}."
194+
)
195+
# Only the 10 real ops appear in the updates list; empties are excluded.
196+
real_op_ids = {u.operation_id for batch in calls for u in batch}
197+
assert real_op_ids == {f"op_{i}" for i in range(10)}
198+
finally:
199+
state.stop_checkpointing()
200+
batcher.shutdown(wait=True)

0 commit comments

Comments
 (0)