-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathconfig_test.py
More file actions
290 lines (217 loc) · 8.84 KB
/
config_test.py
File metadata and controls
290 lines (217 loc) · 8.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""Unit tests for config module."""
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock
from aws_durable_execution_sdk_python.config import (
BatchedInput,
CallbackConfig,
CheckpointMode,
ChildConfig,
CompletionConfig,
Duration,
InvokeConfig,
ItemBatcher,
ItemsPerBatchUnit,
MapConfig,
ParallelConfig,
StepConfig,
StepFuture,
StepSemantics,
TerminationMode,
)
from aws_durable_execution_sdk_python.waits import (
WaitForConditionConfig,
WaitForConditionDecision,
)
def test_batched_input():
"""Test BatchedInput dataclass."""
batch_input = BatchedInput("batch", [1, 2, 3])
assert batch_input.batch_input == "batch"
assert batch_input.items == [1, 2, 3]
def test_completion_config_defaults():
"""Test CompletionConfig default values."""
config = CompletionConfig()
assert config.min_successful is None
assert config.tolerated_failure_count is None
assert config.tolerated_failure_percentage is None
def test_completion_config_first_completed():
"""Test CompletionConfig.first_completed factory method."""
# first_completed is commented out, so this test should be skipped or removed
def test_completion_config_first_successful():
"""Test CompletionConfig.first_successful factory method."""
config = CompletionConfig.first_successful()
assert config.min_successful == 1
assert config.tolerated_failure_count is None
assert config.tolerated_failure_percentage is None
def test_completion_config_all_completed():
"""Test CompletionConfig.all_completed factory method."""
config = CompletionConfig.all_completed()
assert config.min_successful is None
assert config.tolerated_failure_count is None
assert config.tolerated_failure_percentage is None
def test_completion_config_all_successful():
"""Test CompletionConfig.all_successful factory method."""
config = CompletionConfig.all_successful()
assert config.min_successful is None
assert config.tolerated_failure_count == 0
assert config.tolerated_failure_percentage == 0
def test_termination_mode_enum():
"""Test TerminationMode enum."""
assert TerminationMode.TERMINATE.value == "TERMINATE"
assert TerminationMode.CANCEL.value == "CANCEL"
assert TerminationMode.WAIT.value == "WAIT"
assert TerminationMode.ABANDON.value == "ABANDON"
def test_parallel_config_defaults():
"""Test ParallelConfig default values."""
config = ParallelConfig()
assert config.max_concurrency is None
assert isinstance(config.completion_config, CompletionConfig)
def test_wait_for_condition_decision_continue():
"""Test WaitForConditionDecision.continue_waiting factory method."""
decision = WaitForConditionDecision.continue_waiting(Duration.from_seconds(30))
assert decision.should_continue is True
assert decision.delay_seconds == 30
def test_wait_for_condition_decision_stop():
"""Test WaitForConditionDecision.stop_polling factory method."""
decision = WaitForConditionDecision.stop_polling()
assert decision.should_continue is False
assert decision.delay_seconds == 0
def test_wait_for_condition_config():
"""Test WaitForConditionConfig with custom values."""
def wait_strategy(state, attempt):
return WaitForConditionDecision.continue_waiting(Duration.from_seconds(10))
serdes = Mock()
config = WaitForConditionConfig(
wait_strategy=wait_strategy, initial_state="test_state", serdes=serdes
)
assert config.wait_strategy is wait_strategy
assert config.initial_state == "test_state"
assert config.serdes is serdes
def test_step_semantics_enum():
"""Test StepSemantics enum."""
assert StepSemantics.AT_MOST_ONCE_PER_RETRY.value == "AT_MOST_ONCE_PER_RETRY"
assert StepSemantics.AT_LEAST_ONCE_PER_RETRY.value == "AT_LEAST_ONCE_PER_RETRY"
def test_step_config_defaults():
"""Test StepConfig default values."""
config = StepConfig()
assert config.retry_strategy is None
assert config.step_semantics == StepSemantics.AT_LEAST_ONCE_PER_RETRY
assert config.serdes is None
def test_step_config_with_values():
"""Test StepConfig with custom values."""
retry_strategy = Mock()
serdes = Mock()
config = StepConfig(
retry_strategy=retry_strategy,
step_semantics=StepSemantics.AT_MOST_ONCE_PER_RETRY,
serdes=serdes,
)
assert config.retry_strategy is retry_strategy
assert config.step_semantics == StepSemantics.AT_MOST_ONCE_PER_RETRY
assert config.serdes is serdes
def test_checkpoint_mode_enum():
"""Test CheckpointMode enum."""
assert CheckpointMode.NO_CHECKPOINT.value == ("NO_CHECKPOINT",)
assert CheckpointMode.CHECKPOINT_AT_FINISH.value == ("CHECKPOINT_AT_FINISH",)
assert (
CheckpointMode.CHECKPOINT_AT_START_AND_FINISH.value
== "CHECKPOINT_AT_START_AND_FINISH"
)
def test_child_config_defaults():
"""Test ChildConfig default values."""
config = ChildConfig()
assert config.serdes is None
assert config.sub_type is None
def test_child_config_with_serdes():
"""Test ChildConfig with serdes."""
serdes = Mock()
config = ChildConfig(serdes=serdes)
assert config.serdes is serdes
assert config.sub_type is None
def test_child_config_with_sub_type():
"""Test ChildConfig with sub_type."""
sub_type = Mock()
config = ChildConfig(sub_type=sub_type)
assert config.serdes is None
assert config.sub_type is sub_type
def test_child_config_with_summary_generator():
"""Test ChildConfig with summary_generator."""
def mock_summary_generator(result):
return f"Summary of {result}"
config = ChildConfig(summary_generator=mock_summary_generator)
assert config.serdes is None
assert config.sub_type is None
assert config.summary_generator is mock_summary_generator
# Test that the summary generator works
result = config.summary_generator("test_data")
assert result == "Summary of test_data"
def test_items_per_batch_unit_enum():
"""Test ItemsPerBatchUnit enum."""
assert ItemsPerBatchUnit.COUNT.value == ("COUNT",)
assert ItemsPerBatchUnit.BYTES.value == "BYTES"
def test_item_batcher_defaults():
"""Test ItemBatcher default values."""
batcher = ItemBatcher()
assert batcher.max_items_per_batch == 0
assert batcher.max_item_bytes_per_batch == 0
assert batcher.batch_input is None
def test_item_batcher_with_values():
"""Test ItemBatcher with custom values."""
batcher = ItemBatcher(
max_items_per_batch=100, max_item_bytes_per_batch=1024, batch_input="test_input"
)
assert batcher.max_items_per_batch == 100
assert batcher.max_item_bytes_per_batch == 1024
assert batcher.batch_input == "test_input"
def test_map_config_defaults():
"""Test MapConfig default values."""
config = MapConfig()
assert config.max_concurrency is None
assert isinstance(config.item_batcher, ItemBatcher)
assert isinstance(config.completion_config, CompletionConfig)
assert config.serdes is None
def test_callback_config_defaults():
"""Test CallbackConfig default values."""
config = CallbackConfig()
assert config.timeout_seconds == 0
assert config.heartbeat_timeout_seconds == 0
assert config.serdes is None
def test_callback_config_with_values():
"""Test CallbackConfig with custom values."""
serdes = Mock()
config = CallbackConfig(
timeout=Duration.from_seconds(30),
heartbeat_timeout=Duration.from_seconds(10),
serdes=serdes,
)
assert config.timeout_seconds == 30
assert config.heartbeat_timeout_seconds == 10
assert config.serdes is serdes
def test_step_future():
"""Test StepFuture with Future."""
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: "test_result")
step_future = StepFuture(future, "test_step")
result = step_future.result()
assert result == "test_result"
def test_step_future_with_timeout():
"""Test StepFuture result with timeout."""
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: "test_result")
step_future = StepFuture(future)
result = step_future.result(timeout_seconds=1)
assert result == "test_result"
def test_step_future_without_name():
"""Test StepFuture without name."""
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: 42)
step_future = StepFuture(future)
result = step_future.result()
assert result == 42
def test_invoke_config_defaults():
"""Test InvokeConfig defaults."""
config = InvokeConfig()
assert config.tenant_id is None
def test_invoke_config_with_tenant_id():
"""Test InvokeConfig with explicit tenant_id."""
config = InvokeConfig(tenant_id="test-tenant")
assert config.tenant_id == "test-tenant"