4747SubsampleType = Optional [int ]
4848
4949
50+ # =============================================================================
51+ # Variable Validation Functions
52+ # =============================================================================
53+
54+
55+ def get_variable (tax_benefit_system : Any , variable_name : str ) -> Any :
56+ """Get a variable from the tax-benefit system, raising if not found.
57+
58+ Args:
59+ tax_benefit_system: The tax-benefit system to search.
60+ variable_name: The name of the variable to find.
61+
62+ Returns:
63+ The variable object from the tax-benefit system.
64+
65+ Raises:
66+ ValueError: If the variable is not found.
67+ """
68+ if variable_name not in tax_benefit_system .variables :
69+ raise ValueError (
70+ f"Variable '{ variable_name } ' not found in tax-benefit system"
71+ )
72+ return tax_benefit_system .variables [variable_name ]
73+
74+
75+ def validate_variable_entity (
76+ tax_benefit_system : Any ,
77+ variable_name : str ,
78+ expected_entity : str ,
79+ ) -> None :
80+ """Validate that a variable belongs to the expected entity type.
81+
82+ Args:
83+ tax_benefit_system: The tax-benefit system containing the variable.
84+ variable_name: The name of the variable to validate.
85+ expected_entity: The expected entity key (e.g., "household", "person").
86+
87+ Raises:
88+ ValueError: If the variable is not found or belongs to a different entity.
89+ """
90+ variable = get_variable (tax_benefit_system , variable_name )
91+ actual_entity = variable .entity .key
92+
93+ if actual_entity != expected_entity :
94+ raise ValueError (
95+ f"Variable '{ variable_name } ' is a { actual_entity } -level variable, "
96+ f"not a { expected_entity } -level variable."
97+ )
98+
99+
100+ def validate_household_variable (
101+ tax_benefit_system : Any ,
102+ variable_name : str ,
103+ ) -> None :
104+ """Validate that a variable is a household-level variable.
105+
106+ Args:
107+ tax_benefit_system: The tax-benefit system containing the variable.
108+ variable_name: The name of the variable to validate.
109+
110+ Raises:
111+ ValueError: If the variable is not found or is not household-level.
112+ """
113+ variable = get_variable (tax_benefit_system , variable_name )
114+
115+ if variable .entity .key != "household" :
116+ raise ValueError (
117+ f"Variable '{ variable_name } ' is a { variable .entity .key } -level variable, "
118+ f"not a household-level variable. Only household-level variables can be "
119+ f"used for filtering to preserve household integrity."
120+ )
121+
122+
50123class SimulationOptions (BaseModel ):
51124 country : CountryType = Field (..., description = "The country to simulate." )
52125 scope : ScopeType = Field (..., description = "The scope of the simulation." )
@@ -388,6 +461,108 @@ def _apply_us_region_to_simulation(
388461 )
389462 return simulation
390463
464+ def _build_entity_relationships (
465+ self ,
466+ simulation : CountryMicrosimulation ,
467+ ) -> pd .DataFrame :
468+ """Build a DataFrame mapping each person to their containing entities.
469+
470+ Creates an explicit relationship map between persons and all entity
471+ types (household, tax_unit, etc.). This enables filtering at any
472+ entity level while preserving the integrity of all related entities.
473+
474+ Args:
475+ simulation: The microsimulation to extract relationships from.
476+
477+ Returns:
478+ A DataFrame indexed by person with columns for each entity ID.
479+ """
480+ entity_rel = pd .DataFrame (
481+ {"person_id" : simulation .calculate ("person_id" ).values }
482+ )
483+
484+ # Add household relationship (required for all countries)
485+ entity_rel ["household_id" ] = simulation .calculate (
486+ "household_id" , map_to = "person"
487+ ).values
488+
489+ # Add country-specific entity relationships
490+ tbs = simulation .tax_benefit_system
491+ optional_entities = [
492+ "tax_unit_id" ,
493+ "spm_unit_id" ,
494+ "family_id" ,
495+ "marital_unit_id" ,
496+ ]
497+
498+ for entity_id in optional_entities :
499+ if entity_id in tbs .variables :
500+ entity_rel [entity_id ] = simulation .calculate (
501+ entity_id , map_to = "person"
502+ ).values
503+
504+ return entity_rel
505+
506+ def _filter_simulation_by_household_variable (
507+ self ,
508+ simulation : CountryMicrosimulation ,
509+ simulation_type : type ,
510+ variable_name : str ,
511+ variable_value : Any ,
512+ reform : ReformType | None ,
513+ ) -> CountrySimulation :
514+ """Filter a simulation to only include households where a variable matches a value.
515+
516+ Uses the entity relationship approach: builds an explicit map of all
517+ entity relationships, filters at the household level, and keeps all
518+ persons in matching households to preserve entity integrity.
519+
520+ Args:
521+ simulation: The microsimulation to filter.
522+ simulation_type: The type of simulation to create (e.g., Microsimulation).
523+ variable_name: The name of the variable to filter on. Must be a
524+ household-level variable.
525+ variable_value: The value to match. For string variables that may be
526+ stored as bytes in HDF5, both str and bytes versions are checked.
527+ reform: Optional reform to apply to the filtered simulation.
528+
529+ Returns:
530+ A new simulation containing only households where the variable matches.
531+
532+ Raises:
533+ ValueError: If the variable is not a household-level variable.
534+ """
535+ validate_household_variable (
536+ simulation .tax_benefit_system , variable_name
537+ )
538+
539+ # Build entity relationships
540+ entity_rel = self ._build_entity_relationships (simulation )
541+
542+ # Get household-level variable values
543+ hh_values = simulation .calculate (variable_name ).values
544+ hh_ids = simulation .calculate ("household_id" ).values
545+
546+ # Create mask for matching households, handling bytes encoding
547+ if isinstance (variable_value , str ):
548+ hh_mask = (hh_values == variable_value ) | (
549+ hh_values == variable_value .encode ()
550+ )
551+ else :
552+ hh_mask = hh_values == variable_value
553+
554+ matching_hh_ids = set (hh_ids [hh_mask ])
555+
556+ # Filter entity_rel to persons in matching households
557+ person_mask = entity_rel ["household_id" ].isin (matching_hh_ids )
558+ filtered_entity_rel = entity_rel [person_mask ]
559+
560+ # Filter the input DataFrame using the filtered person indices
561+ df = simulation .to_input_dataframe ()
562+ subset_df = df .iloc [filtered_entity_rel .index ]
563+
564+ return simulation_type (dataset = subset_df , reform = reform )
565+
391566 def _filter_us_simulation_by_place (
392567 self ,
393568 simulation : CountryMicrosimulation ,
@@ -409,16 +584,14 @@ def _filter_us_simulation_by_place(
409584 from policyengine .utils .data .datasets import parse_us_place_region
410585
411586 _ , place_fips_code = parse_us_place_region (region )
412- df = simulation .to_input_dataframe ()
413- # Get place_fips at person level since to_input_dataframe() is person-level
414- person_place_fips = simulation .calculate (
415- "place_fips" , map_to = "person"
416- ).values
417- # place_fips may be stored as bytes in HDF5; handle both str and bytes
418- mask = (person_place_fips == place_fips_code ) | (
419- person_place_fips == place_fips_code .encode ()
587+
588+ return self ._filter_simulation_by_household_variable (
589+ simulation = simulation ,
590+ simulation_type = simulation_type ,
591+ variable_name = "place_fips" ,
592+ variable_value = place_fips_code ,
593+ reform = reform ,
420594 )
421- return simulation_type (dataset = df [mask ], reform = reform )
422595
423596 def check_model_version (self ) -> None :
424597 """
0 commit comments