Skip to content

Commit 864a41a

Browse files
authored
Return resolved bundle metadata from the simulation gateway (#435)
* Return bundle metadata from simulation gateway * Resolve gateway dataset aliases in bundle metadata * Preserve unresolved gateway dataset labels
1 parent e6c3378 commit 864a41a

5 files changed

Lines changed: 290 additions & 6 deletions

File tree

projects/policyengine-api-simulation/src/modal/gateway/endpoints.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,57 @@
1414
JobSubmitResponse,
1515
PingRequest,
1616
PingResponse,
17+
PolicyEngineBundle,
1718
SimulationRequest,
1819
)
1920

2021
logger = logging.getLogger(__name__)
2122

2223
router = APIRouter()
24+
JOB_METADATA_DICT_NAME = "simulation-api-job-metadata"
25+
DATASET_URIS = {
26+
"us": {
27+
"enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
28+
"enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
29+
"cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
30+
"cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
31+
"pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0",
32+
"pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0",
33+
},
34+
"uk": {
35+
"enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3",
36+
"enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3",
37+
"frs": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3",
38+
"frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3",
39+
},
40+
}
41+
42+
43+
def _job_metadata_store():
44+
return modal.Dict.from_name(JOB_METADATA_DICT_NAME, create_if_missing=True)
45+
46+
47+
def _build_policyengine_bundle(
48+
country: str, resolved_version: str, payload: dict
49+
) -> PolicyEngineBundle:
50+
dataset = payload.get("data")
51+
if isinstance(dataset, str) and "://" in dataset:
52+
resolved_dataset = dataset
53+
elif isinstance(dataset, str):
54+
resolved_dataset = DATASET_URIS.get(country.lower(), {}).get(dataset, dataset)
55+
else:
56+
resolved_dataset = None
57+
return PolicyEngineBundle(
58+
model_version=resolved_version,
59+
dataset=resolved_dataset,
60+
)
61+
62+
63+
def _serialize_job_metadata(resolved_app_name: str, bundle: PolicyEngineBundle) -> dict:
64+
return {
65+
"resolved_app_name": resolved_app_name,
66+
"policyengine_bundle": bundle.model_dump(),
67+
}
2368

2469

2570
def get_app_name(country: str, version: Optional[str]) -> tuple[str, str]:
@@ -74,12 +119,18 @@ async def submit_simulation(request: SimulationRequest):
74119
# Spawn the job (returns immediately)
75120
call = sim_func.spawn(payload)
76121

122+
bundle = _build_policyengine_bundle(request.country, resolved_version, payload)
123+
job_metadata = _serialize_job_metadata(app_name, bundle)
124+
_job_metadata_store()[call.object_id] = job_metadata
125+
77126
return JobSubmitResponse(
78127
job_id=call.object_id,
79128
status="submitted",
80129
poll_url=f"/jobs/{call.object_id}",
81130
country=request.country,
82131
version=resolved_version,
132+
resolved_app_name=app_name,
133+
policyengine_bundle=bundle,
83134
)
84135

85136

@@ -99,18 +150,32 @@ async def get_job_status(job_id: str):
99150
except Exception:
100151
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")
101152

153+
job_metadata = _job_metadata_store().get(job_id)
154+
102155
try:
103156
result = call.get(timeout=0)
104-
return JobStatusResponse(status="complete", result=result)
157+
return JobStatusResponse(
158+
status="complete", result=result, **(job_metadata or {})
159+
)
105160
except TimeoutError:
106161
return JSONResponse(
107162
status_code=202,
108-
content={"status": "running", "result": None, "error": None},
163+
content={
164+
"status": "running",
165+
"result": None,
166+
"error": None,
167+
**(job_metadata or {}),
168+
},
109169
)
110170
except Exception as e:
111171
return JSONResponse(
112172
status_code=500,
113-
content={"status": "failed", "result": None, "error": str(e)},
173+
content={
174+
"status": "failed",
175+
"result": None,
176+
"error": str(e),
177+
**(job_metadata or {}),
178+
},
114179
)
115180

116181

projects/policyengine-api-simulation/src/modal/gateway/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ class SimulationRequest(BaseModel):
1515
model_config = ConfigDict(extra="allow") # Pass through all other fields
1616

1717

18+
class PolicyEngineBundle(BaseModel):
19+
"""Resolved runtime provenance returned by the gateway."""
20+
21+
model_version: str
22+
policyengine_version: Optional[str] = None
23+
data_version: Optional[str] = None
24+
dataset: Optional[str] = None
25+
26+
1827
class JobSubmitResponse(BaseModel):
1928
"""Response model for job submission."""
2029

@@ -23,6 +32,8 @@ class JobSubmitResponse(BaseModel):
2332
poll_url: str
2433
country: str
2534
version: str
35+
resolved_app_name: str
36+
policyengine_bundle: PolicyEngineBundle
2637

2738

2839
class JobStatusResponse(BaseModel):
@@ -31,6 +42,8 @@ class JobStatusResponse(BaseModel):
3142
status: str
3243
result: Optional[dict] = None
3344
error: Optional[str] = None
45+
resolved_app_name: Optional[str] = None
46+
policyengine_bundle: Optional[PolicyEngineBundle] = None
3447

3548

3649
class PingRequest(BaseModel):

projects/policyengine-api-simulation/tests/fixtures/endpoints.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def __getitem__(self, key: str):
1818
raise KeyError(key)
1919
return self._data[key]
2020

21+
def __setitem__(self, key: str, value):
22+
self._data[key] = value
23+
24+
def get(self, key: str, default=None):
25+
return self._data.get(key, default)
26+
2127
@classmethod
2228
def from_name(cls, name: str):
2329
"""Mock from_name that returns a MockDict based on name."""
@@ -28,8 +34,27 @@ def from_name(cls, name: str):
2834
class MockFunctionCall:
2935
"""Mock for Modal FunctionCall returned by spawn."""
3036

37+
registry = {}
38+
3139
def __init__(self, object_id: str = "mock-job-id-123"):
3240
self.object_id = object_id
41+
self.result = {"budget": {"total": 1000000}}
42+
self.error = None
43+
self.running = False
44+
self.__class__.registry[object_id] = self
45+
46+
def get(self, timeout: int = 0):
47+
if self.running:
48+
raise TimeoutError()
49+
if self.error is not None:
50+
raise self.error
51+
return self.result
52+
53+
@classmethod
54+
def from_id(cls, object_id: str):
55+
if object_id not in cls.registry:
56+
raise KeyError(object_id)
57+
return cls.registry[object_id]
3358

3459

3560
class MockFunction:
@@ -38,10 +63,12 @@ class MockFunction:
3863
def __init__(self):
3964
self.last_payload = None
4065
self.last_from_name_call = None
66+
self.last_call = None
4167

4268
def spawn(self, payload: dict) -> MockFunctionCall:
4369
self.last_payload = payload
44-
return MockFunctionCall()
70+
self.last_call = MockFunctionCall()
71+
return self.last_call
4572

4673
@classmethod
4774
def from_name(cls, app_name: str, func_name: str):
@@ -73,10 +100,13 @@ def test_something(mock_modal, client):
73100
# Create mock objects
74101
mock_func = MockFunction()
75102
mock_dicts = {}
103+
MockFunctionCall.registry = {}
76104

77105
class MockModalDict:
78106
@staticmethod
79-
def from_name(name: str):
107+
def from_name(name: str, create_if_missing: bool = False):
108+
if create_if_missing and name not in mock_dicts:
109+
mock_dicts[name] = {}
80110
if name not in mock_dicts:
81111
raise KeyError(f"Mock dict not configured for: {name}")
82112
return MockDict(mock_dicts[name])
@@ -91,6 +121,7 @@ def from_name(app_name: str, func_name: str):
91121
class MockModal:
92122
Dict = MockModalDict
93123
Function = MockModalFunction
124+
FunctionCall = MockFunctionCall
94125

95126
# Patch the modal import in the endpoints module
96127
monkeypatch.setattr(endpoints, "modal", MockModal)

projects/policyengine-api-simulation/tests/gateway/test_endpoints.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from fastapi.testclient import TestClient
1010

11-
from tests.fixtures.endpoints import mock_modal # noqa: F401 - pytest fixture
11+
pytest_plugins = ("tests.fixtures.endpoints",)
1212

1313

1414
class TestGetAppName:
@@ -200,3 +200,146 @@ def test__given_submission__then_returns_job_id_and_poll_url(
200200
assert data["job_id"] == "mock-job-id-123"
201201
assert data["poll_url"] == "/jobs/mock-job-id-123"
202202
assert data["status"] == "submitted"
203+
204+
def test__given_submission_with_data__then_returns_resolved_bundle_metadata(
205+
self, mock_modal, client: TestClient
206+
):
207+
"""
208+
Given a simulation submission with an explicit data URI
209+
When the request completes
210+
Then the response exposes the resolved app and submitted dataset provenance.
211+
"""
212+
# Given
213+
mock_modal["dicts"]["simulation-api-us-versions"] = {
214+
"latest": "1.500.0",
215+
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
216+
}
217+
218+
request_body = {
219+
"country": "us",
220+
"scope": "macro",
221+
"reform": {},
222+
"data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
223+
}
224+
225+
# When
226+
response = client.post("/simulate/economy/comparison", json=request_body)
227+
228+
# Then
229+
assert response.status_code == 200
230+
data = response.json()
231+
assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0"
232+
assert data["policyengine_bundle"] == {
233+
"model_version": "1.500.0",
234+
"policyengine_version": None,
235+
"data_version": None,
236+
"dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
237+
}
238+
239+
def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved(
240+
self, mock_modal, client: TestClient
241+
):
242+
mock_modal["dicts"]["simulation-api-us-versions"] = {
243+
"latest": "1.500.0",
244+
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
245+
}
246+
247+
request_body = {
248+
"country": "us",
249+
"scope": "macro",
250+
"reform": {},
251+
"data": "enhanced_cps_2024",
252+
}
253+
254+
response = client.post("/simulate/economy/comparison", json=request_body)
255+
256+
assert response.status_code == 200
257+
data = response.json()
258+
assert (
259+
data["policyengine_bundle"]["dataset"]
260+
== "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0"
261+
)
262+
263+
def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_uri(
264+
self, mock_modal, client: TestClient
265+
):
266+
mock_modal["dicts"]["simulation-api-uk-versions"] = {
267+
"latest": "2.66.0",
268+
"2.66.0": "policyengine-simulation-us1-500-0-uk2-66-0",
269+
}
270+
271+
request_body = {
272+
"country": "uk",
273+
"scope": "macro",
274+
"reform": {},
275+
"data": "enhanced_frs",
276+
}
277+
278+
response = client.post("/simulate/economy/comparison", json=request_body)
279+
280+
assert response.status_code == 200
281+
data = response.json()
282+
assert (
283+
data["policyengine_bundle"]["dataset"]
284+
== "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3"
285+
)
286+
287+
def test__given_submission_with_unknown_alias_data__then_bundle_dataset_is_preserved(
288+
self, mock_modal, client: TestClient
289+
):
290+
mock_modal["dicts"]["simulation-api-us-versions"] = {
291+
"latest": "1.500.0",
292+
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
293+
}
294+
295+
request_body = {
296+
"country": "us",
297+
"scope": "macro",
298+
"reform": {},
299+
"data": "custom_dataset_label",
300+
}
301+
302+
response = client.post("/simulate/economy/comparison", json=request_body)
303+
304+
assert response.status_code == 200
305+
data = response.json()
306+
assert data["policyengine_bundle"]["dataset"] == "custom_dataset_label"
307+
308+
def test__given_submitted_job__then_job_status_includes_bundle_metadata(
309+
self, mock_modal, client: TestClient
310+
):
311+
"""
312+
Given a submitted simulation job
313+
When polling job status
314+
Then the resolved bundle metadata is returned with the status response.
315+
"""
316+
# Given
317+
mock_modal["dicts"]["simulation-api-us-versions"] = {
318+
"latest": "1.500.0",
319+
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
320+
}
321+
322+
submit_response = client.post(
323+
"/simulate/economy/comparison",
324+
json={
325+
"country": "us",
326+
"scope": "macro",
327+
"reform": {},
328+
"data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
329+
},
330+
)
331+
332+
# When
333+
response = client.get(f"/jobs/{submit_response.json()['job_id']}")
334+
335+
# Then
336+
assert response.status_code == 200
337+
data = response.json()
338+
assert data["status"] == "complete"
339+
assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0"
340+
assert data["policyengine_bundle"] == {
341+
"model_version": "1.500.0",
342+
"policyengine_version": None,
343+
"data_version": None,
344+
"dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
345+
}

0 commit comments

Comments
 (0)