Skip to content

Commit c2ada92

Browse files
committed
Streamline CNN data handling
1 parent 11194ba commit c2ada92

3 files changed

Lines changed: 70 additions & 60 deletions

File tree

hparam_search.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,14 @@
66

77
import numpy as np
88
import torch
9-
from torch.utils.data import DataLoader
109
from torch.utils.tensorboard import SummaryWriter
1110

1211
from framework.data_utils import (
13-
create_dataloaders,
1412
load_cifar10_data,
1513
prepare_data,
1614
split_train_val,
1715
)
18-
from framework.datasets import CIFAR10Dataset
19-
from framework.utils import get_device
2016
from models.base import get_model_by_name
21-
from models.cnn import TrainingConfig
2217
from search import RandomSearch
2318

2419
RANDOM_SEED = 321
@@ -81,42 +76,14 @@ def evaluate_model(
8176
return model.evaluate(data["val_flat"], data["val_labels"])
8277

8378
if model_key == "cnn":
84-
# Architecture specific parameters
85-
architecture = {k: params[k] for k in ("kernel_size", "stride")}
86-
model.create_model(**architecture)
87-
88-
# Training specific parameters
89-
batch_size = int(params["batch_size"])
90-
config = TrainingConfig(
91-
epochs=DEFAULT_EPOCHS,
92-
learning_rate=float(params["learning_rate"]),
93-
weight_decay=float(params["weight_decay"]),
94-
optimizer=params["optimizer"],
95-
patience=DEFAULT_PATIENCE,
96-
batch_size=batch_size,
97-
)
98-
train_loader, val_loader = create_dataloaders(
79+
model.create_model(**params)
80+
model.train(
9981
data["train_images"],
10082
data["train_labels"],
10183
data["val_images"],
10284
data["val_labels"],
103-
batch_size=batch_size,
104-
)
105-
106-
device = get_device()
107-
model.train(
108-
train_loader, val_loader, config=config, device=device
109-
)
110-
111-
eval_loader = DataLoader(
112-
CIFAR10Dataset(data["val_images"], data["val_labels"]),
113-
batch_size=batch_size,
114-
shuffle=False,
115-
num_workers=0,
116-
pin_memory=torch.cuda.is_available(),
11785
)
118-
eval_metrics = model.evaluate(eval_loader, device=device)
119-
return eval_metrics
86+
return model.evaluate(data["val_images"], data["val_labels"])
12087

12188
raise ValueError(f"Unsupported model key: {model_key}")
12289

models/cnn.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
from dataclasses import dataclass
44
from pathlib import Path
5-
from typing import Dict, Optional
5+
from typing import Dict, List, Optional
66

7+
import numpy as np
78
import torch
89
import torch.nn as nn
910
import torch.optim as optim
1011
from torch.utils.data import DataLoader
1112
from torch.utils.tensorboard import SummaryWriter
1213

14+
from framework.data_utils import create_dataloaders
15+
from framework.datasets import CIFAR10Dataset
1316
from framework.training import Checkpoint, EarlyStopping, train_epoch, validate
1417
from framework.utils import count_parameters, get_device
1518
from .ParamSpace import ParamSpace
@@ -106,17 +109,47 @@ def create_model(self, **params) -> None:
106109

107110
def train(
108111
self,
109-
train_loader: DataLoader,
110-
val_loader: DataLoader,
111-
config: Optional[TrainingConfig] = None,
112+
X_train: List[np.ndarray],
113+
y_train: np.ndarray,
114+
X_val: List[np.ndarray],
115+
y_val: np.ndarray,
112116
device: Optional[torch.device] = None,
117+
epochs: Optional[int] = None,
118+
patience: Optional[int] = None,
119+
min_delta: Optional[float] = None,
120+
checkpoint_path: Optional[Path] = None,
121+
grad_clip_norm: Optional[float] = None,
122+
writer: Optional[SummaryWriter] = None,
123+
num_workers: int = 2,
113124
) -> Dict[str, float]:
114125
if self.network is None:
115126
raise RuntimeError("Train called before model is initialized")
116127
device = device or get_device()
117128
self.network = self.network.to(device)
118129

119-
config = config or TrainingConfig()
130+
default_config = TrainingConfig()
131+
config = TrainingConfig(
132+
learning_rate=float(self.params.get("learning_rate", default_config.learning_rate)),
133+
weight_decay=float(self.params.get("weight_decay", default_config.weight_decay)),
134+
optimizer=self.params.get("optimizer", default_config.optimizer),
135+
batch_size=int(self.params.get("batch_size", default_config.batch_size)),
136+
# Infrastructure params: use provided values or defaults
137+
epochs=epochs if epochs is not None else default_config.epochs,
138+
patience=patience if patience is not None else default_config.patience,
139+
min_delta=min_delta if min_delta is not None else default_config.min_delta,
140+
checkpoint_path=checkpoint_path if checkpoint_path is not None else default_config.checkpoint_path,
141+
grad_clip_norm=grad_clip_norm if grad_clip_norm is not None else default_config.grad_clip_norm,
142+
writer=writer if writer is not None else default_config.writer,
143+
)
144+
145+
train_loader, val_loader = create_dataloaders(
146+
X_train,
147+
y_train,
148+
X_val,
149+
y_val,
150+
batch_size=config.batch_size,
151+
num_workers=num_workers,
152+
)
120153

121154
optimizer = self._build_optimizer(self.network, config)
122155
scheduler = optim.lr_scheduler.OneCycleLR(
@@ -226,16 +259,29 @@ def predict(
226259

227260
def evaluate(
228261
self,
229-
data_loader: DataLoader,
262+
X: List[np.ndarray],
263+
y: np.ndarray,
230264
device: Optional[torch.device] = None,
231265
criterion: Optional[nn.Module] = None,
266+
num_workers: int = 0,
232267
) -> Dict[str, float]:
233268
if self.network is None:
234269
raise RuntimeError("Evaluate called before model is initialized")
235270
device = device or get_device()
236271
network = self.network.to(device)
237272
network.eval()
238273

274+
default_config = TrainingConfig()
275+
batch_size = int(self.params.get("batch_size", default_config.batch_size))
276+
dataset = CIFAR10Dataset(X, y)
277+
data_loader = DataLoader(
278+
dataset,
279+
batch_size=batch_size,
280+
shuffle=False,
281+
num_workers=num_workers,
282+
pin_memory=torch.cuda.is_available(),
283+
)
284+
239285
criterion = criterion or nn.CrossEntropyLoss()
240286

241287
total_loss = 0.0

scripts/train_cnn.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88

99
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
1010
from framework.data_utils import (
11-
create_dataloaders,
1211
load_cifar10_data,
1312
prepare_data,
1413
split_train_val,
1514
)
1615
from framework.utils import get_device, test_pytorch_setup
17-
from models.cnn import CNNModel, TrainingConfig
16+
from models.cnn import CNNModel
1817

1918

2019
def parse_args():
@@ -114,34 +113,32 @@ def train_model(args, writer: SummaryWriter):
114113
)
115114
print(f"Train samples: {len(X_train)}, Val samples: {len(X_val)}")
116115

117-
train_loader, val_loader = create_dataloaders(
116+
model = CNNModel(num_classes=num_classes)
117+
# Pass hyperparameters via create_model (stored in model.params)
118+
model.create_model(
119+
learning_rate=args.lr,
120+
weight_decay=args.weight_decay,
121+
optimizer=args.optimizer,
122+
batch_size=args.batch_size,
123+
)
124+
125+
test_pytorch_setup()
126+
# train() creates DataLoaders internally using batch_size from model.params
127+
results = model.train(
118128
X_train,
119129
y_train,
120130
X_val,
121131
y_val,
122-
batch_size=args.batch_size,
123-
num_workers=args.num_workers,
124-
)
125-
126-
model = CNNModel(num_classes=num_classes)
127-
model.create_model()
128-
129-
config = TrainingConfig(
132+
device=device,
130133
epochs=args.epochs,
131-
learning_rate=args.lr,
132-
weight_decay=args.weight_decay,
133-
optimizer=args.optimizer,
134134
patience=args.patience,
135135
min_delta=args.min_delta,
136136
checkpoint_path=Path(args.checkpoint_path),
137137
grad_clip_norm=args.grad_clip,
138-
batch_size=args.batch_size,
139138
writer=writer,
139+
num_workers=args.num_workers,
140140
)
141141

142-
test_pytorch_setup()
143-
results = model.train(train_loader, val_loader, config=config, device=device)
144-
145142
print("\nTraining complete!")
146143
print(
147144
f"Best val acc: {results['best_val_acc']:.4f} ({results['best_val_acc'] * 100:.2f}%)"

0 commit comments

Comments
 (0)