-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconftest.py
More file actions
269 lines (214 loc) · 9.15 KB
/
conftest.py
File metadata and controls
269 lines (214 loc) · 9.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""Pytest configuration and fixtures for durable execution tests."""
import contextlib
import json
import logging
import os
import sys
from enum import StrEnum
from pathlib import Path
from typing import Any
import pytest
from aws_durable_execution_sdk_python.lambda_service import (
ErrorObject,
OperationPayload,
)
from aws_durable_execution_sdk_python.serdes import ExtendedTypeSerDes
from aws_durable_execution_sdk_python_testing.runner import (
DurableFunctionCloudTestRunner,
DurableFunctionTestResult,
DurableFunctionTestRunner,
)
# Add examples/src to Python path for imports
examples_src = Path(__file__).parent.parent / "src"
if str(examples_src) not in sys.path:
sys.path.insert(0, str(examples_src))
logger = logging.getLogger(__name__)
def deserialize_operation_payload(
payload: OperationPayload | None, serdes: ExtendedTypeSerDes | None = None
) -> Any:
"""Deserialize an operation payload using the provided or default serializer.
This utility function helps test code deserialize operation results that are
returned as raw strings. It supports both the default ExtendedTypeSerDes and
custom serializers.
Args:
payload: The operation payload string to deserialize, or None.
serdes: Optional custom serializer. If None, uses ExtendedTypeSerDes.
Returns:
Deserialized result object, or None if payload is None.
"""
if not payload:
return None
if serdes is None:
serdes = ExtendedTypeSerDes()
try:
return serdes.deserialize(payload)
except Exception:
# Fallback to plain JSON for backwards compatibility
return json.loads(payload)
class RunnerMode(StrEnum):
"""Runner mode for local or cloud execution."""
LOCAL = "local"
CLOUD = "cloud"
def pytest_addoption(parser):
"""Add custom command line options for test execution."""
parser.addoption(
"--runner-mode",
action="store",
default=RunnerMode.LOCAL,
choices=[RunnerMode.LOCAL, RunnerMode.CLOUD],
help="Test runner mode: local (in-memory) or cloud (deployed Lambda)",
)
class TestRunnerAdapter:
"""Adapter that provides consistent interface for both local and cloud runners.
This adapter encapsulates the differences between local and cloud test runners:
- Local runner: Requires context manager for resource cleanup (scheduler thread)
- Cloud runner: No resource cleanup needed (stateless boto3 client)
The adapter ensures proper resource management while providing a unified interface.
"""
def __init__(
self,
runner: DurableFunctionTestRunner | DurableFunctionCloudTestRunner,
mode: str,
):
"""Initialize the adapter."""
self._runner: DurableFunctionTestRunner | DurableFunctionCloudTestRunner = (
runner
)
self._mode: str = mode
def run(
self,
input: str | None = None, # noqa: A002
timeout: int = 60,
skip_time: bool = False,
) -> DurableFunctionTestResult:
"""Execute the durable function and return results."""
return self._runner.run(input=input, timeout=timeout, skip_time=skip_time)
def run_async(
self,
input: str | None = None, # noqa: A002
timeout: int = 60,
) -> str:
return self._runner.run_async(input=input, timeout=timeout)
def send_callback_success(
self, callback_id: str, result: bytes | None = None
) -> None:
self._runner.send_callback_success(callback_id=callback_id, result=result)
def send_callback_failure(
self, callback_id: str, error: ErrorObject | None = None
) -> None:
self._runner.send_callback_failure(callback_id=callback_id, error=error)
def send_callback_heartbeat(self, callback_id: str) -> None:
self._runner.send_callback_heartbeat(callback_id=callback_id)
def wait_for_result(
self, execution_arn: str, timeout: int = 60
) -> DurableFunctionTestResult:
return self._runner.wait_for_result(
execution_arn=execution_arn, timeout=timeout
)
def wait_for_callback(
self, execution_arn: str, name: str | None = None, timeout: int = 60
) -> str:
return self._runner.wait_for_callback(
execution_arn=execution_arn, name=name, timeout=timeout
)
@property
def mode(self) -> str:
"""Get the runner mode (local or cloud)."""
return self._mode
def __enter__(self):
"""Context manager entry - only calls runner's __enter__ if it's a context manager."""
if isinstance(self._runner, contextlib.AbstractContextManager):
self._runner.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - only calls runner's __exit__ if it's a context manager."""
if isinstance(self._runner, contextlib.AbstractContextManager):
return self._runner.__exit__(exc_type, exc_val, exc_tb)
return None
@pytest.fixture
def durable_runner(request):
"""Pytest fixture that provides a test runner based on configuration.
Configuration for cloud mode:
Environment variables (required):
AWS_REGION: AWS region for Lambda invocation (default: us-west-2)
LAMBDA_ENDPOINT: Optional Lambda endpoint URL
PYTEST_FUNCTION_NAME_MAP: JSON mapping of example names to deployed function names
CLI option:
--runner-mode=cloud (or local, default: local)
Example:
AWS_REGION=us-west-2 \
LAMBDA_ENDPOINT=https://lambda.us-west-2.amazonaws.com \
PYTEST_FUNCTION_NAME_MAP='{"hello world":"HelloWorld:$LATEST"}' \
pytest --runner-mode=cloud -k test_hello_world
Usage in tests:
@pytest.mark.durable_execution(
handler=hello_world.handler,
lambda_function_name="hello world"
)
def test_hello_world(durable_runner):
with durable_runner:
result = durable_runner.run(input="test", timeout=10)
assert result.status == InvocationStatus.SUCCEEDED
"""
# Get marker with test configuration
marker = request.node.get_closest_marker("durable_execution")
if not marker:
pytest.fail("Test must be marked with @pytest.mark.durable_execution")
handler: Any = marker.kwargs.get("handler")
lambda_function_name: str | None = marker.kwargs.get("lambda_function_name")
# Get runner mode from CLI option
runner_mode: str = request.config.getoption("--runner-mode")
logger.info("Running test in %s mode", runner_mode.upper())
# Create appropriate runner
if runner_mode == RunnerMode.CLOUD:
# Get deployed function name and AWS config from environment
deployed_name = _get_deployed_function_name(request, lambda_function_name)
region = os.environ.get("AWS_REGION", "us-west-2")
lambda_endpoint = os.environ.get("LAMBDA_ENDPOINT")
logger.info("Using AWS region: %s", region)
# Create cloud runner (no cleanup needed)
runner = DurableFunctionCloudTestRunner(
function_name=deployed_name,
region=region,
lambda_endpoint=lambda_endpoint,
)
else:
if not handler:
pytest.fail("handler is required for local mode tests")
# Create local runner (needs cleanup via context manager)
runner = DurableFunctionTestRunner(handler=handler)
# Wrap in adapter and use context manager for proper cleanup
with TestRunnerAdapter(runner, runner_mode) as adapter:
yield adapter
def _get_deployed_function_name(
request: pytest.FixtureRequest,
lambda_function_name: str | None,
) -> str:
"""Get the deployed function name from environment variables.
Required environment variables:
- QUALIFIED_FUNCTION_NAME: The qualified function ARN (e.g., "MyFunction:$LATEST")
- LAMBDA_FUNCTION_TEST_NAME: The lambda function name to match against test markers
Tests are skipped if the test's lambda_function_name doesn't match LAMBDA_FUNCTION_TEST_NAME.
"""
if not lambda_function_name:
pytest.fail("lambda_function_name is required for cloud mode tests")
# Get from environment variables
function_arn = os.environ.get("QUALIFIED_FUNCTION_NAME")
env_function_name = os.environ.get("LAMBDA_FUNCTION_TEST_NAME")
if not function_arn or not env_function_name:
pytest.fail(
"Cloud mode requires both QUALIFIED_FUNCTION_NAME and LAMBDA_FUNCTION_TEST_NAME environment variables\n"
'Example: QUALIFIED_FUNCTION_NAME="MyFunction:$LATEST" LAMBDA_FUNCTION_TEST_NAME="hello world" pytest --runner-mode=cloud'
)
# Check if this test matches the function name (case-insensitive)
if lambda_function_name.lower() == env_function_name.lower():
logger.info(
"Using function ARN: %s for lambda function: %s",
function_arn,
env_function_name,
)
return function_arn
# This test doesn't match the function name, skip it
pytest.skip(
f"Test '{lambda_function_name}' doesn't match LAMBDA_FUNCTION_TEST_NAME '{env_function_name}'"
)