Skip to content

Commit ae467c3

Browse files
committed
Merge commit from origin/experiment-runner
1 parent 5844c3c commit ae467c3

3 files changed

Lines changed: 85 additions & 38 deletions

File tree

models/cnn.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass
44
from pathlib import Path
5-
from typing import Dict, List, Optional
5+
from typing import Dict, List, Optional, List
66

77
import numpy as np
88
import torch
@@ -24,6 +24,9 @@
2424
from .ParamSpace import ParamSpace
2525
from .base import BaseModel
2626

27+
import numpy as np
28+
from torch import tensor
29+
2730
MODEL_PATH = Path(".cache/models/cnn_cifar.pth")
2831

2932

@@ -89,6 +92,7 @@ def __init__(
8992
nn.Linear(64, num_classes),
9093
)
9194

95+
# Will be run by PyTorch under the hood
9296
def forward(self, x: torch.Tensor) -> torch.Tensor:
9397
x = self.features(x)
9498
return self.classifier(x)
@@ -104,13 +108,25 @@ def __init__(self, num_classes: int = 10) -> None:
104108
self._input_channels = 1 # grayscale CIFAR-10
105109

106110
def create_model(self, **params) -> None:
111+
"""Create the CNN model with given parameters."""
112+
# Extract architecture-specific parameters for Backbone creation
113+
kernel_size = params.get("kernel_size", 3)
114+
stride = params.get("stride", 1)
115+
learning_rate = params.get("learning_rate", 3e-4)
116+
batch_size = params.get("batch_size", 64)
117+
weight_decay = params.get("weight_decay", 1e-3)
118+
optimizer = params.get("optimizer", "AdamW")
119+
120+
# Also ensure default values are set if not provided
121+
self.params.setdefault("kernel_size", kernel_size)
122+
self.params.setdefault("stride", stride)
123+
self.params.setdefault("learning_rate", learning_rate)
124+
self.params.setdefault("batch_size", batch_size)
125+
self.params.setdefault("weight_decay", weight_decay)
126+
self.params.setdefault("optimizer", optimizer)
127+
107128
# Store all parameters passed in
108129
self.params.update(params)
109-
110-
# Extract architecture-specific parameters for Backbone creation
111-
kernel_size = self.params.get("kernel_size", 3)
112-
stride = self.params.get("stride", 1)
113-
114130
self.network = Backbone(
115131
in_channels=self._input_channels,
116132
num_classes=self.num_classes,
@@ -120,10 +136,9 @@ def create_model(self, **params) -> None:
120136

121137
def train(
122138
self,
123-
X_train: List[np.ndarray],
124-
y_train: np.ndarray,
125-
X_val: List[np.ndarray],
126-
y_val: np.ndarray,
139+
train_loader: DataLoader,
140+
val_loader: DataLoader,
141+
config: Optional[TrainingConfig] = None,
127142
device: Optional[torch.device] = None,
128143
epochs: Optional[int] = None,
129144
patience: Optional[int] = None,
@@ -139,29 +154,7 @@ def train(
139154
device = device or get_device()
140155
self.network = self.network.to(device)
141156

142-
default_config = TrainingConfig()
143-
config = TrainingConfig(
144-
learning_rate=float(self.params.get("learning_rate", default_config.learning_rate)),
145-
weight_decay=float(self.params.get("weight_decay", default_config.weight_decay)),
146-
optimizer=self.params.get("optimizer", default_config.optimizer),
147-
batch_size=int(self.params.get("batch_size", default_config.batch_size)),
148-
# Infrastructure params: use provided values or defaults
149-
epochs=epochs if epochs is not None else default_config.epochs,
150-
patience=patience if patience is not None else default_config.patience,
151-
min_delta=min_delta if min_delta is not None else default_config.min_delta,
152-
checkpoint_path=checkpoint_path if checkpoint_path is not None else default_config.checkpoint_path,
153-
grad_clip_norm=grad_clip_norm if grad_clip_norm is not None else default_config.grad_clip_norm,
154-
writer=writer if writer is not None else default_config.writer,
155-
)
156-
157-
train_loader, val_loader = create_dataloaders(
158-
X_train,
159-
y_train,
160-
X_val,
161-
y_val,
162-
batch_size=config.batch_size,
163-
num_workers=num_workers,
164-
)
157+
config = config or TrainingConfig()
165158

166159
optimizer = self._build_optimizer(self.network, config)
167160
scheduler = optim.lr_scheduler.OneCycleLR(
@@ -256,7 +249,8 @@ def train(
256249

257250
def predict(
258251
self,
259-
data_loader: DataLoader,
252+
data: DataLoader | List | np.ndarray,
253+
labels: List | np.ndarray = None,
260254
device: Optional[torch.device] = None,
261255
return_probabilities: bool = False,
262256
) -> torch.Tensor:
@@ -266,6 +260,19 @@ def predict(
266260
network = self.network.to(device)
267261
network.eval()
268262

263+
# Handle data input - create DataLoader if raw data is provided
264+
if isinstance(data, DataLoader):
265+
data_loader = data
266+
else:
267+
# Raw data provided - create DataLoader
268+
from framework.data_utils import create_dataloaders
269+
if labels is None:
270+
labels = [0] * len(data) # dummy labels for prediction
271+
batch_size = getattr(self, 'params', {}).get('batch_size', 64)
272+
_, data_loader = create_dataloaders(
273+
data[:1], [labels[0]], data, labels, batch_size=batch_size
274+
)
275+
269276
outputs = []
270277
with torch.no_grad():
271278
for images, _ in data_loader:
@@ -279,8 +286,7 @@ def predict(
279286

280287
def evaluate(
281288
self,
282-
X: List[np.ndarray],
283-
y: np.ndarray,
289+
data_loader: DataLoader,
284290
device: Optional[torch.device] = None,
285291
criterion: Optional[nn.Module] = None,
286292
num_workers: int = 0,
@@ -302,6 +308,19 @@ def evaluate(
302308
pin_memory=torch.cuda.is_available(),
303309
)
304310

311+
# Handle data input - create DataLoader if raw data is provided
312+
if isinstance(data, DataLoader):
313+
data_loader = data
314+
else:
315+
# Raw data provided - create DataLoader
316+
from framework.data_utils import create_dataloaders
317+
if labels is None:
318+
raise ValueError("labels must be provided when data is not a DataLoader")
319+
batch_size = getattr(self, 'params', {}).get('batch_size', 64)
320+
_, data_loader = create_dataloaders(
321+
data[:1], [labels[0]], data, labels, batch_size=batch_size
322+
)
323+
305324
criterion = criterion or nn.CrossEntropyLoss()
306325

307326
total_loss = 0.0

models/decision_tree.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,22 @@ def evaluate(self, X_test, y_test) -> Dict[str, float]:
7070
"precision_weighted": report["weighted avg"]["precision"],
7171
"recall_weighted": report["weighted avg"]["recall"],
7272
"f1_weighted": report["weighted avg"]["f1-score"],
73-
"roc_auc_weighted": roc_auc_score(y_test, proba, average="weighted", multi_class="ovr"),
7473
}
7574

75+
if hasattr(self.estimator, "predict_proba"):
76+
try:
77+
proba = self.estimator.predict_proba(X_test)
78+
if proba.ndim == 2 and proba.shape[1] > 1:
79+
# Check if we have enough classes for ROC AUC calculation
80+
unique_classes = len(set(y_test))
81+
if unique_classes >= 2:
82+
metrics["roc_auc_weighted"] = roc_auc_score(
83+
y_test, proba, average="weighted", multi_class="ovr"
84+
)
85+
except (ValueError, Exception) as e:
86+
# ROC AUC calculation failed, skip it
87+
pass
88+
7689
return metrics
7790

7891
def get_param_space(self) -> Dict[str, ParamSpace]:

models/knn.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,22 @@ def evaluate(self, X_test, y_test) -> Dict[str, float]:
6969
"recall_weighted": report["weighted avg"]["recall"],
7070
"f1_weighted": report["weighted avg"]["f1-score"],
7171
}
72-
72+
73+
# Add ROC AUC if possible
74+
if hasattr(self.estimator, "predict_proba"):
75+
try:
76+
proba = self.estimator.predict_proba(X_test)
77+
if proba.ndim == 2 and proba.shape[1] > 1:
78+
# Check if we have enough classes for ROC AUC calculation
79+
unique_classes = len(set(y_test))
80+
if unique_classes >= 2:
81+
metrics["roc_auc_weighted"] = roc_auc_score(
82+
y_test, proba, average="weighted", multi_class="ovr"
83+
)
84+
except (ValueError, Exception) as e:
85+
# ROC AUC calculation failed, skip it
86+
pass
87+
7388
return metrics
7489

7590
def get_param_space(self) -> Dict[str, ParamSpace]:

0 commit comments

Comments
 (0)