Skip to content

Commit 6d6673c

Browse files
committed
clean up and fix variable names in restrain and constrain methods
1 parent 29671ad commit 6d6673c

3 files changed

Lines changed: 85 additions & 65 deletions

File tree

src/diffpy/srfit/fitbase/recipeorganizer.py

Lines changed: 78 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -883,59 +883,63 @@ def evaluateEquation(self, eqstr, ns={}):
883883
"""
884884
return self.evaluate_equation(eqstr, func_params=ns)
885885

886-
def constrain(self, parameter, con, ns={}):
886+
def constrain(self, parameter, constraint_eq, params={}):
887887
"""Constrain a parameter to an equation.
888888
889889
Note that only one constraint can exist on a Parameter at a time.
890890
891-
Attributes
891+
Parameters
892892
----------
893-
parameter
893+
parameter : str or Parameter
894894
The name of a Parameter or a Parameter to constrain.
895-
con
895+
constraint_eq : str or Equation
896896
A string representation of the constraint equation or a
897-
Parameter to constrain to. A constraint equation must
897+
Parameter to constrain to. A constraint equation must
898898
consist of numpy operators and "known" Parameters.
899-
Parameters are known if they are in the ns argument, or if
900-
they are managed by this object.
901-
ns
899+
Parameters are known if they are in the `params`
900+
argument, or if they are managed by this object.
901+
params : dict, optional
902902
A dictionary of Parameters, indexed by name, that are used
903-
in the parameter, but not part of this object (default {}).
904-
903+
in `parameter`, but not part of this object (default {}).
905904
906-
Raises ValueError if ns uses a name that is already used for a
907-
variable.
908-
Raises ValueError if parameter is a string but not part of this
909-
object or in ns.
910-
Raises ValueError if parameter is marked as constant.
905+
Raises
906+
------
907+
ValueError
908+
If `params` uses a name that is already used for a
909+
variable.
910+
ValueError
911+
If `parameter` is a string but not part of this object or
912+
in `params`.
913+
ValueError
914+
If `parameter` is marked as constant.
911915
"""
912916
if isinstance(parameter, str):
913917
name = parameter
914918
parameter = self.get(name)
915919
if parameter is None:
916-
parameter = ns.get(name)
920+
parameter = params.get(name)
917921

918922
if parameter is None:
919923
raise ValueError("The parameter cannot be found")
920924

921925
if parameter.const:
922926
raise ValueError("The parameter '%s' is constant" % parameter)
923927

924-
if isinstance(con, str):
925-
eqstr = con
926-
eq = equationFromString(con, self._eqfactory, ns)
928+
if isinstance(constraint_eq, str):
929+
eqstr = constraint_eq
930+
eq = equationFromString(constraint_eq, self._eqfactory, params)
927931
else:
928-
eq = Equation(root=con)
929-
eqstr = con.name
932+
eq = Equation(root=constraint_eq)
933+
eqstr = constraint_eq.name
930934

931935
eq.name = "_constraint_%s" % parameter.name
932936

933937
# Make and store the constraint
934-
con = Constraint()
935-
con.constrain(parameter, eq)
938+
constraint_eq = Constraint()
939+
constraint_eq.constrain(parameter, eq)
936940
# Store the equation string so it can be shown later.
937-
con.eqstr = eqstr
938-
self._constraints[parameter] = con
941+
constraint_eq.eqstr = eqstr
942+
self._constraints[parameter] = constraint_eq
939943

940944
# Our configuration changed
941945
self._update_configuration()
@@ -1073,55 +1077,67 @@ def clearConstraints(self, recurse=False):
10731077
"""
10741078
return self.clear_all_constraints(recurse=recurse)
10751079

1076-
def restrain(self, res, lb=-inf, ub=inf, sig=1, scaled=False, ns={}):
1080+
def restrain(
1081+
self, param_or_eq, lb=-inf, ub=inf, sig=1, scaled=False, params={}
1082+
):
10771083
"""Restrain an expression to specified bounds.
10781084
1079-
Attributes
1085+
Parameters
10801086
----------
1081-
res
1082-
An equation string or Parameter to restrain.
1083-
lb
1084-
The lower bound on the restraint evaluation (default -inf).
1085-
ub
1086-
The lower bound on the restraint evaluation (default inf).
1087-
sig
1088-
The uncertainty on the bounds (default 1).
1089-
scaled
1090-
A flag indicating if the restraint is scaled (multiplied)
1091-
by the unrestrained point-average chi^2 (chi^2/numpoints)
1092-
(default False).
1093-
ns
1094-
A dictionary of Parameters, indexed by name, that are used
1095-
in the equation string, but not part of the RecipeOrganizer
1096-
(default {}).
1087+
param_or_eq : str or Parameter
1088+
The equation string or a Parameter object to restrain.
1089+
lb : float, optional
1090+
The lower bound for the restraint evaluation (default is -inf).
1091+
ub : float, optional
1092+
The upper bound for the restraint evaluation (default is inf).
1093+
sig : float, optional
1094+
The uncertainty associated with the bounds (default is 1).
1095+
scaled : bool, optional
1096+
If True, the restraint penalty is scaled by the unrestrained
1097+
point-average chi^2 (chi^2/numpoints) (default is False).
1098+
params : dict, optional
1099+
The dictionary of Parameters, indexed by name, that are used in the
1100+
equation string but are not part of the RecipeOrganizer
1101+
(default is {}).
10971102
1103+
Returns
1104+
-------
1105+
Restraint
1106+
The created Restraint object, which can be used with the
1107+
'unrestrain' method.
10981108
1099-
The penalty is calculated as
1100-
(max(0, lb - val, val - ub)/sig)**2
1101-
and val is the value of the calculated equation. This is multiplied by
1102-
the average chi^2 if scaled is True.
1109+
Notes
1110+
-----
1111+
The penalty is calculated as:
11031112
1113+
..
1114+
(max(0, lb - val, val - ub) / sig) ** 2
11041115
1105-
Raises ValueError if ns uses a name that is already used for a
1106-
Parameter.
1107-
Raises ValueError if res depends on a Parameter that is not part of
1108-
the RecipeOrganizer and that is not defined in ns.
1116+
where `val` is the value of the evaluated equation.
1117+
If `scaled` is True, this penalty is multiplied by
1118+
the average chi^2.
11091119
1110-
Returns the Restraint object for use with the 'unrestrain' method.
1120+
Raises
1121+
------
1122+
ValueError
1123+
If `func_params` contains a name that is already used
1124+
for a Parameter.
1125+
ValueError
1126+
If `param_or_eq` depends on a Parameter that is not part of the
1127+
RecipeOrganizer and is not defined in `func_params`.
11111128
"""
1112-
1113-
if isinstance(res, str):
1114-
eqstr = res
1115-
eq = equationFromString(res, self._eqfactory, ns)
1129+
if isinstance(param_or_eq, str):
1130+
eqstr = param_or_eq
1131+
eq = equationFromString(param_or_eq, self._eqfactory, params)
11161132
else:
1117-
eq = Equation(root=res)
1118-
eqstr = res.name
1133+
eq = Equation(root=param_or_eq)
1134+
eqstr = param_or_eq.name
11191135

11201136
# Make and store the restraint
1121-
res = Restraint(eq, lb, ub, sig, scaled)
1122-
res.eqstr = eqstr
1123-
self.addRestraint(res)
1124-
return res
1137+
param_or_eq = Restraint(eq, lb, ub, sig, scaled)
1138+
param_or_eq.eqstr = eqstr
1139+
self.addRestraint(param_or_eq)
1140+
return param_or_eq
11251141

11261142
def addRestraint(self, res):
11271143
"""Add a Restraint instance to the RecipeOrganizer.

src/diffpy/srfit/structure/sgconstraints.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,9 @@ def _constrain_adps(self, positions):
596596
continue
597597
isoidx.append(j)
598598
scatterer = scatterers[j]
599-
scatterer.constrain(isosymbol, isoname, ns=self._parameters)
599+
scatterer.constrain(
600+
isosymbol, isoname, params=self._parameters
601+
)
600602

601603
fadp = g.UFormulas(adpnames)
602604

@@ -809,7 +811,7 @@ def _makeconstraint(parname, formula, scatterer, idx, ns={}):
809811
# If we got here, then we have a constraint equation
810812
# Fix any division issues
811813
formula = formula.replace("/", "*1.0/")
812-
scatterer.constrain(par, formula, ns=ns)
814+
scatterer.constrain(par, formula, params=ns)
813815
return
814816

815817

tests/test_recipeorganizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ def testRestrain(self):
352352

353353
# Check errors on unregistered parameters
354354
self.assertRaises(ValueError, self.m.restrain, "2*p3")
355-
self.assertRaises(ValueError, self.m.restrain, "2*p2", ns={"p2": p3})
355+
self.assertRaises(
356+
ValueError, self.m.restrain, "2*p2", params={"p2": p3}
357+
)
356358
return
357359

358360
def testGetConstraints(self):

0 commit comments

Comments
 (0)