|
2 | 2 |
|
3 | 3 | from dataclasses import dataclass |
4 | 4 | from pathlib import Path |
5 | | -from typing import Dict, Optional |
| 5 | +from typing import Dict, List, Optional |
6 | 6 |
|
| 7 | +import numpy as np |
7 | 8 | import torch |
8 | 9 | import torch.nn as nn |
9 | 10 | import torch.optim as optim |
10 | 11 | from torch.utils.data import DataLoader |
11 | 12 | from torch.utils.tensorboard import SummaryWriter |
12 | 13 |
|
| 14 | +from framework.data_utils import create_dataloaders |
| 15 | +from framework.datasets import CIFAR10Dataset |
13 | 16 | from framework.training import Checkpoint, EarlyStopping, train_epoch, validate |
14 | 17 | from framework.utils import count_parameters, get_device |
15 | 18 | from .ParamSpace import ParamSpace |
@@ -106,17 +109,47 @@ def create_model(self, **params) -> None: |
106 | 109 |
|
107 | 110 | def train( |
108 | 111 | self, |
109 | | - train_loader: DataLoader, |
110 | | - val_loader: DataLoader, |
111 | | - config: Optional[TrainingConfig] = None, |
| 112 | + X_train: List[np.ndarray], |
| 113 | + y_train: np.ndarray, |
| 114 | + X_val: List[np.ndarray], |
| 115 | + y_val: np.ndarray, |
112 | 116 | device: Optional[torch.device] = None, |
| 117 | + epochs: Optional[int] = None, |
| 118 | + patience: Optional[int] = None, |
| 119 | + min_delta: Optional[float] = None, |
| 120 | + checkpoint_path: Optional[Path] = None, |
| 121 | + grad_clip_norm: Optional[float] = None, |
| 122 | + writer: Optional[SummaryWriter] = None, |
| 123 | + num_workers: int = 2, |
113 | 124 | ) -> Dict[str, float]: |
114 | 125 | if self.network is None: |
115 | 126 | raise RuntimeError("Train called before model is initialized") |
116 | 127 | device = device or get_device() |
117 | 128 | self.network = self.network.to(device) |
118 | 129 |
|
119 | | - config = config or TrainingConfig() |
| 130 | + default_config = TrainingConfig() |
| 131 | + config = TrainingConfig( |
| 132 | + learning_rate=float(self.params.get("learning_rate", default_config.learning_rate)), |
| 133 | + weight_decay=float(self.params.get("weight_decay", default_config.weight_decay)), |
| 134 | + optimizer=self.params.get("optimizer", default_config.optimizer), |
| 135 | + batch_size=int(self.params.get("batch_size", default_config.batch_size)), |
| 136 | + # Infrastructure params: use provided values or defaults |
| 137 | + epochs=epochs if epochs is not None else default_config.epochs, |
| 138 | + patience=patience if patience is not None else default_config.patience, |
| 139 | + min_delta=min_delta if min_delta is not None else default_config.min_delta, |
| 140 | + checkpoint_path=checkpoint_path if checkpoint_path is not None else default_config.checkpoint_path, |
| 141 | + grad_clip_norm=grad_clip_norm if grad_clip_norm is not None else default_config.grad_clip_norm, |
| 142 | + writer=writer if writer is not None else default_config.writer, |
| 143 | + ) |
| 144 | + |
| 145 | + train_loader, val_loader = create_dataloaders( |
| 146 | + X_train, |
| 147 | + y_train, |
| 148 | + X_val, |
| 149 | + y_val, |
| 150 | + batch_size=config.batch_size, |
| 151 | + num_workers=num_workers, |
| 152 | + ) |
120 | 153 |
|
121 | 154 | optimizer = self._build_optimizer(self.network, config) |
122 | 155 | scheduler = optim.lr_scheduler.OneCycleLR( |
@@ -226,16 +259,29 @@ def predict( |
226 | 259 |
|
227 | 260 | def evaluate( |
228 | 261 | self, |
229 | | - data_loader: DataLoader, |
| 262 | + X: List[np.ndarray], |
| 263 | + y: np.ndarray, |
230 | 264 | device: Optional[torch.device] = None, |
231 | 265 | criterion: Optional[nn.Module] = None, |
| 266 | + num_workers: int = 0, |
232 | 267 | ) -> Dict[str, float]: |
233 | 268 | if self.network is None: |
234 | 269 | raise RuntimeError("Evaluate called before model is initialized") |
235 | 270 | device = device or get_device() |
236 | 271 | network = self.network.to(device) |
237 | 272 | network.eval() |
238 | 273 |
|
| 274 | + default_config = TrainingConfig() |
| 275 | + batch_size = int(self.params.get("batch_size", default_config.batch_size)) |
| 276 | + dataset = CIFAR10Dataset(X, y) |
| 277 | + data_loader = DataLoader( |
| 278 | + dataset, |
| 279 | + batch_size=batch_size, |
| 280 | + shuffle=False, |
| 281 | + num_workers=num_workers, |
| 282 | + pin_memory=torch.cuda.is_available(), |
| 283 | + ) |
| 284 | + |
239 | 285 | criterion = criterion or nn.CrossEntropyLoss() |
240 | 286 |
|
241 | 287 | total_loss = 0.0 |
|
0 commit comments