Skip to content

Commit eac680e

Browse files
igerberclaude
andcommitted
Match R's H/n, asy_rep/n, colMeans convention for panel PS corrections; fix VCV index subsetting
Panel IPW/DR PS corrections: restructure to match R's std_ipw_did_panel / drdid_panel convention: H = X'WX/n, asy_lin_rep = score @ solve(H) / n, M2 = colMeans(). Algebraically equivalent but mirrors R source literally. HonestDiD VCV subsetting: store event_study_vcov_index (the exact event-time ordering matching VCV columns) so subsetting works correctly even when universal base period injects a reference row into event_study_effects. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7e127fb commit eac680e

4 files changed

Lines changed: 47 additions & 40 deletions

File tree

diff_diff/honest_did.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -665,15 +665,21 @@ def _extract_event_study_params(
665665
# otherwise fall back to diagonal from SEs
666666
if hasattr(results, "event_study_vcov") and results.event_study_vcov is not None:
667667
vcov = results.event_study_vcov
668-
# VCV is indexed by ALL event times from aggregation;
669-
# rel_times may be a filtered subset (NaN-SE times dropped).
670-
# Subset VCV to match the surviving rel_times.
671-
all_event_times = sorted(results.event_study_effects.keys())
672-
if vcov.shape[0] == len(all_event_times) and len(rel_times) < len(all_event_times):
673-
idx = [all_event_times.index(t) for t in rel_times]
674-
sigma = vcov[np.ix_(idx, idx)]
675-
else:
668+
# VCV is indexed by the aggregated event times (stored in
669+
# event_study_vcov_index), NOT by event_study_effects keys
670+
# (which may include an injected reference period).
671+
# Subset to match the surviving rel_times.
672+
vcov_index = getattr(results, "event_study_vcov_index", None)
673+
if vcov_index is not None and len(rel_times) < len(vcov_index):
674+
idx = [vcov_index.index(t) for t in rel_times if t in vcov_index]
675+
if len(idx) == len(rel_times):
676+
sigma = vcov[np.ix_(idx, idx)]
677+
else:
678+
sigma = np.diag(np.array(ses) ** 2)
679+
elif vcov.shape[0] == len(rel_times):
676680
sigma = vcov
681+
else:
682+
sigma = np.diag(np.array(ses) ** 2)
677683
else:
678684
sigma = np.diag(np.array(ses) ** 2)
679685

diff_diff/staggered.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,8 +1775,10 @@ def fit(
17751775
# Clear it when bootstrap overwrites event-study SEs to prevent
17761776
# HonestDiD from mixing analytical VCV with bootstrap SEs.
17771777
event_study_vcov = getattr(self, "_event_study_vcov", None)
1778+
event_study_vcov_index = getattr(self, "_event_study_vcov_index", None)
17781779
if bootstrap_results is not None and event_study_vcov is not None:
17791780
event_study_vcov = None
1781+
event_study_vcov_index = None
17801782

17811783
self.results_ = CallawaySantAnnaResults(
17821784
group_time_effects=group_time_effects,
@@ -1800,6 +1802,7 @@ def fit(
18001802
pscore_trim=self.pscore_trim,
18011803
survey_metadata=survey_metadata,
18021804
event_study_vcov=event_study_vcov,
1805+
event_study_vcov_index=event_study_vcov_index,
18031806
panel=self.panel,
18041807
)
18051808

@@ -2032,35 +2035,29 @@ def _ipw_estimation(
20322035
X_all_int = np.column_stack([np.ones(n_t + n_c), X_all])
20332036
pscore_all = np.concatenate([pscore_treated, pscore_control])
20342037

2035-
# Survey-weighted PS Hessian: sum(w_i * mu_i * (1-mu_i) * x_i * x_i')
2038+
# PS IF correction — matches R's std_ipw_did_panel convention:
2039+
# H = X'WX / n, asy_lin_rep = score @ solve(H) / n, M2 = colMeans
2040+
n_all_panel = n_t + n_c
20362041
W_ps = pscore_all * (1 - pscore_all)
20372042
if sw_all is not None:
20382043
W_ps = W_ps * sw_all
2039-
H = X_all_int.T @ (W_ps[:, None] * X_all_int)
2040-
try:
2041-
H_inv = np.linalg.solve(H, np.eye(H.shape[0]))
2042-
except np.linalg.LinAlgError:
2043-
H_inv = np.linalg.lstsq(H, np.eye(H.shape[0]), rcond=None)[0]
2044+
H = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
2045+
H_inv = _safe_inv(H)
20442046

2045-
# PS score: w_i * (D_i - pi_i) * X_i
20462047
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
20472048
score_ps = (D_all - pscore_all)[:, None] * X_all_int
20482049
if sw_all is not None:
20492050
score_ps = score_ps * sw_all[:, None]
2050-
asy_lin_rep_ps = score_ps @ H_inv # shape (n_t + n_c, p)
2051+
asy_lin_rep_ps = score_ps @ H_inv / n_all_panel
20512052

2052-
# M2: gradient of ATT w.r.t. PS parameters
2053-
# R convention: colMeans over ALL n obs (zero for treated rows)
20542053
att_control_weighted = np.sum(weights_control_norm * control_change)
2055-
M2 = np.sum(
2054+
M2 = np.mean(
20562055
(weights_control_norm * (control_change - att_control_weighted))[:, None]
20572056
* X_all_int[n_t:],
20582057
axis=0,
2059-
) / (n_t + n_c)
2058+
)
20602059

2061-
# PS correction to influence function
2062-
inf_ps_correction = asy_lin_rep_ps @ M2
2063-
inf_func = inf_func + inf_ps_correction
2060+
inf_func = inf_func + asy_lin_rep_ps @ M2
20642061

20652062
# SE from influence function variance
20662063
var_psi = np.sum(inf_func**2)
@@ -2295,29 +2292,26 @@ def _doubly_robust(
22952292
)
22962293
pscore_all = np.concatenate([pscore_treated_clipped, pscore_control])
22972294

2298-
# Survey-weighted PS Hessian
2295+
# PS IF correction — R convention: H/n, asy_rep/n, colMeans
2296+
n_all_panel = n_t + n_c
22992297
W_ps = pscore_all * (1 - pscore_all)
23002298
if sw_all is not None:
23012299
W_ps = W_ps * sw_all
2302-
H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int)
2300+
H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
23032301
H_ps_inv = _safe_inv(H_ps)
23042302

2305-
# PS score
23062303
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
23072304
score_ps = (D_all - pscore_all)[:, None] * X_all_int
23082305
if sw_all is not None:
23092306
score_ps = score_ps * sw_all[:, None]
2310-
asy_lin_rep_ps = score_ps @ H_ps_inv # (n_t+n_c, p+1)
2307+
asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel
23112308

2312-
# M2_dr: dATT/dgamma — gradient of DR ATT w.r.t. PS parameters
2313-
# Only the control augmentation term depends on PS via w_ipw
2314-
# R convention: colMeans over ALL n obs (zero for treated rows)
23152309
dr_resid_control = m_control - control_change
2316-
M2_dr = np.sum(
2310+
M2_dr = np.mean(
23172311
((weights_control / sw_t_sum) * dr_resid_control)[:, None]
23182312
* X_all_int[n_t:],
23192313
axis=0,
2320-
) / (n_t + n_c)
2314+
)
23212315
inf_func = inf_func + asy_lin_rep_ps @ M2_dr
23222316

23232317
# --- OR IF correction ---
@@ -2358,27 +2352,27 @@ def _doubly_robust(
23582352
inf_func = np.concatenate([psi_treated, psi_control])
23592353

23602354
if X_treated is not None and X_control is not None and X_treated.shape[1] > 0:
2361-
# --- PS IF correction ---
2362-
X_all_int = np.column_stack([np.ones(n_t + n_c), X_all])
2355+
# --- PS IF correction — R convention: H/n, asy_rep/n, colMeans ---
2356+
n_all_panel = n_t + n_c
2357+
X_all_int = np.column_stack([np.ones(n_all_panel), X_all])
23632358
pscore_treated_clipped = np.clip(
23642359
pscore[:n_t], self.pscore_trim, 1 - self.pscore_trim
23652360
)
23662361
pscore_all = np.concatenate([pscore_treated_clipped, pscore_control])
23672362

23682363
W_ps = pscore_all * (1 - pscore_all)
2369-
H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int)
2364+
H_ps = X_all_int.T @ (W_ps[:, None] * X_all_int) / n_all_panel
23702365
H_ps_inv = _safe_inv(H_ps)
23712366

23722367
D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)])
23732368
score_ps = (D_all - pscore_all)[:, None] * X_all_int
2374-
asy_lin_rep_ps = score_ps @ H_ps_inv
2369+
asy_lin_rep_ps = score_ps @ H_ps_inv / n_all_panel
23752370

2376-
# R convention: colMeans over ALL n obs (zero for treated rows)
23772371
dr_resid_control = m_control - control_change
2378-
M2_dr = np.sum(
2372+
M2_dr = np.mean(
23792373
((weights_control / n_t) * dr_resid_control)[:, None] * X_all_int[n_t:],
23802374
axis=0,
2381-
) / (n_t + n_c)
2375+
)
23822376
inf_func = inf_func + asy_lin_rep_ps @ M2_dr
23832377

23842378
# --- OR IF correction ---

diff_diff/staggered_aggregation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,12 @@ def _aggregate_event_study(
751751
except (ValueError, np.linalg.LinAlgError):
752752
pass # Fall back to diagonal (None)
753753

754+
# Store the event-time index that matches VCV columns (for subsetting
755+
# in HonestDiD when some event times are filtered out)
756+
self._event_study_vcov_index = (
757+
[e for e, _ in sorted_periods] if event_study_vcov is not None else None
758+
)
759+
754760
# Attach VCV to self for CallawaySantAnna to pick up
755761
self._event_study_vcov = event_study_vcov
756762

diff_diff/staggered_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ class CallawaySantAnnaResults:
115115
event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
116116
group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
117117
influence_functions: Optional["np.ndarray"] = field(default=None, repr=False)
118-
# Full event-study VCV matrix (Phase 7d): indexed by sorted relative times
118+
# Full event-study VCV matrix (Phase 7d): indexed by event_study_vcov_index
119119
event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False)
120+
event_study_vcov_index: Optional[list] = field(default=None, repr=False)
120121
bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False)
121122
cband_crit_value: Optional[float] = None
122123
pscore_trim: float = 0.01

0 commit comments

Comments
 (0)