@@ -126,19 +126,32 @@ def forward(self, inputs: tuple) -> torch.Tensor:
126126 def prepare_data (
127127 self ,
128128 dataset_train : np .ndarray ,
129- dataset_test : np .ndarray ,
129+ dataset_test : np .ndarray | None ,
130130 dataset_val : np .ndarray | None ,
131131 timesteps : np .ndarray ,
132132 batch_size : int ,
133133 shuffle : bool = True ,
134+ dummy_timesteps : bool = True ,
134135 ) -> tuple [DataLoader , DataLoader , DataLoader | None ]:
135136 """
136137 Prepare the data for the predict or fit methods.
137- All datasets: shape (n_samples, n_timesteps, n_quantities)
138138
139- Returns: train_loader, test_loader, val_loader
139+ Args:
140+ dataset_train (np.ndarray): Training data.
141+ dataset_test (np.ndarray | None): Test data (optional).
142+ dataset_val (np.ndarray | None): Validation data (optional).
143+ timesteps (np.ndarray): Timesteps.
144+ batch_size (int): Batch size.
145+ shuffle (bool, optional): Whether to shuffle the data. Defaults to True.
146+ dummy_timesteps (bool, optional): Whether to use dummy timesteps. Defaults to True.
147+
148+ Returns:
149+ tuple[DataLoader, DataLoader | None, DataLoader | None]:
150+ DataLoader for training, test, and validation data.
140151 """
141152 dataloaders = []
153+ if dummy_timesteps :
154+ timesteps = np .linspace (0 , 1 , dataset_train .shape [1 ])
142155 loader = self .create_dataloader (dataset_train , timesteps , batch_size , shuffle )
143156 dataloaders .append (loader )
144157 for dataset in [dataset_test , dataset_val ]:
@@ -207,7 +220,7 @@ def fit(
207220 # Update progress bar postfix
208221 postfix = {
209222 "train_loss" : f"{ train_losses [index ]:.2e} " ,
210- "test_loss" : f"{ test_losses [index ]:.2e} "
223+ "test_loss" : f"{ test_losses [index ]:.2e} " ,
211224 }
212225 progress_bar .set_postfix (postfix )
213226
0 commit comments