Skip to content

Commit 9e0ca37

Browse files
committed
bugfixes
1 parent 46f65ef commit 9e0ca37

3 files changed

Lines changed: 23 additions & 21 deletions

File tree

codes/surrogates/LatentPolynomial/latent_poly.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def total_loss(
399399
traj_loss = criterion(x_pred, x_true)
400400

401401
# identity loss (reconstruct x0)
402-
identity = self.identity_loss(x_true, params.to(self.device))
402+
identity = self.identity_loss(x_true, params)
403403

404404
# derivative losses: compute once
405405
d_pred = self.first_derivative(x_pred)
@@ -455,6 +455,7 @@ def identity_loss(self, x_true: Tensor, params: Tensor = None):
455455
# only reconstruct the initial state
456456
x0 = x_true[:, 0, :]
457457
if not self.config.coeff_network and params is not None:
458+
params = params.to(self.device)
458459
enc_input = torch.cat([x0, params], dim=1)
459460
else:
460461
enc_input = x0

codes/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def create_model_dir(
6262

6363
# Check if the directory exists, and create it if it doesn't
6464
if not os.path.exists(full_path):
65-
os.makedirs(full_path)
65+
os.makedirs(full_path, exist_ok=True)
6666

6767
return full_path
6868

datasets/primordial/surrogates_config.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,37 +42,38 @@ class LatentNeuralODEConfig:
4242

4343

4444
@dataclass
45-
class FullyConnectedConfig:
46-
"""Model config for FullyConnected for the primordial dataset"""
45+
class LatentPolyConfig:
46+
"""Model config for LatentPoly for the primordial dataset"""
4747

48-
# primordial_final, trial 174
49-
scheduler: str = "poly"
50-
optimizer: str = "adam"
48+
# primordial_final_latentpoly, trial 31
49+
scheduler: str = "schedulefree"
50+
optimizer: str = "SGD"
5151
loss_function: nn.Module = nn.MSELoss()
52-
momentum: float = 0.0132
53-
degree: int = 4
54-
latent_features: int = 8
52+
activation: nn.Module = nn.ReLU()
5553
coder_layers: int = 2
5654
coder_width: int = 470
57-
learning_rate: float = 1.77e-04
58-
regularization_factor: float = 9.20e-03
59-
activation: nn.Module = nn.ReLU()
55+
degree: int = 4
56+
latent_features: int = 8
57+
learning_rate: float = 0.000177
58+
momentum: float = 0.0132
59+
regularization_factor: float = 0.0092
6060

6161

6262
@dataclass
63-
class LatentPolyConfig:
64-
"""Model config for LatentPoly for the primordial dataset"""
63+
class FullyConnectedConfig:
64+
"""Model config for FullyConnected for the primordial dataset"""
6565

66-
# primordial_final, trial 31
67-
scheduler: str = "schedulefree"
68-
optimizer: str = "SGD"
66+
# primordial_final_fullyconnected, trial 174
67+
scheduler: str = "poly"
68+
optimizer: str = "AdamW"
6969
loss_function: nn.Module = nn.SmoothL1Loss()
70+
activation: nn.Module = nn.ELU()
7071
beta: float = 3.73
7172
hidden_size: int = 470
73+
learning_rate: float = 0.00127
7274
num_hidden_layers: int = 5
73-
learning_rate: float = 1.27e-03
74-
regularization_factor: float = 3.23e-05
75-
activation: nn.Module = nn.ELU()
75+
poly_power: float = 1.48
76+
regularization_factor: float = 3.3e-05
7677

7778

7879
# @dataclass

0 commit comments

Comments
 (0)