Skip to content

Commit 96ac076

Browse files
committed
test: tests for the new feature
1 parent 474b32e commit 96ac076

1 file changed

Lines changed: 118 additions & 14 deletions

File tree

tests/test_fitrecipe.py

Lines changed: 118 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,115 @@ def test_initialize_recipe_from_recipe_bad(build_recipe_two_contributions):
527527
recipe2.initialize_recipe_with_recipe(recipe_bad)
528528

529529

530-
# def test_initialize_recipe_from_results(build_recipe_one_contribution):
531-
# # Case: User initializes a FitRecipe from a FitResults object or
532-
# # results file
533-
# # expected: recipe is initialized with variables from previous fit
534-
# assert False
530+
def test_initialize_recipe_from_results_object(build_recipe_one_contribution):
531+
# Case: User initializes a FitRecipe from a FitResults object
532+
# expected: recipe is initialized with variables from previous fit
533+
recipe1 = build_recipe_one_contribution()
534+
optimize_recipe(recipe1)
535+
results1 = FitResults(recipe1)
536+
expected_values = np.round(results1.varvals, 5)
537+
expected_names = results1.varnames
538+
539+
recipe2 = build_recipe_one_contribution()
540+
recipe2.create_new_variable(
541+
"extra_var", 5
542+
) # should be included in the initialized recipe
543+
actual_values_before_init = [val for val in recipe2.get_values()]
544+
actual_names_before_init = recipe2.get_names()
545+
expected_names_before_init = [
546+
"amplitude",
547+
"extra_var",
548+
"phase_shift",
549+
"wave_number",
550+
]
551+
expected_values_before_init = [
552+
4,
553+
3,
554+
2,
555+
5,
556+
] # the three variables + the extra_var
557+
558+
assert actual_values_before_init == expected_values_before_init
559+
assert sorted(actual_names_before_init) == sorted(
560+
expected_names_before_init
561+
)
562+
563+
recipe2.initialize_recipe_with_results(results1)
564+
optimize_recipe(recipe2)
565+
results2 = FitResults(recipe2)
566+
actual_values = np.round(results2.varvals, 5)
567+
actual_names = results2.varnames
568+
569+
expected_names = expected_names + [
570+
"extra_var"
571+
] # add the new variable name to expected names
572+
expected_values = list(expected_values) + [
573+
5
574+
] # add the value of the new variable to expected values
575+
assert sorted(expected_names) == sorted(actual_names)
576+
assert sorted(expected_values) == sorted(list(actual_values))
577+
578+
579+
def test_initialize_recipe_from_results_file(
580+
build_recipe_one_contribution, temp_data_files
581+
):
582+
# Case: User initializes a FitRecipe from a FitResults file
583+
# expected: recipe is initialized with variables from previous fit
584+
results_file = temp_data_files / "fit_results.res"
585+
expected_names = ["amplitude", "phase_shift", "wave_number"]
586+
expected_values = [1, 1, 0]
587+
588+
recipe = build_recipe_one_contribution()
589+
recipe.initialize_recipe_with_results(results_file)
590+
results = FitResults(recipe)
591+
actual_values = np.round(results.varvals, 5)
592+
actual_names = results.varnames
593+
594+
assert sorted(expected_names) == sorted(actual_names)
595+
assert list(expected_values) == list(actual_values)
596+
597+
598+
def test_initialize_recipe_from_results_file_bad(
599+
build_recipe_one_contribution,
600+
):
601+
# Case: User tries to initialize a recipe with something that
602+
# isn't a path, str, or FitResults object
603+
# Expected: raised ValueError with message
604+
recipe = build_recipe_one_contribution()
605+
bad_input = 12345 # not a valid input type
606+
msg = (
607+
"The input results must be a FitResults object or a path to a "
608+
"results file, but got <class 'int'>."
609+
)
610+
with pytest.raises(ValueError, match=msg):
611+
recipe.initialize_recipe_with_results(bad_input)
612+
613+
614+
def test_initialize_recipe_from_results_file_wrong(
615+
build_recipe_two_contributions, temp_data_files, capsys
616+
):
617+
# Case: User tries to initialize a FitRecipe from a results file
618+
# that does not match params in the recipe
619+
# expected: Warning message is printed and things proceed as
620+
# usual with the variables in the recipe
621+
622+
results_file_from_single_contrib = temp_data_files / "fit_results.res"
623+
recipe = build_recipe_two_contributions
624+
recipe.initialize_recipe_with_results(results_file_from_single_contrib)
625+
captured = capsys.readouterr()
626+
actual_print_msg = captured.out # .strip()
627+
628+
results_file_param_names = ["amplitude", "phase_shift", "wave_number"]
629+
expected_print_messages = []
630+
for param_name in results_file_param_names:
631+
msg = (
632+
f"Warning: Parameter '{param_name}' from results not found "
633+
"in FitRecipe and will be ignored."
634+
)
635+
expected_print_messages.append(msg)
636+
637+
for expected_print_msg in expected_print_messages:
638+
assert expected_print_msg in actual_print_msg
535639

536640

537641
def get_labels_and_linecount(ax):
@@ -591,7 +695,7 @@ def build_recipe_from_datafile_deprecated(datafile):
591695

592696

593697
def test_plot_recipe_bad_display(build_recipe_one_contribution):
594-
recipe = build_recipe_one_contribution
698+
recipe = build_recipe_one_contribution()
595699
# Case: All plots are disabled
596700
# expected: raised ValueError with message
597701
plt.close("all")
@@ -621,7 +725,7 @@ def test_plot_recipe_before_refinement(capsys, build_recipe_one_contribution):
621725
# Case: User tries to plot recipe before refinement
622726
# expected: Data plotted without fit line or difference curve
623727
# and warning message printed
624-
recipe = build_recipe_one_contribution
728+
recipe = build_recipe_one_contribution()
625729
plt.close("all")
626730
before = set(plt.get_fignums())
627731
# include fit_label="nothing" to make sure fit line is not plotted
@@ -649,7 +753,7 @@ def test_plot_recipe_before_refinement(capsys, build_recipe_one_contribution):
649753
def test_plot_recipe_after_refinement(build_recipe_one_contribution):
650754
# Case: User refines recipe and then plots
651755
# expected: Plot generates with no problem
652-
recipe = build_recipe_one_contribution
756+
recipe = build_recipe_one_contribution()
653757
optimize_recipe(recipe)
654758
plt.close("all")
655759
before = set(plt.get_fignums())
@@ -686,7 +790,7 @@ def test_plot_recipe_two_contributions(build_recipe_two_contributions):
686790
def test_plot_recipe_on_existing_plot(build_recipe_one_contribution):
687791
# Case: User passes axes to plot_recipe to plot on existing figure
688792
# expected: User modifications are present in the final figure
689-
recipe = build_recipe_one_contribution
793+
recipe = build_recipe_one_contribution()
690794
optimize_recipe(recipe)
691795
plt.close("all")
692796
fig, ax = plt.subplots()
@@ -706,7 +810,7 @@ def test_plot_recipe_on_existing_plot(build_recipe_one_contribution):
706810
def test_plot_recipe_add_new_data(build_recipe_one_contribution):
707811
# Case: User wants to add data to figure generated by plot_recipe
708812
# Expected: New data is added to existing figure (check with labels)
709-
recipe = build_recipe_one_contribution
813+
recipe = build_recipe_one_contribution()
710814
optimize_recipe(recipe)
711815
plt.close("all")
712816
before = set(plt.get_fignums())
@@ -750,7 +854,7 @@ def test_plot_recipe_add_new_data_two_figs(build_recipe_two_contributions):
750854
def test_plot_recipe_set_title(build_recipe_one_contribution):
751855
# Case: User sets title via plot_recipe
752856
# Expected: Title is set correctly
753-
recipe = build_recipe_one_contribution
857+
recipe = build_recipe_one_contribution()
754858
optimize_recipe(recipe)
755859
plt.close("all")
756860
expected_title = "Custom Recipe Title"
@@ -764,7 +868,7 @@ def test_plot_recipe_set_title(build_recipe_one_contribution):
764868
def test_plot_recipe_set_defaults(build_recipe_one_contribution):
765869
# Case: user sets default plot options with set_plot_defaults
766870
# Expected: plot_recipe uses the default options for all calls
767-
recipe = build_recipe_one_contribution
871+
recipe = build_recipe_one_contribution()
768872
optimize_recipe(recipe)
769873
plt.close("all")
770874
# set new defaults
@@ -792,7 +896,7 @@ def test_plot_recipe_set_defaults(build_recipe_one_contribution):
792896
def test_plot_recipe_set_defaults_bad(capsys, build_recipe_one_contribution):
793897
# Case: user tries to set kwargs that are not valid plot_recipe options
794898
# Expected: Plot is shown and warning is printed
795-
recipe = build_recipe_one_contribution
899+
recipe = build_recipe_one_contribution()
796900
optimize_recipe(recipe)
797901
plt.close("all")
798902
recipe.set_plot_defaults(
@@ -902,7 +1006,7 @@ def test_plot_recipe_reset_all_defaults(build_recipe_one_contribution):
9021006
"show": True,
9031007
}
9041008

905-
recipe = build_recipe_one_contribution
1009+
recipe = build_recipe_one_contribution()
9061010
optimize_recipe(recipe)
9071011
plt.close("all")
9081012

0 commit comments

Comments
 (0)