Skip to content

Commit ba01c58

Browse files
anth-volkclaude
andcommitted
Add changelog entry and format code
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 5c5c97e commit ba01c58

3 files changed

Lines changed: 63 additions & 19 deletions

File tree

changelog_entry.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Entity relationship approach for simulation filtering that preserves household integrity
5+
- Reusable variable validation functions (`get_variable`, `validate_variable_entity`, `validate_household_variable`)
6+
changed:
7+
- Refactored `_filter_simulation_by_household_variable` to use explicit entity relationship mapping
8+
- Place-level filtering now builds entity_rel DataFrame for cleaner filtering logic

policyengine/simulation.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,9 @@ def _build_entity_relationships(
477477
Returns:
478478
A DataFrame indexed by person with columns for each entity ID.
479479
"""
480-
entity_rel = pd.DataFrame({"person_id": simulation.calculate("person_id").values})
480+
entity_rel = pd.DataFrame(
481+
{"person_id": simulation.calculate("person_id").values}
482+
)
481483

482484
# Add household relationship (required for all countries)
483485
entity_rel["household_id"] = simulation.calculate(
@@ -486,7 +488,12 @@ def _build_entity_relationships(
486488

487489
# Add country-specific entity relationships
488490
tbs = simulation.tax_benefit_system
489-
optional_entities = ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]
491+
optional_entities = [
492+
"tax_unit_id",
493+
"spm_unit_id",
494+
"family_id",
495+
"marital_unit_id",
496+
]
490497

491498
for entity_id in optional_entities:
492499
if entity_id in tbs.variables:
@@ -525,7 +532,9 @@ def _filter_simulation_by_household_variable(
525532
Raises:
526533
ValueError: If the variable is not a household-level variable.
527534
"""
528-
validate_household_variable(simulation.tax_benefit_system, variable_name)
535+
validate_household_variable(
536+
simulation.tax_benefit_system, variable_name
537+
)
529538

530539
# Build entity relationships
531540
entity_rel = self._build_entity_relationships(simulation)
@@ -536,7 +545,9 @@ def _filter_simulation_by_household_variable(
536545

537546
# Create mask for matching households, handling bytes encoding
538547
if isinstance(variable_value, str):
539-
hh_mask = (hh_values == variable_value) | (hh_values == variable_value.encode())
548+
hh_mask = (hh_values == variable_value) | (
549+
hh_values == variable_value.encode()
550+
)
540551
else:
541552
hh_mask = hh_values == variable_value
542553

tests/fixtures/country/us_places.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@
129129
# =============================================================================
130130

131131

132-
def create_mock_tax_benefit_system(household_variables: list[str] | None = None) -> Mock:
132+
def create_mock_tax_benefit_system(
133+
household_variables: list[str] | None = None,
134+
) -> Mock:
133135
"""Create a mock tax benefit system with variable entity information.
134136
135137
Args:
@@ -152,12 +154,21 @@ def create_mock_tax_benefit_system(household_variables: list[str] | None = None)
152154
mock_tbs.variables[var_name] = mock_var
153155

154156
# Add standard entity ID variables
155-
for entity_id in ["person_id", "household_id", "tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]:
157+
for entity_id in [
158+
"person_id",
159+
"household_id",
160+
"tax_unit_id",
161+
"spm_unit_id",
162+
"family_id",
163+
"marital_unit_id",
164+
]:
156165
mock_var = Mock()
157166
mock_var.entity = Mock()
158167
# Entity IDs belong to their respective entities
159168
entity_name = entity_id.replace("_id", "")
160-
mock_var.entity.key = entity_name if entity_name != "person" else "person"
169+
mock_var.entity.key = (
170+
entity_name if entity_name != "person" else "person"
171+
)
161172
mock_tbs.variables[entity_id] = mock_var
162173

163174
return mock_tbs
@@ -218,7 +229,12 @@ def mock_calculate(variable_name, map_to=None, period=None):
218229
result.values = np.array(person_household_ids)
219230
else:
220231
result.values = np.array(household_ids)
221-
elif variable_name in ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]:
232+
elif variable_name in [
233+
"tax_unit_id",
234+
"spm_unit_id",
235+
"family_id",
236+
"marital_unit_id",
237+
]:
222238
# For simplicity, use household_id as proxy for other entity IDs
223239
if map_to == "person":
224240
result.values = np.array(person_household_ids)
@@ -231,11 +247,13 @@ def mock_calculate(variable_name, map_to=None, period=None):
231247
mock_sim.calculate = mock_calculate
232248

233249
# Mock to_input_dataframe to return person-level DataFrame
234-
df = pd.DataFrame({
235-
"person_id__2024": person_ids,
236-
"household_id__2024": person_household_ids,
237-
"place_fips__2024": person_place_fips,
238-
})
250+
df = pd.DataFrame(
251+
{
252+
"person_id__2024": person_ids,
253+
"household_id__2024": person_household_ids,
254+
"place_fips__2024": person_place_fips,
255+
}
256+
)
239257
mock_sim.to_input_dataframe.return_value = df
240258

241259
return mock_sim
@@ -287,7 +305,12 @@ def mock_calculate(variable_name, map_to=None, period=None):
287305
result.values = np.array(person_household_ids)
288306
else:
289307
result.values = np.array(household_ids)
290-
elif variable_name in ["tax_unit_id", "spm_unit_id", "family_id", "marital_unit_id"]:
308+
elif variable_name in [
309+
"tax_unit_id",
310+
"spm_unit_id",
311+
"family_id",
312+
"marital_unit_id",
313+
]:
291314
if map_to == "person":
292315
result.values = np.array(person_household_ids)
293316
else:
@@ -298,11 +321,13 @@ def mock_calculate(variable_name, map_to=None, period=None):
298321

299322
mock_sim.calculate = mock_calculate
300323

301-
df = pd.DataFrame({
302-
"person_id__2024": person_ids,
303-
"household_id__2024": person_household_ids,
304-
"place_fips__2024": person_place_fips,
305-
})
324+
df = pd.DataFrame(
325+
{
326+
"person_id__2024": person_ids,
327+
"household_id__2024": person_household_ids,
328+
"place_fips__2024": person_place_fips,
329+
}
330+
)
306331
mock_sim.to_input_dataframe.return_value = df
307332

308333
return mock_sim

0 commit comments

Comments
 (0)