Skip to content

Commit 5223bb7

Browse files
authored
Merge pull request #225 from PolicyEngine/fix/fix-filtering
Fix household-level filtering to preserve household integrity
2 parents 031871a + ba01c58 commit 5223bb7

5 files changed

Lines changed: 623 additions & 32 deletions

File tree

changelog_entry.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Entity relationship approach for simulation filtering that preserves household integrity
5+
- Reusable variable validation functions (`get_variable`, `validate_variable_entity`, `validate_household_variable`)
6+
changed:
7+
- Refactored `_filter_simulation_by_household_variable` to use explicit entity relationship mapping
8+
- Place-level filtering now builds entity_rel DataFrame for cleaner filtering logic

policyengine/simulation.py

Lines changed: 182 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,79 @@
4747
SubsampleType = Optional[int]
4848

4949

50+
# =============================================================================
51+
# Variable Validation Functions
52+
# =============================================================================
53+
54+
55+
def get_variable(tax_benefit_system: Any, variable_name: str) -> Any:
56+
"""Get a variable from the tax-benefit system, raising if not found.
57+
58+
Args:
59+
tax_benefit_system: The tax-benefit system to search.
60+
variable_name: The name of the variable to find.
61+
62+
Returns:
63+
The variable object from the tax-benefit system.
64+
65+
Raises:
66+
ValueError: If the variable is not found.
67+
"""
68+
if variable_name not in tax_benefit_system.variables:
69+
raise ValueError(
70+
f"Variable '{variable_name}' not found in tax-benefit system"
71+
)
72+
return tax_benefit_system.variables[variable_name]
73+
74+
75+
def validate_variable_entity(
76+
tax_benefit_system: Any,
77+
variable_name: str,
78+
expected_entity: str,
79+
) -> None:
80+
"""Validate that a variable belongs to the expected entity type.
81+
82+
Args:
83+
tax_benefit_system: The tax-benefit system containing the variable.
84+
variable_name: The name of the variable to validate.
85+
expected_entity: The expected entity key (e.g., "household", "person").
86+
87+
Raises:
88+
ValueError: If the variable is not found or belongs to a different entity.
89+
"""
90+
variable = get_variable(tax_benefit_system, variable_name)
91+
actual_entity = variable.entity.key
92+
93+
if actual_entity != expected_entity:
94+
raise ValueError(
95+
f"Variable '{variable_name}' is a {actual_entity}-level variable, "
96+
f"not a {expected_entity}-level variable."
97+
)
98+
99+
100+
def validate_household_variable(
101+
tax_benefit_system: Any,
102+
variable_name: str,
103+
) -> None:
104+
"""Validate that a variable is a household-level variable.
105+
106+
Args:
107+
tax_benefit_system: The tax-benefit system containing the variable.
108+
variable_name: The name of the variable to validate.
109+
110+
Raises:
111+
ValueError: If the variable is not found or is not household-level.
112+
"""
113+
variable = get_variable(tax_benefit_system, variable_name)
114+
115+
if variable.entity.key != "household":
116+
raise ValueError(
117+
f"Variable '{variable_name}' is a {variable.entity.key}-level variable, "
118+
f"not a household-level variable. Only household-level variables can be "
119+
f"used for filtering to preserve household integrity."
120+
)
121+
122+
50123
class SimulationOptions(BaseModel):
51124
country: CountryType = Field(..., description="The country to simulate.")
52125
scope: ScopeType = Field(..., description="The scope of the simulation.")
@@ -388,6 +461,108 @@ def _apply_us_region_to_simulation(
388461
)
389462
return simulation
390463

464+
def _build_entity_relationships(
465+
self,
466+
simulation: CountryMicrosimulation,
467+
) -> pd.DataFrame:
468+
"""Build a DataFrame mapping each person to their containing entities.
469+
470+
Creates an explicit relationship map between persons and all entity
471+
types (household, tax_unit, etc.). This enables filtering at any
472+
entity level while preserving the integrity of all related entities.
473+
474+
Args:
475+
simulation: The microsimulation to extract relationships from.
476+
477+
Returns:
478+
A DataFrame indexed by person with columns for each entity ID.
479+
"""
480+
entity_rel = pd.DataFrame(
481+
{"person_id": simulation.calculate("person_id").values}
482+
)
483+
484+
# Add household relationship (required for all countries)
485+
entity_rel["household_id"] = simulation.calculate(
486+
"household_id", map_to="person"
487+
).values
488+
489+
# Add country-specific entity relationships
490+
tbs = simulation.tax_benefit_system
491+
optional_entities = [
492+
"tax_unit_id",
493+
"spm_unit_id",
494+
"family_id",
495+
"marital_unit_id",
496+
]
497+
498+
for entity_id in optional_entities:
499+
if entity_id in tbs.variables:
500+
entity_rel[entity_id] = simulation.calculate(
501+
entity_id, map_to="person"
502+
).values
503+
504+
return entity_rel
505+
506+
def _filter_simulation_by_household_variable(
507+
self,
508+
simulation: CountryMicrosimulation,
509+
simulation_type: type,
510+
variable_name: str,
511+
variable_value: Any,
512+
reform: ReformType | None,
513+
) -> CountrySimulation:
514+
"""Filter a simulation to only include households where a variable matches a value.
515+
516+
Uses the entity relationship approach: builds an explicit map of all
517+
entity relationships, filters at the household level, and keeps all
518+
persons in matching households to preserve entity integrity.
519+
520+
Args:
521+
simulation: The microsimulation to filter.
522+
simulation_type: The type of simulation to create (e.g., Microsimulation).
523+
variable_name: The name of the variable to filter on. Must be a
524+
household-level variable.
525+
variable_value: The value to match. For string variables that may be
526+
stored as bytes in HDF5, both str and bytes versions are checked.
527+
reform: Optional reform to apply to the filtered simulation.
528+
529+
Returns:
530+
A new simulation containing only households where the variable matches.
531+
532+
Raises:
533+
ValueError: If the variable is not a household-level variable.
534+
"""
535+
validate_household_variable(
536+
simulation.tax_benefit_system, variable_name
537+
)
538+
539+
# Build entity relationships
540+
entity_rel = self._build_entity_relationships(simulation)
541+
542+
# Get household-level variable values
543+
hh_values = simulation.calculate(variable_name).values
544+
hh_ids = simulation.calculate("household_id").values
545+
546+
# Create mask for matching households, handling bytes encoding
547+
if isinstance(variable_value, str):
548+
hh_mask = (hh_values == variable_value) | (
549+
hh_values == variable_value.encode()
550+
)
551+
else:
552+
hh_mask = hh_values == variable_value
553+
554+
matching_hh_ids = set(hh_ids[hh_mask])
555+
556+
# Filter entity_rel to persons in matching households
557+
person_mask = entity_rel["household_id"].isin(matching_hh_ids)
558+
filtered_entity_rel = entity_rel[person_mask]
559+
560+
# Filter the input DataFrame using the filtered person indices
561+
df = simulation.to_input_dataframe()
562+
subset_df = df.iloc[filtered_entity_rel.index]
563+
564+
return simulation_type(dataset=subset_df, reform=reform)
565+
391566
def _filter_us_simulation_by_place(
392567
self,
393568
simulation: CountryMicrosimulation,
@@ -409,16 +584,14 @@ def _filter_us_simulation_by_place(
409584
from policyengine.utils.data.datasets import parse_us_place_region
410585

411586
_, place_fips_code = parse_us_place_region(region)
412-
df = simulation.to_input_dataframe()
413-
# Get place_fips at person level since to_input_dataframe() is person-level
414-
person_place_fips = simulation.calculate(
415-
"place_fips", map_to="person"
416-
).values
417-
# place_fips may be stored as bytes in HDF5; handle both str and bytes
418-
mask = (person_place_fips == place_fips_code) | (
419-
person_place_fips == place_fips_code.encode()
587+
588+
return self._filter_simulation_by_household_variable(
589+
simulation=simulation,
590+
simulation_type=simulation_type,
591+
variable_name="place_fips",
592+
variable_value=place_fips_code,
593+
reform=reform,
420594
)
421-
return simulation_type(dataset=df[mask], reform=reform)
422595

423596
def check_model_version(self) -> None:
424597
"""

0 commit comments

Comments
 (0)