Skip to content

Commit f66cdd8

Browse files
committed
Update model signatures
1 parent 865ab1b commit f66cdd8

9 files changed

Lines changed: 224 additions & 427 deletions

File tree

framework/data_utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from datasets import load_from_disk
77
from sklearn.model_selection import train_test_split
88
from torch.utils.data import DataLoader
9-
from PIL import Image
109

1110
from framework import utils
1211
from framework.datasets import CIFAR10Dataset
@@ -116,14 +115,12 @@ def create_dataloaders(
116115
X_val: List[np.ndarray],
117116
y_val: np.ndarray,
118117
batch_size: int,
119-
num_workers: int = 2,
120118
) -> Tuple[DataLoader, DataLoader]:
121119
train_dataset = CIFAR10Dataset(X_train, y_train)
122120
train_loader = DataLoader(
123121
train_dataset,
124122
batch_size=batch_size,
125123
shuffle=True,
126-
num_workers=num_workers,
127124
pin_memory=utils.is_cuda_available(),
128125
)
129126

@@ -132,7 +129,6 @@ def create_dataloaders(
132129
val_dataset,
133130
batch_size=batch_size,
134131
shuffle=False,
135-
num_workers=num_workers,
136132
pin_memory=utils.is_cuda_available(),
137133
)
138134

framework/fitness.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
def calculate_composite_fitness(metrics: dict[str, float]) -> float:
22
"""Calculate composite fitness score from evaluation metrics."""
33
# Extract metrics
4-
f1_macro = metrics.get("f1_macro", 0.0)
5-
recall_macro = metrics.get("recall_macro", 0.0)
6-
roc_auc = metrics.get("roc_auc", 0.0)
7-
precision_macro = metrics.get("precision_macro", 0.0)
8-
accuracy = metrics.get("accuracy", 0.0)
9-
f1_micro = metrics.get("f1_micro", 0.0)
4+
f1_macro = metrics["f1_macro"]
5+
recall_macro = metrics["recall_macro"]
6+
roc_auc = metrics["roc_auc"]
7+
precision_macro = metrics["precision_macro"]
8+
accuracy = metrics["accuracy"]
9+
f1_micro = metrics["f1_micro"]
1010

1111
# Composite fitness
1212
composite_fitness = (

models/cnn.py

Lines changed: 59 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""CNN model"""
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, replace
44
from pathlib import Path
5-
from typing import Dict, List, Optional, List
5+
from typing import Dict, List, Optional
66

77
import numpy as np
88
import torch
@@ -24,8 +24,6 @@
2424
from .ParamSpace import ParamSpace
2525
from .base import BaseModel
2626

27-
import numpy as np
28-
from torch import tensor
2927

3028
MODEL_PATH = Path(".cache/models/cnn_cifar.pth")
3129

@@ -92,7 +90,6 @@ def __init__(
9290
nn.Linear(64, num_classes),
9391
)
9492

95-
# Will be run by PyTorch under the hood
9693
def forward(self, x: torch.Tensor) -> torch.Tensor:
9794
x = self.features(x)
9895
return self.classifier(x)
@@ -108,36 +105,24 @@ def __init__(self, num_classes: int = 10) -> None:
108105
self._input_channels = 1 # grayscale CIFAR-10
109106

110107
def create_model(self, **params) -> None:
111-
"""Create the CNN model with given parameters."""
112-
# Extract architecture-specific parameters for Backbone creation
113-
kernel_size = params.get("kernel_size", 3)
114-
stride = params.get("stride", 1)
115-
learning_rate = params.get("learning_rate", 3e-4)
116-
batch_size = params.get("batch_size", 64)
117-
weight_decay = params.get("weight_decay", 1e-3)
118-
optimizer = params.get("optimizer", "AdamW")
119-
120-
# Also ensure default values are set if not provided
121-
self.params.setdefault("kernel_size", kernel_size)
122-
self.params.setdefault("stride", stride)
123-
self.params.setdefault("learning_rate", learning_rate)
124-
self.params.setdefault("batch_size", batch_size)
125-
self.params.setdefault("weight_decay", weight_decay)
126-
self.params.setdefault("optimizer", optimizer)
127-
108+
"""Create the CNN model with given parameters."""
128109
# Store all parameters passed in
129110
self.params.update(params)
111+
112+
# Create Backbone with architecture-specific parameters
130113
self.network = Backbone(
131114
in_channels=self._input_channels,
132115
num_classes=self.num_classes,
133-
kernel_size=kernel_size,
134-
stride=stride,
116+
kernel_size=self.params.get("kernel_size", 3),
117+
stride=self.params.get("stride", 1),
135118
)
136119

137120
def train(
138121
self,
139-
train_data, # Can be DataLoader or raw data
140-
val_data, # Can be DataLoader or raw data
122+
X_train: List[np.ndarray],
123+
y_train: np.ndarray,
124+
X_val: List[np.ndarray],
125+
y_val: np.ndarray,
141126
config: Optional[TrainingConfig] = None,
142127
device: Optional[torch.device] = None,
143128
verbose: bool = True,
@@ -147,28 +132,25 @@ def train(
147132
device = device or get_device()
148133
self.network = self.network.to(device)
149134

150-
config = config or TrainingConfig()
135+
# Create config from params if not provided, using defaults where needed
136+
if config is None:
137+
default_config = TrainingConfig()
138+
config = replace(
139+
default_config,
140+
learning_rate=float(self.params.get("learning_rate", default_config.learning_rate)),
141+
weight_decay=float(self.params.get("weight_decay", default_config.weight_decay)),
142+
optimizer=self.params.get("optimizer", default_config.optimizer),
143+
batch_size=int(self.params.get("batch_size", default_config.batch_size)),
144+
)
151145

152-
# Handle different input types - convert to DataLoaders if needed
153-
if isinstance(train_data, DataLoader) and isinstance(val_data, DataLoader):
154-
train_loader = train_data
155-
val_loader = val_data
156-
else:
157-
# Raw data provided - create DataLoaders
158-
from framework.data_utils import create_dataloaders
159-
if hasattr(train_data, '__len__') and hasattr(val_data, '__len__'):
160-
# Assume train_data is X_train, val_data is y_train for backwards compatibility
161-
if len(train_data) > 0 and not isinstance(train_data[0], (int, float)):
162-
# This looks like X_train, y_train, X_test, y_test pattern
163-
# Need to extract from the calling pattern
164-
batch_size = getattr(config, 'batch_size', 64)
165-
train_loader, val_loader = create_dataloaders(
166-
train_data, val_data, train_data[:len(val_data)], val_data, batch_size=batch_size
167-
)
168-
else:
169-
raise ValueError("Invalid data format provided to CNN.train()")
170-
else:
171-
raise ValueError("Invalid data format provided to CNN.train()")
146+
# Create DataLoaders from X_train/y_train/X_val/y_val
147+
train_loader, val_loader = create_dataloaders(
148+
X_train,
149+
y_train,
150+
X_val,
151+
y_val,
152+
batch_size=config.batch_size,
153+
)
172154

173155
optimizer = self._build_optimizer(self.network, config)
174156
scheduler = optim.lr_scheduler.OneCycleLR(
@@ -263,8 +245,7 @@ def train(
263245

264246
def predict(
265247
self,
266-
data: DataLoader | List | np.ndarray,
267-
labels: List | np.ndarray = None,
248+
X_test: List[np.ndarray],
268249
device: Optional[torch.device] = None,
269250
return_probabilities: bool = False,
270251
) -> torch.Tensor:
@@ -274,18 +255,16 @@ def predict(
274255
network = self.network.to(device)
275256
network.eval()
276257

277-
# Handle data input - create DataLoader if raw data is provided
278-
if isinstance(data, DataLoader):
279-
data_loader = data
280-
else:
281-
# Raw data provided - create DataLoader
282-
from framework.data_utils import create_dataloaders
283-
if labels is None:
284-
labels = [0] * len(data) # dummy labels for prediction
285-
batch_size = getattr(self, 'params', {}).get('batch_size', 64)
286-
_, data_loader = create_dataloaders(
287-
data[:1], [labels[0]], data, labels, batch_size=batch_size
288-
)
258+
# Create DataLoader from X_test (with dummy labels for prediction)
259+
batch_size = int(self.params.get("batch_size", 128))
260+
dummy_labels = np.zeros(len(X_test), dtype=np.int64)
261+
dataset = CIFAR10Dataset(X_test, dummy_labels)
262+
data_loader = DataLoader(
263+
dataset,
264+
batch_size=batch_size,
265+
shuffle=False,
266+
pin_memory=torch.cuda.is_available(),
267+
)
289268

290269
outputs = []
291270
with torch.no_grad():
@@ -300,30 +279,26 @@ def predict(
300279

301280
def evaluate(
302281
self,
303-
data, # Can be DataLoader or raw data
304-
labels=None, # Required if data is not DataLoader
282+
X_test: List[np.ndarray],
283+
y_test: np.ndarray,
305284
device: Optional[torch.device] = None,
306285
criterion: Optional[nn.Module] = None,
307-
num_workers: int = 0,
308286
) -> Dict[str, float]:
309287
if self.network is None:
310288
raise RuntimeError("Evaluate called before model is initialized")
311289
device = device or get_device()
312290
network = self.network.to(device)
313291
network.eval()
314292

315-
# Handle data input - create DataLoader if raw data is provided
316-
if isinstance(data, DataLoader):
317-
data_loader = data
318-
else:
319-
# Raw data provided - create DataLoader
320-
from framework.data_utils import create_dataloaders
321-
if labels is None:
322-
raise ValueError("labels must be provided when data is not a DataLoader")
323-
batch_size = getattr(self, 'params', {}).get('batch_size', 64)
324-
_, data_loader = create_dataloaders(
325-
data[:1], [labels[0]], data, labels, batch_size=batch_size
326-
)
293+
# Create DataLoader from X_test/y_test
294+
batch_size = int(self.params.get("batch_size", 128))
295+
dataset = CIFAR10Dataset(X_test, y_test)
296+
data_loader = DataLoader(
297+
dataset,
298+
batch_size=batch_size,
299+
shuffle=False,
300+
pin_memory=torch.cuda.is_available(),
301+
)
327302

328303
criterion = criterion or nn.CrossEntropyLoss()
329304

@@ -359,28 +334,21 @@ def evaluate(
359334
precision_macro = report["macro avg"]["precision"]
360335
recall_macro = report["macro avg"]["recall"]
361336
f1_macro = report["macro avg"]["f1-score"]
362-
f1_micro = report.get("micro avg", {}).get("f1-score", f1_score(y_true, y_pred, average="micro", zero_division=0))
337+
f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
363338

364-
# Initialize metrics without ROC AUC first
365-
metrics = {
366-
"loss": total_loss / len(data_loader), # average loss
339+
roc_auc = roc_auc_score(y_true, y_proba, average="macro", multi_class="ovr")
340+
341+
avg_loss = total_loss / len(data_loader)
342+
343+
return {
344+
"loss": avg_loss,
367345
"accuracy": accuracy,
368346
"precision_macro": precision_macro,
369347
"recall_macro": recall_macro,
370348
"f1_macro": f1_macro,
371349
"f1_micro": f1_micro,
350+
"roc_auc": roc_auc,
372351
}
373-
374-
# Try to calculate ROC AUC, but handle potential errors gracefully
375-
try:
376-
roc_auc = roc_auc_score(y_true, y_proba, average="macro", multi_class="ovr")
377-
metrics["roc_auc"] = roc_auc
378-
except ValueError as e:
379-
# ROC AUC calculation failed (likely due to insufficient samples per class)
380-
# Continue without ROC AUC metric
381-
print(f"Warning: Could not calculate ROC AUC: {e}")
382-
383-
return metrics
384352

385353
def get_param_space(self) -> Dict[str, ParamSpace]:
386354
return {

models/decision_tree.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, List
22

3+
import numpy as np
34
from sklearn.metrics import classification_report, f1_score, roc_auc_score
45
from sklearn.tree import DecisionTreeClassifier
56
from sklearn.utils.validation import check_is_fitted
@@ -20,23 +21,23 @@ def create_model(self, **params: Any) -> None:
2021
self.params.update(params)
2122
self.estimator = DecisionTreeClassifier(**self.params)
2223

23-
def train(self, X_train, y_train) -> DecisionTreeClassifier:
24+
def train(self, X_train: List[np.ndarray], y_train: np.ndarray) -> DecisionTreeClassifier:
2425
if self.estimator is None:
2526
self.create_model()
2627
estimator = self.estimator
2728
assert estimator is not None
2829
estimator.fit(X_train, y_train)
2930
return estimator
3031

31-
def predict(self, X):
32+
def predict(self, X: List[np.ndarray]):
3233
if self.estimator is None:
3334
raise RuntimeError(
3435
"Estimator has not been created. Call create_model() first."
3536
)
3637
check_is_fitted(self.estimator)
3738
return self.estimator.predict(X)
3839

39-
def predict_proba(self, X):
40+
def predict_proba(self, X: List[np.ndarray]):
4041
if self.estimator is None:
4142
raise RuntimeError(
4243
"Estimator has not been created. Call create_model() first."
@@ -48,7 +49,7 @@ def predict_proba(self, X):
4849
check_is_fitted(self.estimator)
4950
return self.estimator.predict_proba(X)
5051

51-
def evaluate(self, X_test, y_test) -> Dict[str, float]:
52+
def evaluate(self, X_test: List[np.ndarray], y_test: np.ndarray) -> Dict[str, float]:
5253
if self.estimator is None:
5354
raise RuntimeError(
5455
"Estimator has not been created. Call create_model() first."
@@ -57,37 +58,21 @@ def evaluate(self, X_test, y_test) -> Dict[str, float]:
5758
report = classification_report(
5859
y_test, predictions, output_dict=True, zero_division=0
5960
)
61+
62+
proba = self.estimator.predict_proba(X_test)
6063

61-
# Initialize metrics without ROC AUC first
6264
metrics: Dict[str, float] = {
6365
"accuracy": report["accuracy"],
6466
"precision_macro": report["macro avg"]["precision"],
6567
"recall_macro": report["macro avg"]["recall"],
6668
"f1_macro": report["macro avg"]["f1-score"],
67-
"f1_micro": report.get("micro avg", {}).get("f1-score", f1_score(y_test, predictions, average="micro", zero_division=0)),
69+
"f1_micro": f1_score(y_test, predictions, average="micro", zero_division=0),
6870
"precision_weighted": report["weighted avg"]["precision"],
6971
"recall_weighted": report["weighted avg"]["recall"],
7072
"f1_weighted": report["weighted avg"]["f1-score"],
73+
"roc_auc": roc_auc_score(y_test, proba, average="macro", multi_class="ovr"),
7174
}
7275

73-
# Add ROC AUC if possible (with proper error handling)
74-
if hasattr(self.estimator, "predict_proba"):
75-
try:
76-
proba = self.estimator.predict_proba(X_test)
77-
if proba.ndim == 2 and proba.shape[1] > 1:
78-
# Check if we have enough classes for ROC AUC calculation
79-
unique_classes = len(set(y_test))
80-
if unique_classes >= 2 and proba.shape[1] == len(set(y_test)):
81-
metrics["roc_auc"] = roc_auc_score(
82-
y_test, proba, average="macro", multi_class="ovr"
83-
)
84-
metrics["roc_auc_weighted"] = roc_auc_score(
85-
y_test, proba, average="weighted", multi_class="ovr"
86-
)
87-
except (ValueError, Exception) as e:
88-
# ROC AUC calculation failed, skip it
89-
pass
90-
9176
return metrics
9277

9378
def get_param_space(self) -> Dict[str, ParamSpace]:

0 commit comments

Comments
 (0)