|
| 1 | +"""Integration test: empty checkpoint coalescing with concurrent map + wait. |
| 2 | +
|
| 3 | +Python equivalent of the Java MapWithConditionAndCallbackExample referenced in |
| 4 | +issue #325. Verifies that when many concurrent map branches resume from timed |
| 5 | +wait operations simultaneously, the empty checkpoints produced by the |
| 6 | +resubmitter (executor.py) are coalesced into minimal API calls instead of |
| 7 | +being split across multiple batches. |
| 8 | +
|
| 9 | +Background |
| 10 | +---------- |
| 11 | +When a map branch suspends via TimedSuspendExecution and later resumes, the |
| 12 | +ConcurrentExecutor resubmitter calls:: |
| 13 | +
|
| 14 | + execution_state.create_checkpoint() # empty checkpoint |
| 15 | +
|
| 16 | +before resubmitting the branch. In high-concurrency scenarios (300+ branches) |
| 17 | +all resuming at the same time, 300+ empty checkpoints flood the checkpoint |
| 18 | +queue. |
| 19 | +
|
| 20 | +Without the coalescing optimization (issue #325), the 250-operation batch limit |
| 21 | +causes these to be split across multiple batches → multiple API calls. |
| 22 | +With the optimization, all subsequent empty checkpoints beyond the first do |
| 23 | +NOT count toward the batch limit, so they are coalesced into a single batch |
| 24 | +and a single API call. |
| 25 | +
|
| 26 | +These tests directly simulate that concurrent-checkpoint pattern by launching |
| 27 | +many threads that each call ``create_checkpoint()`` simultaneously, mirroring |
| 28 | +what the map resubmitter does when all branches resume at once. |
| 29 | +""" |
| 30 | + |
| 31 | +from __future__ import annotations |
| 32 | + |
| 33 | +import threading |
| 34 | +from concurrent.futures import ThreadPoolExecutor |
| 35 | + |
| 36 | + |
| 37 | +from aws_durable_execution_sdk_python.lambda_service import ( |
| 38 | + CheckpointOutput, |
| 39 | + CheckpointUpdatedExecutionState, |
| 40 | + LambdaClient, |
| 41 | + OperationAction, |
| 42 | + OperationUpdate, |
| 43 | + OperationType, |
| 44 | +) |
| 45 | +from aws_durable_execution_sdk_python.state import ( |
| 46 | + CheckpointBatcherConfig, |
| 47 | + ExecutionState, |
| 48 | + QueuedOperation, |
| 49 | +) |
| 50 | +from aws_durable_execution_sdk_python.threading import CompletionEvent |
| 51 | + |
| 52 | +from unittest.mock import Mock |
| 53 | + |
| 54 | + |
| 55 | +def _make_state( |
| 56 | + mock_client: Mock, |
| 57 | + batch_time: float = 5.0, |
| 58 | + max_ops: int = 250, |
| 59 | +) -> ExecutionState: |
| 60 | + config = CheckpointBatcherConfig( |
| 61 | + max_batch_size_bytes=10 * 1024 * 1024, |
| 62 | + max_batch_time_seconds=batch_time, |
| 63 | + max_batch_operations=max_ops, |
| 64 | + ) |
| 65 | + return ExecutionState( |
| 66 | + durable_execution_arn="test-arn", |
| 67 | + initial_checkpoint_token="token-0", # noqa: S106 |
| 68 | + operations={}, |
| 69 | + service_client=mock_client, |
| 70 | + batcher_config=config, |
| 71 | + ) |
| 72 | + |
| 73 | + |
| 74 | +def _make_tracking_client() -> tuple[Mock, list]: |
| 75 | + """Return a (mock LambdaClient, checkpoint_calls list) pair.""" |
| 76 | + calls: list[list] = [] |
| 77 | + mock_client = Mock(spec=LambdaClient) |
| 78 | + |
| 79 | + def _checkpoint( |
| 80 | + durable_execution_arn, checkpoint_token, updates, client_token=None |
| 81 | + ): |
| 82 | + calls.append(list(updates)) |
| 83 | + return CheckpointOutput( |
| 84 | + checkpoint_token=f"token_{len(calls)}", |
| 85 | + new_execution_state=CheckpointUpdatedExecutionState(), |
| 86 | + ) |
| 87 | + |
| 88 | + mock_client.checkpoint = _checkpoint |
| 89 | + return mock_client, calls |
| 90 | + |
| 91 | + |
| 92 | +def test_map_with_concurrent_waits_coalesces_empty_checkpoints(): |
| 93 | + """300 concurrent branches all create empty checkpoints simultaneously. |
| 94 | +
|
| 95 | + Simulates the Java MapWithConditionAndCallbackExample scenario: 300 map |
| 96 | + branches all resuming from a wait operation at the same time, each calling |
| 97 | + the resubmitter which enqueues an empty checkpoint. |
| 98 | +
|
| 99 | + Without the coalescing optimization, the 250-op batch limit splits 300 |
| 100 | + empty checkpoints into 2 batches (250 + 50) → 2 API calls. |
| 101 | + With the optimization (effective_operation_count stays 1 for empties), |
| 102 | + all 300 are collected in a single batch → 1 API call. |
| 103 | + """ |
| 104 | + mock_client, calls = _make_tracking_client() |
| 105 | + state = _make_state(mock_client, batch_time=5.0, max_ops=250) |
| 106 | + |
| 107 | + batcher = ThreadPoolExecutor(max_workers=1) |
| 108 | + batcher.submit(state.checkpoint_batches_forever) |
| 109 | + |
| 110 | + # 300 branches all call create_checkpoint() concurrently, each blocking |
| 111 | + # until the batch is processed — mirrors the resubmitter pattern. |
| 112 | + branch_count = 300 |
| 113 | + start_barrier = threading.Barrier(branch_count) |
| 114 | + errors: list[Exception] = [] |
| 115 | + |
| 116 | + def branch_work(): |
| 117 | + try: |
| 118 | + start_barrier.wait() # all start simultaneously |
| 119 | + state.create_checkpoint() # empty checkpoint, synchronous |
| 120 | + except Exception as e: # noqa: BLE001 |
| 121 | + errors.append(e) |
| 122 | + |
| 123 | + threads = [threading.Thread(target=branch_work) for _ in range(branch_count)] |
| 124 | + for t in threads: |
| 125 | + t.start() |
| 126 | + for t in threads: |
| 127 | + t.join(timeout=30) |
| 128 | + |
| 129 | + try: |
| 130 | + assert not errors, f"Branch errors: {errors}" |
| 131 | + |
| 132 | + # All 300 empty checkpoints should be batched into 1 API call. |
| 133 | + # Without the fix, 300 > 250 limit would produce 2 calls. |
| 134 | + assert len(calls) == 1, ( |
| 135 | + f"Expected 1 coalesced API call for {branch_count} concurrent empty " |
| 136 | + f"checkpoints, got {len(calls)}. The 250-op limit must not split empties." |
| 137 | + ) |
| 138 | + assert calls[0] == [], "Empty checkpoints should produce an empty updates list" |
| 139 | + finally: |
| 140 | + state.stop_checkpointing() |
| 141 | + batcher.shutdown(wait=True) |
| 142 | + |
| 143 | + |
| 144 | +def test_map_with_concurrent_waits_api_call_count_scales_with_real_ops_not_empties(): |
| 145 | + """400 empty checkpoints + 10 real ops → 1 API call with limit=11. |
| 146 | +
|
| 147 | + Demonstrates that the effective batch count is driven by real operations |
| 148 | + (and only the *first* empty), not the total number of empties. |
| 149 | +
|
| 150 | + With limit=11: the first empty counts as effective_op 1, and each of the |
| 151 | + 10 real ops increments the count (effective_ops 2–11). The limit is hit |
| 152 | + exactly when the last real op is collected. All 399 remaining empties are |
| 153 | + coalesced in without incrementing the count. |
| 154 | +
|
| 155 | + Result: 1 batch (410 operations, 10 real) → 1 API call. |
| 156 | + """ |
| 157 | + mock_client, calls = _make_tracking_client() |
| 158 | + # limit = 1 (first empty) + 10 (real ops) = 11, so all fit in one batch |
| 159 | + state = _make_state(mock_client, batch_time=5.0, max_ops=11) |
| 160 | + |
| 161 | + completion_events: list[CompletionEvent] = [] |
| 162 | + |
| 163 | + try: |
| 164 | + # 400 empty checkpoints (simulating concurrent branch resumes) |
| 165 | + for _ in range(400): |
| 166 | + ev = CompletionEvent() |
| 167 | + completion_events.append(ev) |
| 168 | + state._checkpoint_queue.put(QueuedOperation(None, ev)) # noqa: SLF001 |
| 169 | + |
| 170 | + # 10 real operations alongside the empties |
| 171 | + for i in range(10): |
| 172 | + op = OperationUpdate( |
| 173 | + operation_id=f"op_{i}", |
| 174 | + operation_type=OperationType.STEP, |
| 175 | + action=OperationAction.START, |
| 176 | + ) |
| 177 | + |
| 178 | + ev = CompletionEvent() |
| 179 | + completion_events.append(ev) |
| 180 | + state._checkpoint_queue.put(QueuedOperation(op, ev)) # noqa: SLF001 |
| 181 | + |
| 182 | + batcher = ThreadPoolExecutor(max_workers=1) |
| 183 | + batcher.submit(state.checkpoint_batches_forever) |
| 184 | + |
| 185 | + # Wait for all 410 to be processed |
| 186 | + for ev in completion_events: |
| 187 | + ev.wait() |
| 188 | + |
| 189 | + # 1 empty (effective=1) + 10 real ops (effective=11) exhaust the batch |
| 190 | + # limit exactly. The 399 remaining empties coalesce in → still 1 API call. |
| 191 | + assert len(calls) == 1, ( |
| 192 | + f"Expected 1 API call with 400 empty + 10 real ops (limit=11), " |
| 193 | + f"got {len(calls)}." |
| 194 | + ) |
| 195 | + # Only the 10 real ops appear in the updates list; empties are excluded. |
| 196 | + real_op_ids = {u.operation_id for batch in calls for u in batch} |
| 197 | + assert real_op_ids == {f"op_{i}" for i in range(10)} |
| 198 | + finally: |
| 199 | + state.stop_checkpointing() |
| 200 | + batcher.shutdown(wait=True) |
0 commit comments