Skip to content

Commit ae21301

Browse files
committed
Add loss plot script
1 parent 9cdaf4a commit ae21301

2 files changed

Lines changed: 250 additions & 4 deletions

File tree

.gitignore

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ profiling*
1616
study
1717
sindy
1818
lorenzo_data.ipynb
19-
optuna_runs/models
20-
optuna_runs/studies
21-
optuna_runs/plots
19+
tuned/models
20+
tuned/studies
21+
tuned/plots
2222
scripts
2323
build/
2424
dist/
@@ -36,4 +36,4 @@ docs/_build/
3636
.doctrees/
3737
docs/docs/
3838
.coverage
39-
optuna_runs/
39+
tuned/

codes/tune/evaluate_tuning.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import argparse
2+
import os
3+
import re
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import torch
8+
9+
10+
def load_loss_history(model_path: str) -> tuple[np.ndarray, np.ndarray, int]:
11+
"""
12+
Load loss histories from a saved model file.
13+
14+
The saved file is expected to be in the custom format, where the loss histories and other
15+
attributes are stored under the "attributes" key.
16+
17+
Args:
18+
model_path (str): Path to the .pth file.
19+
20+
Returns:
21+
tuple: (train_loss, test_loss, n_epochs)
22+
"""
23+
model_dict = torch.load(model_path, map_location="cpu", weights_only=False)
24+
attributes = model_dict.get("attributes", {})
25+
# Expect that train_loss, test_loss, and n_epochs have been saved.
26+
train_loss = (
27+
np.array(attributes.get("train_loss"))
28+
if attributes.get("train_loss") is not None
29+
else None
30+
)
31+
test_loss = (
32+
np.array(attributes.get("test_loss"))
33+
if attributes.get("test_loss") is not None
34+
else None
35+
)
36+
n_epochs = attributes.get(
37+
"n_epochs", len(train_loss) if train_loss is not None else 0
38+
)
39+
return train_loss, test_loss, n_epochs
40+
41+
42+
def plot_losses(
43+
loss_histories: tuple[np.ndarray, ...],
44+
epochs: int,
45+
labels: tuple[str, ...],
46+
title: str = "Losses",
47+
save: bool = False,
48+
conf: dict | None = None,
49+
surr_name: str | None = None,
50+
mode: str = "main",
51+
percentage: float = 2.0,
52+
show_title: bool = True,
53+
) -> None:
54+
"""
55+
Plot the loss trajectories for multiple models using their actual lengths.
56+
57+
Each loss trajectory is plotted over its own length (i.e. trial-specific number of epochs),
58+
rather than forcing all trajectories to the length of the shortest one. The global y-axis limits
59+
are determined from the valid (nonzero) portions of each trajectory after excluding the initial
60+
percentage of epochs.
61+
62+
Args:
63+
loss_histories (tuple[np.ndarray, ...]): Tuple of loss history arrays.
64+
epochs (int): Total number of training epochs (used for labeling only).
65+
labels (tuple[str, ...]): Labels for each loss history.
66+
title (str): Title for the plot.
67+
save (bool): Whether to save the plot as an image file.
68+
conf (dict | None): Configuration dictionary (used for naming output files).
69+
surr_name (str | None): Surrogate model name.
70+
mode (str): Mode for labeling (e.g., "main" or surrogate name).
71+
percentage (float): Percentage of initial epochs to exclude from min/max y-value calculation.
72+
show_title (bool): Whether to display the title.
73+
"""
74+
# Filter out loss arrays that are None or empty.
75+
valid_losses = [
76+
loss for loss in loss_histories if loss is not None and loss.size > 0
77+
]
78+
if not valid_losses:
79+
print("No valid loss arrays found; skipping plot.")
80+
return
81+
82+
# Determine global maximum length (for x-axis limit).
83+
lengths = [len(loss) for loss in valid_losses]
84+
max_length = max(lengths)
85+
86+
# Compute global min and max values across all valid losses.
87+
valid_mins = []
88+
valid_maxes = []
89+
for loss in valid_losses:
90+
start_idx = int(len(loss) * (percentage / 100))
91+
slice_vals = loss[start_idx:]
92+
valid_vals = slice_vals[slice_vals > 0]
93+
if valid_vals.size > 0:
94+
valid_mins.append(valid_vals.min())
95+
valid_maxes.append(valid_vals.max())
96+
if valid_mins:
97+
global_min = min(valid_mins)
98+
global_max = max(valid_maxes)
99+
else:
100+
global_min, global_max = 1e-8, 1.0
101+
102+
# Create color map for plotting.
103+
colors = plt.cm.magma(np.linspace(0.15, 0.85, len(loss_histories)))
104+
105+
plt.figure(figsize=(6, 4))
106+
loss_plotted = False
107+
for loss, label in zip(loss_histories, labels):
108+
if loss is not None and loss.size > 0:
109+
# Generate x-axis based on the actual length of this loss history.
110+
x_epochs = np.arange(len(loss))
111+
plt.plot(x_epochs, loss, label=label, color=colors[labels.index(label)])
112+
loss_plotted = True
113+
114+
plt.xlabel("Epoch")
115+
plt.xlim(0, max_length)
116+
plt.ylabel("Loss")
117+
plt.yscale("log")
118+
plt.ylim(global_min, global_max)
119+
if show_title:
120+
plt.title(title)
121+
plt.legend()
122+
123+
if not loss_plotted:
124+
plt.text(
125+
0.5,
126+
0.5,
127+
"No losses available",
128+
horizontalalignment="center",
129+
verticalalignment="center",
130+
)
131+
132+
# Save the plot if requested.
133+
if save and conf and surr_name:
134+
out_dir = os.path.join("tuned", conf.get("study_name", "study"))
135+
os.makedirs(out_dir, exist_ok=True)
136+
save_path = os.path.join(out_dir, f"losses_{mode.lower()}.png")
137+
plt.savefig(save_path, dpi=300)
138+
print(f"Plot saved to {save_path}")
139+
140+
plt.close()
141+
142+
143+
def evaluate_tuning(study_name: str) -> None:
144+
"""
145+
Evaluate the tuning step by generating loss plots for each surrogate model.
146+
147+
This function looks for folders in "tuned/<study_name>/models". Each folder should
148+
correspond to a surrogate model (e.g., "FullyConnected", "LatentPoly", etc.). It then
149+
loads all .pth files within each folder, extracts the loss trajectories (test_loss),
150+
extracts the trial number from the filename, and generates a loss plot.
151+
152+
Args:
153+
study_name (str): Name of the study (e.g., "primordialtest").
154+
"""
155+
models_dir = os.path.join("tuned", study_name, "models")
156+
output_dir = os.path.join("tuned", study_name)
157+
os.makedirs(output_dir, exist_ok=True)
158+
159+
# Get a list of surrogate folders.
160+
surrogate_folders = [
161+
d for d in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, d))
162+
]
163+
for surr_folder in surrogate_folders:
164+
surr_path = os.path.join(models_dir, surr_folder)
165+
print(f"Processing surrogate model folder: {surr_folder}")
166+
167+
# Find all model files (*.pth) in this folder.
168+
model_files = [f for f in os.listdir(surr_path) if f.endswith(".pth")]
169+
if not model_files:
170+
print(f"No model files found in {surr_path}. Skipping.")
171+
continue
172+
173+
trial_numbers = []
174+
test_loss_histories = []
175+
n_epochs = None
176+
177+
for file_name in model_files:
178+
# Extract trial number from filename (e.g., "latentpoly_0.pth")
179+
match = re.search(r"_(\d+)\.pth$", file_name)
180+
if match:
181+
trial_num = int(match.group(1))
182+
else:
183+
trial_num = -1 # Default if extraction fails.
184+
trial_numbers.append(trial_num)
185+
186+
file_path = os.path.join(surr_path, file_name)
187+
_, test_loss, epochs = load_loss_history(file_path)
188+
test_loss_histories.append(test_loss)
189+
if n_epochs is None:
190+
n_epochs = epochs
191+
192+
# Sort trials by trial number for consistent labeling.
193+
sorted_trials = sorted(
194+
zip(trial_numbers, test_loss_histories), key=lambda x: x[0]
195+
)
196+
trial_numbers, test_loss_histories = zip(*sorted_trials)
197+
labels = tuple(f"Trial {num}" for num in trial_numbers)
198+
199+
# Create the plot using the provided plot_losses function.
200+
plot_losses(
201+
loss_histories=test_loss_histories,
202+
epochs=n_epochs,
203+
labels=labels,
204+
title=f"{surr_folder} Test Losses",
205+
save=True,
206+
conf={"study_name": study_name},
207+
surr_name=surr_folder,
208+
mode=surr_folder,
209+
show_title=True,
210+
)
211+
print(f"Loss plot created for surrogate: {surr_folder}.")
212+
213+
214+
def parse_args() -> argparse.Namespace:
215+
"""
216+
Parse command-line arguments.
217+
218+
Returns:
219+
argparse.Namespace: Parsed arguments containing study_name.
220+
"""
221+
parser = argparse.ArgumentParser(
222+
description="Evaluate tuning loss trajectories and generate plots."
223+
)
224+
parser.add_argument(
225+
"--study_name",
226+
type=str,
227+
required=True,
228+
help="Name of the study (e.g., primordialtest)",
229+
)
230+
return parser.parse_args()
231+
232+
233+
def main():
234+
"""
235+
Main function to evaluate tuning.
236+
237+
Reads the study name from command-line arguments, processes each surrogate folder in
238+
tuned/<study_name>/models, and generates loss plots saved to tuned/<study_name>/.
239+
"""
240+
args = parse_args()
241+
study_name = args.study_name
242+
evaluate_tuning(study_name)
243+
244+
245+
if __name__ == "__main__":
246+
main()

0 commit comments

Comments
 (0)