Skip to content

Commit 3b405b7

Browse files
igerberclaude
andcommitted
Fix bootstrap RCS cohort-mass weighting, reset stale event-study VCV
- Bootstrap overall/event-study reaggregation now uses agg_weight (fixed cohort mass) for panel=False, matching the analytical aggregation path - Reset self._event_study_vcov = None at start of fit() to prevent stale VCV from prior fit leaking into reused estimator objects Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b623dee commit 3b405b7

2 files changed

Lines changed: 14 additions & 2 deletions

File tree

diff_diff/staggered.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,9 @@ def fit(
13201320
if not (0 < self.pscore_trim < 0.5):
13211321
raise ValueError(f"pscore_trim must be in (0, 0.5), got {self.pscore_trim}")
13221322

1323+
# Reset stale state from prior fit (prevents leaking event-study VCV)
1324+
self._event_study_vcov = None
1325+
13231326
# Normalize empty covariates list to None
13241327
if covariates is not None and len(covariates) == 0:
13251328
covariates = None

diff_diff/staggered_bootstrap.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,14 @@ def _run_multiplier_bootstrap(
237237
_cohort_mass_cache[g] = float(np.sum(survey_w[unit_cohorts == g]))
238238
all_n_treated = np.array([_cohort_mass_cache[gt[0]] for gt in gt_pairs], dtype=float)
239239
else:
240+
# Use agg_weight if available (RCS: fixed cohort mass);
241+
# fall back to n_treated for panel data
240242
all_n_treated = np.array(
241-
[group_time_effects[gt]["n_treated"] for gt in gt_pairs], dtype=float
243+
[
244+
group_time_effects[gt].get("agg_weight", group_time_effects[gt]["n_treated"])
245+
for gt in gt_pairs
246+
],
247+
dtype=float,
242248
)
243249
post_n_treated = all_n_treated[post_treatment_mask]
244250

@@ -572,7 +578,10 @@ def _agg_weight(g: Any, t: Any) -> float:
572578
if g not in _cohort_mass:
573579
_cohort_mass[g] = float(np.sum(survey_w[unit_cohorts == g]))
574580
return _cohort_mass[g]
575-
return group_time_effects[(g, t)]["n_treated"]
581+
# Use agg_weight if available (RCS: fixed cohort mass)
582+
return group_time_effects[(g, t)].get(
583+
"agg_weight", group_time_effects[(g, t)]["n_treated"]
584+
)
576585

577586
# Organize by relative time
578587
effects_by_e: Dict[int, List[Tuple[int, float, float]]] = {}

0 commit comments

Comments
 (0)