Skip to content

Commit 7208b96

Browse files
committed
Small adjustments
1 parent ae21301 commit 7208b96

7 files changed

Lines changed: 76 additions & 10 deletions

File tree

codes/surrogates/LatentNeuralODE/latent_neural_ode.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ class ModelWrapper(torch.nn.Module):
372372
Wraps the encoder, decoder, and neural ODE into a single model.
373373
Chooses architecture based on the config.model_version flag.
374374
"""
375+
375376
def __init__(self, config, n_quantities: int):
376377
super().__init__()
377378
self.config = config
@@ -383,7 +384,9 @@ def __init__(self, config, n_quantities: int):
383384
self.encoder = OldEncoder(
384385
in_features=n_quantities,
385386
latent_features=config.latent_features,
386-
layers_factor=getattr(config, "layers_factor", 8), # for backward compatibility
387+
layers_factor=getattr(
388+
config, "layers_factor", 8
389+
), # for backward compatibility
387390
activation=config.activation,
388391
)
389392
self.decoder = OldDecoder(
@@ -640,7 +643,8 @@ def __init__(
640643

641644
def forward(self, x: torch.Tensor) -> torch.Tensor:
642645
"""Forward pass to encode the input into the latent space."""
643-
return self.mlp(x)
646+
out = self.mlp(x)
647+
return out
644648

645649

646650
class Decoder(torch.nn.Module):
@@ -683,7 +687,8 @@ def __init__(
683687

684688
def forward(self, x: torch.Tensor) -> torch.Tensor:
685689
"""Forward pass to decode the latent representation into output features."""
686-
return self.mlp(x)
690+
out = self.mlp(x)
691+
return out
687692

688693

689694
class OldEncoder(torch.nn.Module):

codes/surrogates/LatentNeuralODE/latent_neural_ode_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ class LatentNeuralODEBaseConfig:
3939
ode_layers: int = 4
4040
ode_width: int = 64
4141
ode_tanh_reg: bool = True
42-
rtol: float = 1e-5
43-
atol: float = 1e-5
42+
rtol: float = 1e-6
43+
atol: float = 1e-6
4444
learning_rate: float = 1e-3

codes/train/train_fcts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def train_and_save_model(
8282

8383
_, n_timesteps, n_quantities = train_data.shape
8484

85+
# # Replace timesteps with dummy timesteps between 0 and 1
86+
# timesteps = np.linspace(0, 1, n_timesteps)
87+
8588
# Get the surrogate class
8689
surrogate_class = get_surrogate(surr_name)
8790
model_config = get_model_config(surr_name, config)

config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Global settings for the benchmark
2-
training_id: "srtest3"
3-
surrogates: ["MultiONet", "FullyConnected", "LatentPoly", "LatentNeuralODE", ]
4-
batch_size: [4096, 4096, 256, 256]
5-
epochs: [10,10,10,10] # [12000, 10000, 10000, 7000]
2+
training_id: "primordialtest"
3+
surrogates: ["LatentNeuralODE"]
4+
batch_size: [128]
5+
epochs: [20,100] # [12000, 10000, 10000, 7000]
66
dataset:
7-
name: "simple_reaction"
7+
name: "primordial"
88
log10_transform: True
99
normalise: "minmax" # "minmax" # "standardise", "minmax", "disable"
1010
use_optimal_params: True

datasets/_data_analysis/analyse_all_datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def main():
2727
Main function to analyse the dataset. It checks the dataset and loads the data.
2828
"""
2929
datasets = [
30+
"primordial",
3031
"simple_reaction",
3132
"osutest2",
3233
"coupled_oscillators",

datasets/_data_analysis/dataset_dict.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,9 @@
5151
"qpp": 5,
5252
"tol": 1e-30,
5353
},
54+
"primordial": {
55+
"log": True,
56+
"qpp": 5,
57+
"tol": 1e-30,
58+
},
5459
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from dataclasses import dataclass
2+
3+
from torch import nn
4+
5+
6+
@dataclass
7+
class MultiONetConfig:
8+
"""Model config for MultiONet for the simple_ode dataset"""
9+
10+
branch_hidden_layers: int = 6
11+
trunk_hidden_layers: int = 4
12+
hidden_size: int = 347
13+
output_factor: int = 90
14+
learning_rate: float = 1.2e-5
15+
activation: nn.Module = nn.ReLU()
16+
17+
18+
@dataclass
19+
class LatentNeuralODEConfig:
20+
"""Model config for LatentNeuralODE for the simple_ode dataset"""
21+
22+
latent_features: int = 9
23+
coder_layers: int = 4
24+
coder_width: int = 230
25+
learning_rate: float = 0.0004
26+
ode_layers: int = 5
27+
ode_width: int = 150
28+
ode_tanh_reg: bool = True
29+
activation: nn.Module = nn.LeakyReLU()
30+
model_version: str = "v2"
31+
32+
33+
@dataclass
34+
class FullyConnectedConfig:
35+
"""Model config for FullyConnected for the simple_ode dataset"""
36+
37+
hidden_size: int = 800
38+
num_hidden_layers: int = 8
39+
learning_rate: float = 3e-5
40+
activation: nn.Module = nn.GELU()
41+
42+
43+
@dataclass
44+
class LatentPolyConfig:
45+
"""Model config for LatentPoly for the simple_ode dataset"""
46+
47+
latent_features: int = 6
48+
degree: int = 2
49+
learning_rate: float = 0.002
50+
coder_layers: int = 4
51+
coder_width: int = 230
52+
activation: nn.Module = nn.LeakyReLU()

0 commit comments

Comments
 (0)