Skip to content

Commit 09ac0da

Browse files
hmgaudeckerclaude
andcommitted
Fix FixedConstraintWithValue.loc type and test expectations
Remove list from loc type union, convert callers to tuple(). Update anchoring test expectations from list to tuple. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bf4faa0 commit 09ac0da

4 files changed

Lines changed: 35 additions & 24 deletions

File tree

CLAUDE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ simulate_dataset() → Simulate states (with optional policy effects)
9393
computation using Kalman filtering. The debug variant is not jitted and returns
9494
intermediate results (residuals, contributions, filtered states).
9595
- **constraints.py**: Generates parameter constraints (bounds, equalities from stagemap,
96-
fixed values) for optimization. Exports `get_constraints()`, `get_constraints()`,
96+
fixed values) for optimization. Exports `get_constraints()`,
9797
`enforce_fixed_constraints()`, `add_bounds()`, `FixedConstraintWithValue`.
9898
- **parse_params.py**: Converts flat parameter vectors to structured model parameters.
9999
Exports `create_parsing_info()` and `parse_params()`.
@@ -196,7 +196,7 @@ These are not in `__all__` but are imported directly by application projects:
196196
- `skillmodels.types.ProcessedModel`, `EndogenousFactorsInfo`
197197
- `skillmodels.decorators.register_params` — essential for custom transition functions
198198
- `skillmodels.constraints.get_constraints`, `enforce_fixed_constraints`,
199-
`FixedConstraintWithValue`, `_sel`
199+
`FixedConstraintWithValue`, `select_by_loc`
200200
- `skillmodels.utilities.extract_factors`, `update_parameter_values`
201201
- `skillmodels.process_data.pre_process_data`
202202
- `skillmodels.correlation_heatmap.get_measurements_corr`, `get_quasi_scores_corr`,

src/skillmodels/constraints.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class FixedConstraintWithValue(om.FixedConstraint):
3535
DataFrame) and `value` (the value to set before optimization).
3636
"""
3737

38-
loc: pd.MultiIndex | list | tuple | str | None = None
38+
loc: pd.MultiIndex | tuple | str | None = None
3939
"""Parameter location in the params DataFrame."""
4040
value: float | None = None
4141
"""Value to enforce on the parameter."""
@@ -202,7 +202,9 @@ def _get_mixture_weights_constraints(
202202
value=1.0,
203203
),
204204
]
205-
return [om.ProbabilityConstraint(selector=functools.partial(select_by_loc, loc=loc))]
205+
return [
206+
om.ProbabilityConstraint(selector=functools.partial(select_by_loc, loc=loc))
207+
]
206208

207209

208210
def _get_stage_constraints(
@@ -231,12 +233,16 @@ def _get_stage_constraints(
231233
loc_q = [("shock_sds", p) for p in stage_periods]
232234
constraints.append(
233235
om.PairwiseEqualityConstraint(
234-
selectors=[functools.partial(select_by_loc, loc=loc) for loc in loc_trans],
236+
selectors=[
237+
functools.partial(select_by_loc, loc=loc) for loc in loc_trans
238+
],
235239
),
236240
)
237241
constraints.append(
238242
om.PairwiseEqualityConstraint(
239-
selectors=[functools.partial(select_by_loc, loc=loc) for loc in loc_q],
243+
selectors=[
244+
functools.partial(select_by_loc, loc=loc) for loc in loc_q
245+
],
240246
),
241247
)
242248

@@ -292,7 +298,9 @@ def _get_initial_states_constraints(
292298
("initial_states", 0, f"mixture_{emf}", factors[0])
293299
for emf in range(n_mixtures)
294300
]
295-
return [om.IncreasingConstraint(selector=functools.partial(select_by_loc, loc=locs))]
301+
return [
302+
om.IncreasingConstraint(selector=functools.partial(select_by_loc, loc=locs))
303+
]
296304
return []
297305

298306

@@ -352,10 +360,11 @@ def _get_anchoring_constraints( # noqa: C901
352360
for period, meas in anchoring_updates:
353361
locs.append(("controls", period, meas, "constant"))
354362
if locs:
363+
loc = tuple(locs)
355364
constraints.append(
356365
FixedConstraintWithValue(
357-
selector=functools.partial(select_by_loc, loc=locs),
358-
loc=locs,
366+
selector=functools.partial(select_by_loc, loc=loc),
367+
loc=loc,
359368
value=0,
360369
),
361370
)
@@ -366,10 +375,11 @@ def _get_anchoring_constraints( # noqa: C901
366375
for cont in [c for c in controls if c != "constant"]:
367376
ind_tups.append(("controls", period, meas, cont))
368377
if ind_tups:
378+
loc = tuple(ind_tups)
369379
constraints.append(
370380
FixedConstraintWithValue(
371-
selector=functools.partial(select_by_loc, loc=ind_tups),
372-
loc=ind_tups,
381+
selector=functools.partial(select_by_loc, loc=loc),
382+
loc=loc,
373383
value=0,
374384
),
375385
)
@@ -383,10 +393,11 @@ def _get_anchoring_constraints( # noqa: C901
383393
ind_tups.append(("loadings", period, meas, factor))
384394

385395
if ind_tups:
396+
loc = tuple(ind_tups)
386397
constraints.append(
387398
FixedConstraintWithValue(
388-
selector=functools.partial(select_by_loc, loc=ind_tups),
389-
loc=ind_tups,
399+
selector=functools.partial(select_by_loc, loc=loc),
400+
loc=loc,
390401
value=1,
391402
),
392403
)

src/skillmodels/transition_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from jax import Array
3838

3939

40-
def _sel(params: Any, loc: Any) -> Any: # noqa: ANN401
40+
def select_by_loc(params: Any, loc: Any) -> Any: # noqa: ANN401
4141
"""Select parameters by location."""
4242
return params.loc[loc]
4343

@@ -68,7 +68,7 @@ def identity_constraints_linear(
6868
loc = ("transition", aug_period, factor, regressor)
6969
constraints.append(
7070
FixedConstraintWithValue(
71-
selector=functools.partial(_sel, loc=loc),
71+
selector=functools.partial(select_by_loc, loc=loc),
7272
loc=loc,
7373
value=val,
7474
)
@@ -122,7 +122,7 @@ def identity_constraints_translog(
122122
loc = ("transition", aug_period, factor, regressor)
123123
constraints.append(
124124
FixedConstraintWithValue(
125-
selector=functools.partial(_sel, loc=loc),
125+
selector=functools.partial(select_by_loc, loc=loc),
126126
loc=loc,
127127
value=val,
128128
)
@@ -158,7 +158,7 @@ def constraints_log_ces(
158158
"""Constraints for log_ces production function."""
159159
names = params_log_ces(factors)
160160
loc = [("transition", aug_period, factor, name) for name in names[:-1]]
161-
return om.ProbabilityConstraint(selector=functools.partial(_sel, loc=loc))
161+
return om.ProbabilityConstraint(selector=functools.partial(select_by_loc, loc=loc))
162162

163163

164164
def identity_constraints_log_ces(
@@ -244,7 +244,7 @@ def identity_constraints_linear_and_squares(
244244
loc = ("transition", aug_period, factor, regressor)
245245
constraints.append(
246246
FixedConstraintWithValue(
247-
selector=functools.partial(_sel, loc=loc),
247+
selector=functools.partial(select_by_loc, loc=loc),
248248
loc=loc,
249249
value=val,
250250
)

tests/test_constraints.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,12 +334,12 @@ def test_anchoring_constraints_for_constants(anch_uinfo) -> None:
334334

335335
expected = [
336336
{
337-
"loc": [
337+
"loc": (
338338
("controls", 0, "outcome_f1", "constant"),
339339
("controls", 0, "outcome_f2", "constant"),
340340
("controls", 1, "outcome_f1", "constant"),
341341
("controls", 1, "outcome_f2", "constant"),
342-
],
342+
),
343343
"type": "fixed",
344344
"value": 0,
345345
},
@@ -368,7 +368,7 @@ def test_anchoring_constraints_for_controls(anch_uinfo) -> None:
368368

369369
expected = [
370370
{
371-
"loc": [
371+
"loc": (
372372
("controls", 0, "outcome_f1", "c1"),
373373
("controls", 0, "outcome_f1", "c2"),
374374
("controls", 0, "outcome_f2", "c1"),
@@ -377,7 +377,7 @@ def test_anchoring_constraints_for_controls(anch_uinfo) -> None:
377377
("controls", 1, "outcome_f1", "c2"),
378378
("controls", 1, "outcome_f2", "c1"),
379379
("controls", 1, "outcome_f2", "c2"),
380-
],
380+
),
381381
"type": "fixed",
382382
"value": 0,
383383
},
@@ -401,12 +401,12 @@ def test_anchoring_constraints_for_loadings(anch_uinfo) -> None:
401401

402402
expected = [
403403
{
404-
"loc": [
404+
"loc": (
405405
("loadings", 0, "outcome_f1", "f1"),
406406
("loadings", 0, "outcome_f2", "f2"),
407407
("loadings", 1, "outcome_f1", "f1"),
408408
("loadings", 1, "outcome_f2", "f2"),
409-
],
409+
),
410410
"type": "fixed",
411411
"value": 1,
412412
},

0 commit comments

Comments
 (0)