@@ -35,7 +35,7 @@ class FixedConstraintWithValue(om.FixedConstraint):
3535 DataFrame) and `value` (the value to set before optimization).
3636 """
3737
38- loc : pd .MultiIndex | list | tuple | str | None = None
38+ loc : pd .MultiIndex | tuple | str | None = None
3939 """Parameter location in the params DataFrame."""
4040 value : float | None = None
4141 """Value to enforce on the parameter."""
@@ -202,7 +202,9 @@ def _get_mixture_weights_constraints(
202202 value = 1.0 ,
203203 ),
204204 ]
205- return [om .ProbabilityConstraint (selector = functools .partial (select_by_loc , loc = loc ))]
205+ return [
206+ om .ProbabilityConstraint (selector = functools .partial (select_by_loc , loc = loc ))
207+ ]
206208
207209
208210def _get_stage_constraints (
@@ -231,12 +233,16 @@ def _get_stage_constraints(
231233 loc_q = [("shock_sds" , p ) for p in stage_periods ]
232234 constraints .append (
233235 om .PairwiseEqualityConstraint (
234- selectors = [functools .partial (select_by_loc , loc = loc ) for loc in loc_trans ],
236+ selectors = [
237+ functools .partial (select_by_loc , loc = loc ) for loc in loc_trans
238+ ],
235239 ),
236240 )
237241 constraints .append (
238242 om .PairwiseEqualityConstraint (
239- selectors = [functools .partial (select_by_loc , loc = loc ) for loc in loc_q ],
243+ selectors = [
244+ functools .partial (select_by_loc , loc = loc ) for loc in loc_q
245+ ],
240246 ),
241247 )
242248
@@ -292,7 +298,9 @@ def _get_initial_states_constraints(
292298 ("initial_states" , 0 , f"mixture_{ emf } " , factors [0 ])
293299 for emf in range (n_mixtures )
294300 ]
295- return [om .IncreasingConstraint (selector = functools .partial (select_by_loc , loc = locs ))]
301+ return [
302+ om .IncreasingConstraint (selector = functools .partial (select_by_loc , loc = locs ))
303+ ]
296304 return []
297305
298306
@@ -352,10 +360,11 @@ def _get_anchoring_constraints( # noqa: C901
352360 for period , meas in anchoring_updates :
353361 locs .append (("controls" , period , meas , "constant" ))
354362 if locs :
363+ loc = tuple (locs )
355364 constraints .append (
356365 FixedConstraintWithValue (
357- selector = functools .partial (select_by_loc , loc = locs ),
358- loc = locs ,
366+ selector = functools .partial (select_by_loc , loc = loc ),
367+ loc = loc ,
359368 value = 0 ,
360369 ),
361370 )
@@ -366,10 +375,11 @@ def _get_anchoring_constraints( # noqa: C901
366375 for cont in [c for c in controls if c != "constant" ]:
367376 ind_tups .append (("controls" , period , meas , cont ))
368377 if ind_tups :
378+ loc = tuple (ind_tups )
369379 constraints .append (
370380 FixedConstraintWithValue (
371- selector = functools .partial (select_by_loc , loc = ind_tups ),
372- loc = ind_tups ,
381+ selector = functools .partial (select_by_loc , loc = loc ),
382+ loc = loc ,
373383 value = 0 ,
374384 ),
375385 )
@@ -383,10 +393,11 @@ def _get_anchoring_constraints( # noqa: C901
383393 ind_tups .append (("loadings" , period , meas , factor ))
384394
385395 if ind_tups :
396+ loc = tuple (ind_tups )
386397 constraints .append (
387398 FixedConstraintWithValue (
388- selector = functools .partial (select_by_loc , loc = ind_tups ),
389- loc = ind_tups ,
399+ selector = functools .partial (select_by_loc , loc = loc ),
400+ loc = loc ,
390401 value = 1 ,
391402 ),
392403 )
0 commit comments