Skip to content

Commit 51a7d8f

Browse files
authored
Add managed release bundle runtime (#276)
* Add managed release bundle runtime * Add changelog fragment for managed runtime * Fix lint for managed runtime PR * Format managed runtime files
1 parent 3d70790 commit 51a7d8f

15 files changed

Lines changed: 495 additions & 420 deletions
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added managed release-bundle runtime enforcement for bundled US and UK microsimulations, including manifest-backed dataset pinning and runtime bundle metadata.

src/policyengine/core/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
)
2626
from .release_manifest import get_data_release_manifest as get_data_release_manifest
2727
from .release_manifest import get_release_manifest as get_release_manifest
28+
from .release_manifest import (
29+
resolve_managed_dataset_reference as resolve_managed_dataset_reference,
30+
)
2831
from .scoping_strategy import RegionScopingStrategy as RegionScopingStrategy
2932
from .scoping_strategy import RowFilterStrategy as RowFilterStrategy
3033
from .scoping_strategy import ScopingStrategy as ScopingStrategy
@@ -36,13 +39,6 @@
3639
from .tax_benefit_model_version import (
3740
TaxBenefitModelVersion as TaxBenefitModelVersion,
3841
)
39-
from .trace_tro import (
40-
build_trace_tro_from_release_bundle as build_trace_tro_from_release_bundle,
41-
)
42-
from .trace_tro import (
43-
compute_trace_composition_fingerprint as compute_trace_composition_fingerprint,
44-
)
45-
from .trace_tro import serialize_trace_tro as serialize_trace_tro
4642
from .variable import Variable as Variable
4743

4844
# Rebuild models to resolve forward references

src/policyengine/core/release_manifest.py

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import os
22
from functools import lru_cache
3-
from importlib import import_module, metadata
3+
from importlib import import_module
44
from importlib.resources import files
55
from pathlib import Path
66

77
import requests
88
from pydantic import BaseModel, Field
99

1010
HF_REQUEST_TIMEOUT_SECONDS = 30
11-
12-
13-
class DataReleaseManifestUnavailable(ValueError):
14-
pass
11+
LOCAL_DATA_REPO_HINTS = {
12+
"us": ("policyengine_us", "policyengine-us-data", "policyengine_us_data"),
13+
"uk": ("policyengine_uk", "policyengine-uk-data", "policyengine_uk_data"),
14+
}
1515

1616

1717
class PackageVersion(BaseModel):
@@ -126,28 +126,6 @@ def build_hf_uri(repo_id: str, path_in_repo: str, revision: str) -> str:
126126
return f"hf://{repo_id}/{path_in_repo}@{revision}"
127127

128128

129-
def get_runtime_model_build_metadata(package_name: str) -> dict[str, str | None]:
130-
installed_version = metadata.version(package_name)
131-
module_name = package_name.replace("-", "_")
132-
133-
try:
134-
build_metadata_module = import_module(f"{module_name}.build_metadata")
135-
except Exception:
136-
return {
137-
"name": package_name,
138-
"version": installed_version,
139-
"git_sha": None,
140-
"data_build_fingerprint": None,
141-
}
142-
143-
build_metadata = build_metadata_module.get_data_build_metadata()
144-
build_metadata.setdefault("name", package_name)
145-
build_metadata.setdefault("version", installed_version)
146-
build_metadata.setdefault("git_sha", None)
147-
build_metadata.setdefault("data_build_fingerprint", None)
148-
return build_metadata
149-
150-
151129
@lru_cache
152130
def get_release_manifest(country_id: str) -> CountryReleaseManifest:
153131
manifest_path = files("policyengine").joinpath(
@@ -183,15 +161,10 @@ def get_data_release_manifest(country_id: str) -> DataReleaseManifest:
183161
timeout=HF_REQUEST_TIMEOUT_SECONDS,
184162
)
185163
if response.status_code in (401, 403):
186-
raise DataReleaseManifestUnavailable(
164+
raise ValueError(
187165
"Could not fetch the data release manifest from Hugging Face. "
188166
"If this country uses a private data repo, set HUGGING_FACE_TOKEN."
189167
)
190-
if response.status_code == 404:
191-
raise DataReleaseManifestUnavailable(
192-
"Could not find the data release manifest on Hugging Face for "
193-
f"{data_package.repo_id}@{data_package.version}."
194-
)
195168
response.raise_for_status()
196169
return DataReleaseManifest.model_validate_json(response.text)
197170

@@ -208,7 +181,17 @@ def certify_data_release_compatibility(
208181
runtime_data_build_fingerprint: str | None = None,
209182
) -> DataCertification:
210183
country_manifest = get_release_manifest(country_id)
211-
data_release_manifest = get_data_release_manifest(country_id)
184+
try:
185+
data_release_manifest = get_data_release_manifest(country_id)
186+
except Exception as exc:
187+
bundled_certification = country_manifest.certification
188+
if (
189+
bundled_certification is not None
190+
and bundled_certification.certified_for_model_version
191+
== runtime_model_version
192+
):
193+
return bundled_certification
194+
raise exc
212195
built_with_model = (
213196
data_release_manifest.build.built_with_model_package
214197
if data_release_manifest.build is not None
@@ -295,37 +278,6 @@ def certify_data_release_compatibility(
295278
)
296279

297280

298-
def resolve_runtime_data_certification(
299-
country_id: str,
300-
runtime_model_version: str,
301-
runtime_data_build_fingerprint: str | None = None,
302-
bundled_certification: DataCertification | None = None,
303-
) -> DataCertification:
304-
try:
305-
return certify_data_release_compatibility(
306-
country_id=country_id,
307-
runtime_model_version=runtime_model_version,
308-
runtime_data_build_fingerprint=runtime_data_build_fingerprint,
309-
)
310-
except DataReleaseManifestUnavailable:
311-
if (
312-
bundled_certification is not None
313-
and bundled_certification.certified_for_model_version
314-
== runtime_model_version
315-
):
316-
bundled_fingerprint = bundled_certification.data_build_fingerprint
317-
if (
318-
bundled_certification.compatibility_basis
319-
== "matching_data_build_fingerprint"
320-
and bundled_fingerprint is not None
321-
and runtime_data_build_fingerprint is not None
322-
and bundled_fingerprint != runtime_data_build_fingerprint
323-
):
324-
raise
325-
return bundled_certification
326-
raise
327-
328-
329281
def resolve_dataset_reference(country_id: str, dataset: str) -> str:
330282
if "://" in dataset:
331283
return dataset
@@ -350,6 +302,82 @@ def resolve_dataset_reference(country_id: str, dataset: str) -> str:
350302
return artifact.uri
351303

352304

305+
def resolve_managed_dataset_reference(
306+
country_id: str,
307+
dataset: str | None = None,
308+
*,
309+
allow_unmanaged: bool = False,
310+
) -> str:
311+
"""Resolve a dataset reference under policyengine.py bundle enforcement.
312+
313+
Managed mode pins dataset selection to the bundled `policyengine.py`
314+
release manifest. Callers can:
315+
316+
- omit `dataset` to use the certified default dataset for the bundle
317+
- pass a logical dataset name present in the bundled/data-release manifests
318+
319+
Direct URLs or raw Hugging Face references are treated as unmanaged unless
320+
`allow_unmanaged=True` is set explicitly.
321+
"""
322+
323+
manifest = get_release_manifest(country_id)
324+
if dataset is None:
325+
return manifest.default_dataset_uri
326+
327+
if "://" in dataset:
328+
if dataset == manifest.default_dataset_uri:
329+
return dataset
330+
if allow_unmanaged:
331+
return dataset
332+
raise ValueError(
333+
"Explicit dataset URIs bypass the policyengine.py release bundle. "
334+
"Pass a manifest dataset name or omit `dataset` to use the certified "
335+
"default dataset. Set `allow_unmanaged=True` only if you intend to "
336+
"bypass bundle enforcement."
337+
)
338+
339+
return resolve_dataset_reference(country_id, dataset)
340+
341+
342+
def resolve_local_managed_dataset_source(country_id: str, dataset_uri: str) -> str:
343+
"""Resolve a local mirror of a managed dataset when available.
344+
345+
This preserves the bundled dataset URI for provenance while allowing local
346+
development environments with sibling data-repo checkouts to load the
347+
exact certified artifact from disk rather than re-downloading it.
348+
"""
349+
350+
if not dataset_uri.startswith("hf://"):
351+
return dataset_uri
352+
353+
local_hint = LOCAL_DATA_REPO_HINTS.get(country_id)
354+
if local_hint is None:
355+
return dataset_uri
356+
357+
path_without_revision = dataset_uri[5:].rsplit("@", 1)[0]
358+
parts = path_without_revision.split("/", 2)
359+
if len(parts) != 3:
360+
return dataset_uri
361+
_, _, path_in_repo = parts
362+
363+
model_module_name, data_repo_name, data_package_name = local_hint
364+
try:
365+
model_module = import_module(model_module_name)
366+
except ImportError:
367+
return dataset_uri
368+
369+
repo_root = Path(model_module.__file__).resolve().parents[1]
370+
local_path = (
371+
repo_root.with_name(data_repo_name)
372+
/ data_package_name
373+
/ "storage"
374+
/ path_in_repo
375+
)
376+
if local_path.exists():
377+
return str(local_path)
378+
return dataset_uri
379+
380+
353381
def dataset_logical_name(dataset: str) -> str:
354382
return Path(dataset.rsplit("@", 1)[0]).stem
355383

src/policyengine/core/tax_benefit_model_version.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,8 @@
44

55
from pydantic import BaseModel, Field
66

7-
from .release_manifest import (
8-
CountryReleaseManifest,
9-
DataCertification,
10-
PackageVersion,
11-
get_data_release_manifest,
12-
)
7+
from .release_manifest import CountryReleaseManifest, DataCertification, PackageVersion
138
from .tax_benefit_model import TaxBenefitModel
14-
from .trace_tro import build_trace_tro_from_release_bundle
159

1610
if TYPE_CHECKING:
1711
from .parameter import Parameter
@@ -207,22 +201,6 @@ def release_bundle(self) -> dict[str, str | None]:
207201
),
208202
}
209203

210-
@property
211-
def trace_tro(self) -> dict:
212-
if self.release_manifest is None:
213-
raise ValueError(
214-
"TRACE TRO export requires a bundled country release manifest."
215-
)
216-
217-
data_release_manifest = get_data_release_manifest(
218-
self.release_manifest.country_id
219-
)
220-
return build_trace_tro_from_release_bundle(
221-
self.release_manifest,
222-
data_release_manifest,
223-
certification=self.data_certification,
224-
)
225-
226204
def __repr__(self) -> str:
227205
# Give the id and version, and the number of variables, parameters, parameter nodes, parameter values
228206
return f"<TaxBenefitModelVersion id={self.id} variables={len(self.variables)} parameters={len(self.parameters)} parameter_nodes={len(self.parameter_nodes)} parameter_values={len(self.parameter_values)}>"

src/policyengine/outputs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
UK_INEQUALITY_INCOME_VARIABLE,
2121
US_INEQUALITY_INCOME_VARIABLE,
2222
Inequality,
23+
USInequalityPreset,
2324
calculate_uk_inequality,
2425
calculate_us_inequality,
2526
)
@@ -76,6 +77,7 @@
7677
"GENDER_GROUPS",
7778
"RACE_GROUPS",
7879
"Inequality",
80+
"USInequalityPreset",
7981
"UK_INEQUALITY_INCOME_VARIABLE",
8082
"US_INEQUALITY_INCOME_VARIABLE",
8183
"calculate_uk_inequality",

0 commit comments

Comments
 (0)