Skip to content

Commit 1be7d1e

Browse files
igerberclaude
andcommitted
Address AI review P1/P2 findings for survey Phase 5
- Return composed ω_eff (not raw ω) in SyntheticDiDResults.unit_weights so returned weights match the estimator actually used under survey - Add NaN finite guard in TROP local _fit_with_fixed_lambda() and Rust bootstrap to skip non-finite treated outcomes (match main fit contract) - Add finite guard on bootstrap ATT accumulator - Add regression tests for effective weight semantics and NaN bootstrap Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d23b8c3 commit 1be7d1e

4 files changed

Lines changed: 89 additions & 4 deletions

File tree

diff_diff/synthetic_did.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,10 @@ def fit( # type: ignore[override]
508508
else:
509509
p_value = p_value_analytical
510510

511-
# Create weight dictionaries (store original ω, not composed)
512-
unit_weights_dict = {unit_id: w for unit_id, w in zip(control_units, unit_weights)}
511+
# Create weight dictionaries. When survey weights are active, store
512+
# the effective (composed) weights that were actually used for the ATT
513+
# so that results.unit_weights matches the estimator.
514+
unit_weights_dict = {unit_id: w for unit_id, w in zip(control_units, omega_eff)}
513515
time_weights_dict = {period: w for period, w in zip(pre_periods, time_weights)}
514516

515517
# Store results

diff_diff/trop_local.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,8 @@ def _bootstrap_variance(
949949
optimal_lambda,
950950
survey_design=survey_design,
951951
)
952-
bootstrap_estimates_list.append(att)
952+
if np.isfinite(att):
953+
bootstrap_estimates_list.append(att)
953954
except (ValueError, np.linalg.LinAlgError, KeyError):
954955
continue
955956

@@ -1032,6 +1033,10 @@ def _fit_with_fixed_lambda(
10321033
tau_values = []
10331034
tau_weights = []
10341035
for t, i in treated_observations:
1036+
# Skip non-finite outcomes (match main fit NaN contract)
1037+
if not np.isfinite(Y[t, i]):
1038+
continue
1039+
10351040
# Compute observation-specific weights for this (i, t)
10361041
weight_matrix = self._compute_observation_weights(
10371042
Y, D, i, t, lambda_time, lambda_unit, control_unit_idx, n_units, n_periods
@@ -1048,6 +1053,8 @@ def _fit_with_fixed_lambda(
10481053
if local_weight_arr is not None:
10491054
tau_weights.append(local_weight_arr[i])
10501055

1051-
if local_weight_arr is not None and tau_values:
1056+
if not tau_values:
1057+
return float("nan")
1058+
if local_weight_arr is not None:
10521059
return float(np.average(tau_values, weights=tau_weights))
10531060
return float(np.mean(tau_values))

rust/src/trop.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,11 @@ pub fn bootstrap_trop_variance<'py>(
10351035
let mut tau_count = 0usize;
10361036

10371037
for (t, i) in boot_treated {
1038+
// Skip non-finite outcomes (match main fit NaN contract)
1039+
if !y_boot[[t, i]].is_finite() {
1040+
continue;
1041+
}
1042+
10381043
let weight_matrix = compute_weight_matrix(
10391044
&y_boot.view(),
10401045
&d_boot.view(),

tests/test_survey_phase5.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,39 @@ def test_covariates_with_survey(self, sdid_survey_data, survey_design_weights):
350350
assert np.isfinite(result.att)
351351
assert result.survey_metadata is not None
352352

353+
def test_effective_weights_returned(self, sdid_survey_data, survey_design_weights):
354+
"""unit_weights returns composed ω_eff (not raw ω) under survey weighting."""
355+
est = SyntheticDiD(variance_method="placebo", n_bootstrap=50, seed=42)
356+
result = est.fit(
357+
sdid_survey_data,
358+
outcome="outcome",
359+
treatment="treated",
360+
unit="unit",
361+
time="time",
362+
post_periods=[6, 7, 8, 9],
363+
survey_design=survey_design_weights,
364+
)
365+
weights = result.unit_weights
366+
# Effective weights should sum to 1 (renormalized)
367+
assert sum(weights.values()) == pytest.approx(1.0, abs=1e-10)
368+
# With non-uniform survey weights, effective weights should differ
369+
# from what uniform survey weights would produce
370+
sdid_survey_data_u = sdid_survey_data.copy()
371+
sdid_survey_data_u["uniform_w"] = 1.0
372+
result_u = est.fit(
373+
sdid_survey_data_u,
374+
outcome="outcome",
375+
treatment="treated",
376+
unit="unit",
377+
time="time",
378+
post_periods=[6, 7, 8, 9],
379+
survey_design=SurveyDesign(weights="uniform_w"),
380+
)
381+
# Non-uniform weights should change the returned weight distribution
382+
eff_vals = sorted(weights.values(), reverse=True)
383+
uni_vals = sorted(result_u.unit_weights.values(), reverse=True)
384+
assert eff_vals != pytest.approx(uni_vals, abs=1e-6)
385+
353386

354387
# =============================================================================
355388
# TROP Survey Tests
@@ -577,3 +610,41 @@ def test_to_dict_includes_survey(self, trop_survey_data, survey_design_weights):
577610
d = result.to_dict()
578611
assert "weight_type" in d
579612
assert d["weight_type"] == "pweight"
613+
614+
def test_local_bootstrap_nan_treated_outcomes(self, trop_survey_data):
615+
"""Bootstrap handles NaN treated outcomes without poisoning SE."""
616+
trop_survey_data = trop_survey_data.copy()
617+
# Set some treated post-treatment outcomes to NaN
618+
mask = (trop_survey_data["D"] == 1) & (trop_survey_data["time"] == 7)
619+
trop_survey_data.loc[mask, "outcome"] = np.nan
620+
621+
est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5)
622+
result = est.fit(
623+
trop_survey_data,
624+
outcome="outcome",
625+
treatment="D",
626+
unit="unit",
627+
time="time",
628+
)
629+
# Point estimate should use finite cells only
630+
assert np.isfinite(result.att)
631+
# SE should remain finite (not poisoned by NaN)
632+
assert np.isfinite(result.se)
633+
634+
def test_local_bootstrap_nan_with_survey(self, trop_survey_data, survey_design_weights):
635+
"""Bootstrap + survey handles NaN treated outcomes correctly."""
636+
trop_survey_data = trop_survey_data.copy()
637+
mask = (trop_survey_data["D"] == 1) & (trop_survey_data["time"] == 8)
638+
trop_survey_data.loc[mask, "outcome"] = np.nan
639+
640+
est = TROP(method="local", n_bootstrap=10, seed=42, max_iter=5)
641+
result = est.fit(
642+
trop_survey_data,
643+
outcome="outcome",
644+
treatment="D",
645+
unit="unit",
646+
time="time",
647+
survey_design=survey_design_weights,
648+
)
649+
assert np.isfinite(result.att)
650+
assert np.isfinite(result.se)

0 commit comments

Comments
 (0)