Skip to content

Commit 474b32e

Browse files
committed
feat: initialize FitRecipe from a results file or object
1 parent 0d6e1be commit 474b32e

1 file changed

Lines changed: 70 additions & 0 deletions

File tree

src/diffpy/srfit/fitbase/fitrecipe.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@
3535
__all__ = ["FitRecipe"]
3636

3737
from collections import OrderedDict
38+
from pathlib import Path
3839

3940
import matplotlib.pyplot as plt
4041
from bg_mpl_stylesheets.styles import all_styles
4142
from numpy import array, concatenate, dot, sqrt
4243

44+
import diffpy.srfit.util.inpututils as utils
4345
from diffpy.srfit.fitbase.fithook import PrintFitHook
4446
from diffpy.srfit.fitbase.parameter import ParameterProxy
4547
from diffpy.srfit.fitbase.recipeorganizer import RecipeOrganizer
@@ -1184,6 +1186,74 @@ def initialize_recipe_with_recipe(self, recipe_object):
11841186
if restraint not in self._restraints:
11851187
self._restraints.add(restraint)
11861188

1189+
def _pretty_print_results_dict(self, params_dict):
1190+
"""Pretty print a dictionary of parameter names and values."""
1191+
sorted_params = sorted(params_dict.items())
1192+
width = max(len(name) for name, _ in sorted_params)
1193+
for name, value in sorted_params:
1194+
if isinstance(value, float):
1195+
value_str = f"{value:.6g}"
1196+
else:
1197+
value_str = str(value)
1198+
print(f" {name:<{width}} = {value_str}")
1199+
1200+
def _set_parameters_from_dict(self, params_dict):
1201+
"""Set the parameters of the FitRecipe from a dictionary of
1202+
parameter names and values."""
1203+
for param_name, param_value in params_dict.items():
1204+
if param_name in self._parameters:
1205+
self._parameters[param_name].setValue(param_value)
1206+
else:
1207+
print(
1208+
f"Warning: Parameter '{param_name}' from results "
1209+
"not found in FitRecipe and will be ignored."
1210+
)
1211+
1212+
def initialize_recipe_with_results(self, results, verbose=True):
1213+
"""Initialize a FitRecipe with a FitResults object or a results
1214+
file.
1215+
1216+
Note that at least one FitContribution must already exist in
1217+
the FitRecipe.
1218+
1219+
Parameters
1220+
----------
1221+
results : FitResults, pathlib.Path, or str
1222+
The FitResults object or path to results file to initialize with.
1223+
verbose : bool, optional
1224+
If True, print warnings for any parameters in the results that are
1225+
not in the FitRecipe. Default is True.
1226+
1227+
Raises
1228+
------
1229+
ValueError
1230+
If the input results is not a FitResults object or a path to a
1231+
results file.
1232+
"""
1233+
if hasattr(results, "print_results"):
1234+
params_dict = utils.get_dict_from_results_object(results)
1235+
elif isinstance(results, (str, Path)):
1236+
params_dict = utils.get_dict_from_results_file(results)
1237+
else:
1238+
raise ValueError(
1239+
"The input results must be a FitResults object or a path to a "
1240+
f"results file, but got {type(results)}."
1241+
)
1242+
self._set_parameters_from_dict(params_dict)
1243+
if verbose:
1244+
print()
1245+
print("Parameters found in Results:")
1246+
print("=" * 30)
1247+
self._pretty_print_results_dict(params_dict)
1248+
print()
1249+
print("Parameters set in FitRecipe:")
1250+
print("=" * 30)
1251+
set_parameters_dict = {
1252+
param.name: param.getValue()
1253+
for param in self._parameters.values()
1254+
}
1255+
self._pretty_print_results_dict(set_parameters_dict)
1256+
11871257
def set_plot_defaults(self, **kwargs):
11881258
"""Set default plotting options for all future plots.
11891259

0 commit comments

Comments
 (0)