Skip to content

Commit 6c6494f

Browse files
igerberclaude
andcommitted
Validate replicate scale/rscales, combined_weights contract, fix df metadata
- Validate replicate_scale > 0 and replicate_rscales finite non-negative in SurveyDesign.__post_init__ - Validate combined_weights=True contract in resolve(): reject w_r > 0 where w_full == 0 (malformed design) - Fix CS IPW/DR path: pass survey df to safe_inference_batch - Fix ContinuousDiD/EfficientDiD: don't propagate df=0 sentinel to survey_metadata (keep as None for display) - Add TWFE, StackedDiD rejection tests + scale/rscales validation tests - Update survey-roadmap.md: CS now has full survey support, accurate replicate limitation descriptions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 150c9c0 commit 6c6494f

6 files changed

Lines changed: 105 additions & 10 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,10 @@ def fit(
531531
raw_w_unit = _unit_resolved.weights
532532
survey_metadata = compute_survey_metadata(_unit_resolved, raw_w_unit)
533533

534-
# Propagate replicate df override to survey_metadata for display consistency
535-
if _survey_df is not None and survey_metadata is not None:
534+
# Propagate replicate df override to survey_metadata for display
535+
# (but not the df=0 sentinel — keep metadata as None for undefined df)
536+
if (_survey_df is not None and _survey_df != 0
537+
and survey_metadata is not None):
536538
if survey_metadata.df_survey != _survey_df:
537539
survey_metadata.df_survey = _survey_df
538540

diff_diff/efficient_did.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,9 @@ def _recompute_unit_survey_metadata(self, panel_metadata):
10811081
self._unit_resolved_survey.weights,
10821082
)
10831083
# Propagate effective replicate df if available
1084-
if self._survey_df is not None and meta.df_survey != self._survey_df:
1084+
# (but not the df=0 sentinel — keep metadata as None for undefined df)
1085+
if (self._survey_df is not None and self._survey_df != 0
1086+
and meta.df_survey != self._survey_df):
10851087
meta.df_survey = self._survey_df
10861088
return meta
10871089
return panel_metadata

diff_diff/staggered.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1214,8 +1214,16 @@ def _compute_all_att_gt_covariate_reg(
12141214

12151215
# Batch inference
12161216
if task_keys:
1217+
# Use survey df for replicate designs (propagated from precomputed)
1218+
_ipw_dr_df = precomputed.get("df_survey") if precomputed is not None else None
1219+
# Guard: replicate design with undefined df → NaN inference
1220+
if (_ipw_dr_df is None and precomputed is not None
1221+
and precomputed.get("resolved_survey_unit") is not None
1222+
and hasattr(precomputed["resolved_survey_unit"], 'uses_replicate_variance')
1223+
and precomputed["resolved_survey_unit"].uses_replicate_variance):
1224+
_ipw_dr_df = 0
12171225
t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch(
1218-
np.array(atts), np.array(ses), alpha=self.alpha
1226+
np.array(atts), np.array(ses), alpha=self.alpha, df=_ipw_dr_df
12191227
)
12201228
for idx, key in enumerate(task_keys):
12211229
group_time_effects[key]["t_stat"] = float(t_stats[idx])

diff_diff/survey.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,24 @@ def __post_init__(self):
131131
f"replicate_strata length ({len(self.replicate_strata)}) must "
132132
f"match replicate_weights length ({len(self.replicate_weights)})"
133133
)
134-
# Validate rscales length
134+
# Validate scale/rscales values and length
135+
if self.replicate_scale is not None:
136+
if not (np.isfinite(self.replicate_scale) and self.replicate_scale > 0):
137+
raise ValueError(
138+
f"replicate_scale must be a positive finite number, "
139+
f"got {self.replicate_scale}"
140+
)
135141
if self.replicate_rscales is not None and self.replicate_weights is not None:
136142
if len(self.replicate_rscales) != len(self.replicate_weights):
137143
raise ValueError(
138144
f"replicate_rscales length ({len(self.replicate_rscales)}) must "
139145
f"match replicate_weights length ({len(self.replicate_weights)})"
140146
)
147+
rscales_arr = np.asarray(self.replicate_rscales, dtype=float)
148+
if not np.all(np.isfinite(rscales_arr)):
149+
raise ValueError("replicate_rscales must be finite")
150+
if np.any(rscales_arr < 0):
151+
raise ValueError("replicate_rscales must be non-negative")
141152

142153
def resolve(self, data: pd.DataFrame) -> "ResolvedSurveyDesign":
143154
"""
@@ -214,6 +225,26 @@ def resolve(self, data: pd.DataFrame) -> "ResolvedSurveyDesign":
214225
raise ValueError("Replicate weights contain Inf values")
215226
if np.any(rep_arr < 0):
216227
raise ValueError("Replicate weights must be non-negative")
228+
# Validate combined_weights contract: when True, replicate columns
229+
# include the full-sample weight, so w_r > 0 with w_full == 0 is
230+
# malformed (observation excluded from full sample but included in
231+
# a replicate).
232+
combined = (
233+
self.combined_weights
234+
if self.combined_weights is not None
235+
else True
236+
)
237+
if combined:
238+
zero_full = weights == 0
239+
if np.any(zero_full):
240+
rep_positive_on_zero = np.any(rep_arr[zero_full] > 0, axis=1)
241+
if np.any(rep_positive_on_zero):
242+
raise ValueError(
243+
"Malformed combined_weights=True design: some "
244+
"replicate columns have positive weight where "
245+
"full-sample weight is zero. Either fix the "
246+
"replicate columns or use combined_weights=False."
247+
)
217248
# Do NOT normalize replicate columns — the IF path uses w_r/w_full
218249
# ratios that must reflect the true replicate design, not rescaled sums
219250
n_rep = rep_arr.shape[1]

docs/survey-roadmap.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ message pointing to the planned phase or describing the limitation.
4646
|-----------|------|----------------|-------|
4747
| ImputationDiD | `imputation.py` | Analytical | Weighted iterative FE, weighted ATT aggregation, weighted conservative variance (Theorem 3); bootstrap+survey deferred |
4848
| TwoStageDiD | `two_stage.py` | Analytical | Weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance; bootstrap+survey deferred |
49-
| CallawaySantAnna | `staggered.py` | Weights-only | Weights-only SurveyDesign (strata/PSU/FPC rejected); reg supports covariates, IPW/DR no-covariate only; survey-weighted WIF in aggregation; full design SEs, covariates+IPW/DR, and bootstrap+survey deferred |
49+
| CallawaySantAnna | `staggered.py` | Full | Full SurveyDesign (strata/PSU/FPC/replicate weights); reg supports covariates, IPW/DR no-covariate only; survey-weighted WIF in aggregation; replicate IF variance for analytical SEs |
5050

5151
**Infrastructure**: Weighted `solve_logit()` added to `linalg.py` — survey weights
5252
enter the IRLS working weights as `w_survey * mu * (1 - mu)`. This also unblocked
@@ -100,10 +100,11 @@ JKn requires explicit `replicate_strata` (per-replicate stratum assignment).
100100
- Dispatch in `LinearRegression.fit()` and `staggered_aggregation.py`
101101
- Replicate weights mutually exclusive with strata/PSU/FPC
102102
- Survey df = rank(replicate_weights) - 1, matching R's `survey::degf()`
103-
- **Limitations**: SunAbraham rejects replicate-weight designs (weighted
104-
within-transformation must be recomputed per replicate — not yet implemented).
105-
ContinuousDiD and EfficientDiD reject replicate weights + `n_bootstrap > 0`
106-
(replicate variance is analytical, not bootstrap-compatible).
103+
- **Limitations**: Supported in CallawaySantAnna, ContinuousDiD, EfficientDiD,
104+
TripleDifference (analytical only, no bootstrap). Rejected with
105+
`NotImplementedError` in DifferenceInDifferences, TwoWayFixedEffects,
106+
MultiPeriodDiD, StackedDiD, SunAbraham, ImputationDiD, TwoStageDiD,
107+
SyntheticDiD, TROP.
107108

108109
### DEFF Diagnostics ✅ (2026-03-26)
109110
Per-coefficient design effects comparing survey vcov to SRS (HC1) vcov.

tests/test_survey_phase6.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,57 @@ def test_callaway_santanna_replicate_bootstrap_rejected(self):
13571357
survey_design=sd,
13581358
)
13591359

1360+
def test_twfe_replicate_rejected(self):
1361+
"""TwoWayFixedEffects should reject replicate-weight designs."""
1362+
from diff_diff.twfe import TwoWayFixedEffects
1363+
1364+
data, rep_cols = TestEstimatorReplicateWeights._make_staggered_replicate_data()
1365+
sd = SurveyDesign(
1366+
weights="weight", replicate_weights=rep_cols,
1367+
replicate_method="JK1",
1368+
)
1369+
with pytest.raises(NotImplementedError, match="TwoWayFixedEffects"):
1370+
TwoWayFixedEffects().fit(
1371+
data, outcome="outcome", treatment="first_treat",
1372+
unit="unit", time="time", survey_design=sd,
1373+
)
1374+
1375+
def test_stacked_did_replicate_rejected(self):
1376+
"""StackedDiD should reject replicate-weight designs."""
1377+
from diff_diff import StackedDiD
1378+
1379+
data, rep_cols = TestEstimatorReplicateWeights._make_staggered_replicate_data()
1380+
sd = SurveyDesign(
1381+
weights="weight", replicate_weights=rep_cols,
1382+
replicate_method="JK1",
1383+
)
1384+
with pytest.raises(NotImplementedError, match="StackedDiD"):
1385+
StackedDiD().fit(
1386+
data, outcome="outcome", unit="unit", time="time",
1387+
first_treat="first_treat", survey_design=sd,
1388+
)
1389+
1390+
def test_invalid_replicate_scale_rejected(self):
1391+
"""Negative or zero replicate_scale should be rejected."""
1392+
with pytest.raises(ValueError, match="positive finite"):
1393+
SurveyDesign(
1394+
weights="w", replicate_weights=["r1", "r2"],
1395+
replicate_method="JK1", replicate_scale=-1.0,
1396+
)
1397+
with pytest.raises(ValueError, match="positive finite"):
1398+
SurveyDesign(
1399+
weights="w", replicate_weights=["r1", "r2"],
1400+
replicate_method="JK1", replicate_scale=0.0,
1401+
)
1402+
1403+
def test_invalid_replicate_rscales_rejected(self):
1404+
"""Negative replicate_rscales should be rejected."""
1405+
with pytest.raises(ValueError, match="non-negative"):
1406+
SurveyDesign(
1407+
weights="w", replicate_weights=["r1", "r2"],
1408+
replicate_method="JK1", replicate_rscales=[-1.0, 1.0],
1409+
)
1410+
13601411

13611412
# =============================================================================
13621413
# Effective-sample and d.f. consistency tests

0 commit comments

Comments
 (0)