11import os
22from functools import lru_cache
3- from importlib import import_module , metadata
3+ from importlib import import_module
44from importlib .resources import files
55from pathlib import Path
66
77import requests
88from pydantic import BaseModel , Field
99
1010HF_REQUEST_TIMEOUT_SECONDS = 30
11-
12-
13- class DataReleaseManifestUnavailable ( ValueError ):
14- pass
11+ LOCAL_DATA_REPO_HINTS = {
12+ "us" : ( "policyengine_us" , "policyengine-us-data" , "policyengine_us_data" ),
13+ "uk" : ( "policyengine_uk" , "policyengine-uk-data" , "policyengine_uk_data" ),
14+ }
1515
1616
1717class PackageVersion (BaseModel ):
@@ -126,28 +126,6 @@ def build_hf_uri(repo_id: str, path_in_repo: str, revision: str) -> str:
126126 return f"hf://{ repo_id } /{ path_in_repo } @{ revision } "
127127
128128
129- def get_runtime_model_build_metadata (package_name : str ) -> dict [str , str | None ]:
130- installed_version = metadata .version (package_name )
131- module_name = package_name .replace ("-" , "_" )
132-
133- try :
134- build_metadata_module = import_module (f"{ module_name } .build_metadata" )
135- except Exception :
136- return {
137- "name" : package_name ,
138- "version" : installed_version ,
139- "git_sha" : None ,
140- "data_build_fingerprint" : None ,
141- }
142-
143- build_metadata = build_metadata_module .get_data_build_metadata ()
144- build_metadata .setdefault ("name" , package_name )
145- build_metadata .setdefault ("version" , installed_version )
146- build_metadata .setdefault ("git_sha" , None )
147- build_metadata .setdefault ("data_build_fingerprint" , None )
148- return build_metadata
149-
150-
151129@lru_cache
152130def get_release_manifest (country_id : str ) -> CountryReleaseManifest :
153131 manifest_path = files ("policyengine" ).joinpath (
@@ -183,15 +161,10 @@ def get_data_release_manifest(country_id: str) -> DataReleaseManifest:
183161 timeout = HF_REQUEST_TIMEOUT_SECONDS ,
184162 )
185163 if response .status_code in (401 , 403 ):
186- raise DataReleaseManifestUnavailable (
164+ raise ValueError (
187165 "Could not fetch the data release manifest from Hugging Face. "
188166 "If this country uses a private data repo, set HUGGING_FACE_TOKEN."
189167 )
190- if response .status_code == 404 :
191- raise DataReleaseManifestUnavailable (
192- "Could not find the data release manifest on Hugging Face for "
193- f"{ data_package .repo_id } @{ data_package .version } ."
194- )
195168 response .raise_for_status ()
196169 return DataReleaseManifest .model_validate_json (response .text )
197170
@@ -208,7 +181,17 @@ def certify_data_release_compatibility(
208181 runtime_data_build_fingerprint : str | None = None ,
209182) -> DataCertification :
210183 country_manifest = get_release_manifest (country_id )
211- data_release_manifest = get_data_release_manifest (country_id )
184+ try :
185+ data_release_manifest = get_data_release_manifest (country_id )
186+ except Exception as exc :
187+ bundled_certification = country_manifest .certification
188+ if (
189+ bundled_certification is not None
190+ and bundled_certification .certified_for_model_version
191+ == runtime_model_version
192+ ):
193+ return bundled_certification
194+ raise exc
212195 built_with_model = (
213196 data_release_manifest .build .built_with_model_package
214197 if data_release_manifest .build is not None
@@ -295,37 +278,6 @@ def certify_data_release_compatibility(
295278 )
296279
297280
298- def resolve_runtime_data_certification (
299- country_id : str ,
300- runtime_model_version : str ,
301- runtime_data_build_fingerprint : str | None = None ,
302- bundled_certification : DataCertification | None = None ,
303- ) -> DataCertification :
304- try :
305- return certify_data_release_compatibility (
306- country_id = country_id ,
307- runtime_model_version = runtime_model_version ,
308- runtime_data_build_fingerprint = runtime_data_build_fingerprint ,
309- )
310- except DataReleaseManifestUnavailable :
311- if (
312- bundled_certification is not None
313- and bundled_certification .certified_for_model_version
314- == runtime_model_version
315- ):
316- bundled_fingerprint = bundled_certification .data_build_fingerprint
317- if (
318- bundled_certification .compatibility_basis
319- == "matching_data_build_fingerprint"
320- and bundled_fingerprint is not None
321- and runtime_data_build_fingerprint is not None
322- and bundled_fingerprint != runtime_data_build_fingerprint
323- ):
324- raise
325- return bundled_certification
326- raise
327-
328-
329281def resolve_dataset_reference (country_id : str , dataset : str ) -> str :
330282 if "://" in dataset :
331283 return dataset
@@ -350,6 +302,82 @@ def resolve_dataset_reference(country_id: str, dataset: str) -> str:
350302 return artifact .uri
351303
352304
305+ def resolve_managed_dataset_reference (
306+ country_id : str ,
307+ dataset : str | None = None ,
308+ * ,
309+ allow_unmanaged : bool = False ,
310+ ) -> str :
311+ """Resolve a dataset reference under policyengine.py bundle enforcement.
312+
313+ Managed mode pins dataset selection to the bundled `policyengine.py`
314+ release manifest. Callers can:
315+
316+ - omit `dataset` to use the certified default dataset for the bundle
317+ - pass a logical dataset name present in the bundled/data-release manifests
318+
319+ Direct URLs or raw Hugging Face references are treated as unmanaged unless
320+ `allow_unmanaged=True` is set explicitly.
321+ """
322+
323+ manifest = get_release_manifest (country_id )
324+ if dataset is None :
325+ return manifest .default_dataset_uri
326+
327+ if "://" in dataset :
328+ if dataset == manifest .default_dataset_uri :
329+ return dataset
330+ if allow_unmanaged :
331+ return dataset
332+ raise ValueError (
333+ "Explicit dataset URIs bypass the policyengine.py release bundle. "
334+ "Pass a manifest dataset name or omit `dataset` to use the certified "
335+ "default dataset. Set `allow_unmanaged=True` only if you intend to "
336+ "bypass bundle enforcement."
337+ )
338+
339+ return resolve_dataset_reference (country_id , dataset )
340+
341+
342+ def resolve_local_managed_dataset_source (country_id : str , dataset_uri : str ) -> str :
343+ """Resolve a local mirror of a managed dataset when available.
344+
345+ This preserves the bundled dataset URI for provenance while allowing local
346+ development environments with sibling data-repo checkouts to load the
347+ exact certified artifact from disk rather than re-downloading it.
348+ """
349+
350+ if not dataset_uri .startswith ("hf://" ):
351+ return dataset_uri
352+
353+ local_hint = LOCAL_DATA_REPO_HINTS .get (country_id )
354+ if local_hint is None :
355+ return dataset_uri
356+
357+ path_without_revision = dataset_uri [5 :].rsplit ("@" , 1 )[0 ]
358+ parts = path_without_revision .split ("/" , 2 )
359+ if len (parts ) != 3 :
360+ return dataset_uri
361+ _ , _ , path_in_repo = parts
362+
363+ model_module_name , data_repo_name , data_package_name = local_hint
364+ try :
365+ model_module = import_module (model_module_name )
366+ except ImportError :
367+ return dataset_uri
368+
369+ repo_root = Path (model_module .__file__ ).resolve ().parents [1 ]
370+ local_path = (
371+ repo_root .with_name (data_repo_name )
372+ / data_package_name
373+ / "storage"
374+ / path_in_repo
375+ )
376+ if local_path .exists ():
377+ return str (local_path )
378+ return dataset_uri
379+
380+
353381def dataset_logical_name (dataset : str ) -> str :
354382 return Path (dataset .rsplit ("@" , 1 )[0 ]).stem
355383
0 commit comments