Skip to content

Commit e7e46b0

Browse files
committed
small fixes
1 parent ff075ab commit e7e46b0

3 files changed

Lines changed: 4 additions & 2 deletions

File tree

codes/surrogates/LatentPolynomial/latent_poly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def total_loss(
399399
traj_loss = criterion(x_pred, x_true)
400400

401401
# identity loss (reconstruct x0)
402-
identity = self.identity_loss(x_true, params)
402+
identity = self.identity_loss(x_true, params.to(self.device))
403403

404404
# derivative losses: compute once
405405
d_pred = self.first_derivative(x_pred)

codes/utils/data_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def check_and_load_data(
159159
elif normalisation_mode == "standardise":
160160
data_info["mean_params"] = params_info["mean"]
161161
data_info["std_params"] = params_info["std"]
162+
if verbose:
163+
print("Parameters normalized.")
164+
print(f"Parameters info: {params_info}")
162165

163166
data_info["log10_transform"] = True if log else False
164167

run_tuning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from pathlib import Path
66

77
import optuna
8-
from optuna.study import MaxTrialsCallback
98
from optuna.trial import TrialState
109
from tqdm import tqdm
1110

0 commit comments

Comments
 (0)