Skip to content

Commit 893432d

Browse files
committed
Add dummy timestep mechanism
1 parent b5d8ab9 commit 893432d

8 files changed

Lines changed: 40 additions & 11 deletions

File tree

codes/surrogates/AbstractSurrogate/surrogates.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def prepare_data(
158158
timesteps: np.ndarray,
159159
batch_size: int,
160160
shuffle: bool,
161+
dummy_timesteps: bool = True,
161162
) -> tuple[DataLoader, DataLoader | None, DataLoader | None]:
162163
"""
163164
Prepare the data for training, testing, and validation. This method should
@@ -170,6 +171,7 @@ def prepare_data(
170171
timesteps (np.ndarray): The timesteps.
171172
batch_size (int): The batch size.
172173
shuffle (bool): Whether to shuffle the data.
174+
dummy_timesteps (bool): Whether to use dummy timesteps. Defaults to True.
173175
174176
Returns:
175177
tuple[DataLoader, DataLoader, DataLoader]: The DataLoader objects for the

codes/surrogates/DeepONet/deeponet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def prepare_data(
237237
timesteps: np.ndarray,
238238
batch_size: int,
239239
shuffle: bool = True,
240+
dummy_timesteps: bool = True,
240241
) -> tuple[DataLoader, DataLoader, DataLoader | None]:
241242
"""
242243
Prepare the data for the predict or fit methods.
@@ -249,11 +250,17 @@ def prepare_data(
249250
timesteps (np.ndarray): The timesteps.
250251
batch_size (int, optional): The batch size.
251252
shuffle (bool, optional): Whether to shuffle the data.
253+
dummy_timesteps (bool, optional): Whether to create a dummy timestep array.
252254
253255
Returns:
254256
tuple: The training, test, and validation DataLoaders.
255257
"""
256258
dataloaders = []
259+
260+
# Create dummy timesteps
261+
if dummy_timesteps:
262+
timesteps = np.linspace(0, 1, dataset_train.shape[1])
263+
257264
# Create the train dataloader
258265
dataloader_train = self.create_dataloader(
259266
dataset_train,
@@ -297,7 +304,7 @@ def fit(
297304
epochs (int, optional): The number of epochs to train the model.
298305
position (int): The position of the progress bar.
299306
description (str): The description for the progress bar.
300-
multi_objective (bool): Whether multi-objective optimization is used.
307+
multi_objective (bool): Whether multi-objective optimization is used.
301308
If True, trial.report is not used (not supported by Optuna).
302309
303310
Returns:

codes/surrogates/FCNN/fcnn.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,32 @@ def forward(self, inputs: tuple) -> torch.Tensor:
126126
def prepare_data(
127127
self,
128128
dataset_train: np.ndarray,
129-
dataset_test: np.ndarray,
129+
dataset_test: np.ndarray | None,
130130
dataset_val: np.ndarray | None,
131131
timesteps: np.ndarray,
132132
batch_size: int,
133133
shuffle: bool = True,
134+
dummy_timesteps: bool = True,
134135
) -> tuple[DataLoader, DataLoader, DataLoader | None]:
135136
"""
136137
Prepare the data for the predict or fit methods.
137-
All datasets: shape (n_samples, n_timesteps, n_quantities)
138138
139-
Returns: train_loader, test_loader, val_loader
139+
Args:
140+
dataset_train (np.ndarray): Training data.
141+
dataset_test (np.ndarray | None): Test data (optional).
142+
dataset_val (np.ndarray | None): Validation data (optional).
143+
timesteps (np.ndarray): Timesteps.
144+
batch_size (int): Batch size.
145+
shuffle (bool, optional): Whether to shuffle the data. Defaults to True.
146+
dummy_timesteps (bool, optional): Whether to use dummy timesteps. Defaults to True.
147+
148+
Returns:
149+
tuple[DataLoader, DataLoader | None, DataLoader | None]:
150+
DataLoader for training, test, and validation data.
140151
"""
141152
dataloaders = []
153+
if dummy_timesteps:
154+
timesteps = np.linspace(0, 1, dataset_train.shape[1])
142155
loader = self.create_dataloader(dataset_train, timesteps, batch_size, shuffle)
143156
dataloaders.append(loader)
144157
for dataset in [dataset_test, dataset_val]:
@@ -207,7 +220,7 @@ def fit(
207220
# Update progress bar postfix
208221
postfix = {
209222
"train_loss": f"{train_losses[index]:.2e}",
210-
"test_loss": f"{test_losses[index]:.2e}"
223+
"test_loss": f"{test_losses[index]:.2e}",
211224
}
212225
progress_bar.set_postfix(postfix)
213226

codes/surrogates/LatentNeuralODE/latent_neural_ode.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def prepare_data(
7878
timesteps: np.ndarray,
7979
batch_size: int = 128,
8080
shuffle: bool = True,
81+
dummy_timesteps: bool = True,
8182
) -> tuple[DataLoader, DataLoader | None, DataLoader | None]:
8283
"""
8384
Prepares the data for training by creating DataLoader objects.
@@ -101,6 +102,9 @@ def prepare_data(
101102
shuffled_indices = np.random.permutation(len(dataset_train))
102103
dataset_train = dataset_train[shuffled_indices]
103104

105+
if dummy_timesteps:
106+
timesteps = np.linspace(0, 1, dataset_train.shape[1])
107+
104108
# Create training DataLoader
105109
dset_train = ChemDataset(dataset_train, timesteps, device=self.device)
106110
dataloader_train = DataLoader(

codes/surrogates/LatentPolynomial/latent_poly.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def prepare_data(
8484
timesteps: np.ndarray,
8585
batch_size: int = 128,
8686
shuffle: bool = True,
87+
dummy_timesteps: bool = True,
8788
) -> tuple[DataLoader, DataLoader | None, DataLoader | None]:
8889
"""
8990
Prepare DataLoaders for training, testing, and validation.
@@ -95,10 +96,13 @@ def prepare_data(
9596
timesteps (np.ndarray): Array of timesteps.
9697
batch_size (int): Batch size.
9798
shuffle (bool): Whether to shuffle training data.
99+
dummy_timesteps (bool): Whether to use dummy timesteps.
98100
99101
Returns:
100102
tuple: DataLoaders for training, test, and validation datasets.
101103
"""
104+
if dummy_timesteps:
105+
timesteps = np.linspace(0, 1, dataset_train.shape[1])
102106
if shuffle:
103107
shuffled_indices = np.random.permutation(len(dataset_train))
104108
dataset_train = dataset_train[shuffled_indices]

codes/train/train_fcts.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ def train_and_save_model(
8282

8383
_, n_timesteps, n_quantities = train_data.shape
8484

85-
# # Replace timesteps with dummy timesteps between 0 and 1
86-
# timesteps = np.linspace(0, 1, n_timesteps)
87-
8885
# Get the surrogate class
8986
surrogate_class = get_surrogate(surr_name)
9087
model_config = get_model_config(surr_name, config)
@@ -107,6 +104,7 @@ def train_and_save_model(
107104
timesteps=timesteps,
108105
batch_size=batch_size,
109106
shuffle=True,
107+
dummy_timesteps=True,
110108
)
111109

112110
description = make_description(mode, device, str(metric), surr_name)

codes/tune/optuna_fcts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
measure_inference_time,
1515
)
1616
from codes.utils import check_and_load_data, make_description, set_random_seeds
17-
from codes.utils.data_utils import get_data_subset
17+
from codes.utils.data_utils import download_data, get_data_subset
1818

1919

2020
def load_yaml_config(config_path: str) -> dict:

config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Global settings for the benchmark
2-
training_id: "primordialtest"
2+
training_id: "primordialtest3"
33
surrogates: ["LatentNeuralODE"]
44
batch_size: [128]
55
epochs: [20,100] # [12000, 10000, 10000, 7000]
@@ -10,7 +10,8 @@ dataset:
1010
use_optimal_params: True
1111
tolerance: 1e-30
1212
subset_factor: 1
13-
devices: ["cuda:1", "cuda:6"] # ["cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "cuda:8", "cuda:9"]
13+
log_timesteps: True
14+
devices: ["cuda:0", "cuda:1"] # ["cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "cuda:8", "cuda:9"]
1415
seed: 42
1516
verbose: False
1617

0 commit comments

Comments
 (0)