Skip to content

Commit 865ab1b

Browse files
committed
Merge remote-tracking branch 'origin/master' into pso-and-experiments
2 parents 80cd743 + 23bb979 commit 865ab1b

5 files changed

Lines changed: 1150 additions & 860 deletions

File tree

framework/data_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,72 @@
66
from datasets import load_from_disk
77
from sklearn.model_selection import train_test_split
88
from torch.utils.data import DataLoader
9+
from PIL import Image
910

1011
from framework import utils
1112
from framework.datasets import CIFAR10Dataset
1213

1314

15+
def convert_to_grayscale(image: np.ndarray) -> np.ndarray:
16+
"""Convert RGB/RGBA image to grayscale.
17+
18+
Args:
19+
image: Image array with shape (H, W, C) where C is 3 (RGB) or 4 (RGBA)
20+
21+
Returns:
22+
Grayscale image with shape (H, W)
23+
"""
24+
if len(image.shape) == 2:
25+
# Already grayscale
26+
return image
27+
elif len(image.shape) == 3:
28+
if image.shape[2] == 1:
29+
# Single channel, just squeeze
30+
return image.squeeze(axis=2)
31+
elif image.shape[2] == 3:
32+
# RGB -> Grayscale using luminance weights
33+
# Y = 0.2125R + 0.7154G + 0.0721B
34+
return np.dot(image[...,:3], [0.2125, 0.7154, 0.0721])
35+
elif image.shape[2] == 4:
36+
# RGBA -> Grayscale (ignore alpha)
37+
return np.dot(image[...,:3], [0.2125, 0.7154, 0.0721])
38+
39+
raise ValueError(f"Unsupported image shape: {image.shape}")
40+
41+
42+
def preprocess_images_to_grayscale(images: List[np.ndarray]) -> List[np.ndarray]:
43+
"""Convert a list of images to grayscale.
44+
45+
Args:
46+
images: List of image arrays
47+
48+
Returns:
49+
List of grayscale image arrays
50+
"""
51+
return [convert_to_grayscale(img) for img in images]
52+
53+
54+
def convert_dataset_to_grayscale(dataset):
55+
"""Convert HuggingFace dataset images to grayscale in-place preprocessing.
56+
57+
Args:
58+
dataset: HuggingFace dataset with 'image' column
59+
60+
Returns:
61+
List of grayscale images and labels
62+
"""
63+
images = []
64+
labels = []
65+
66+
for item in dataset:
67+
img = np.array(item['image'])
68+
gray_img = convert_to_grayscale(img)
69+
images.append(gray_img)
70+
labels.append(item['label'])
71+
72+
return images, np.array(labels)
73+
74+
1475
def load_cifar10_data():
1576
"""Load CIFAR-10 dataset (grayscale from processed datasets)."""
1677
repo_root = Path(__file__).resolve().parents[1]

models/cnn.py

Lines changed: 90 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass
44
from pathlib import Path
5-
from typing import Dict, List, Optional
5+
from typing import Dict, List, Optional, List
66

77
import numpy as np
88
import torch
@@ -24,6 +24,9 @@
2424
from .ParamSpace import ParamSpace
2525
from .base import BaseModel
2626

27+
import numpy as np
28+
from torch import tensor
29+
2730
MODEL_PATH = Path(".cache/models/cnn_cifar.pth")
2831

2932

@@ -89,6 +92,7 @@ def __init__(
8992
nn.Linear(64, num_classes),
9093
)
9194

95+
# Will be run by PyTorch under the hood
9296
def forward(self, x: torch.Tensor) -> torch.Tensor:
9397
x = self.features(x)
9498
return self.classifier(x)
@@ -104,13 +108,25 @@ def __init__(self, num_classes: int = 10) -> None:
104108
self._input_channels = 1 # grayscale CIFAR-10
105109

106110
def create_model(self, **params) -> None:
111+
"""Create the CNN model with given parameters."""
112+
# Extract architecture-specific parameters for Backbone creation
113+
kernel_size = params.get("kernel_size", 3)
114+
stride = params.get("stride", 1)
115+
learning_rate = params.get("learning_rate", 3e-4)
116+
batch_size = params.get("batch_size", 64)
117+
weight_decay = params.get("weight_decay", 1e-3)
118+
optimizer = params.get("optimizer", "AdamW")
119+
120+
# Also ensure default values are set if not provided
121+
self.params.setdefault("kernel_size", kernel_size)
122+
self.params.setdefault("stride", stride)
123+
self.params.setdefault("learning_rate", learning_rate)
124+
self.params.setdefault("batch_size", batch_size)
125+
self.params.setdefault("weight_decay", weight_decay)
126+
self.params.setdefault("optimizer", optimizer)
127+
107128
# Store all parameters passed in
108129
self.params.update(params)
109-
110-
# Extract architecture-specific parameters for Backbone creation
111-
kernel_size = self.params.get("kernel_size", 3)
112-
stride = self.params.get("stride", 1)
113-
114130
self.network = Backbone(
115131
in_channels=self._input_channels,
116132
num_classes=self.num_classes,
@@ -120,48 +136,39 @@ def create_model(self, **params) -> None:
120136

121137
def train(
122138
self,
123-
X_train: List[np.ndarray],
124-
y_train: np.ndarray,
125-
X_val: List[np.ndarray],
126-
y_val: np.ndarray,
139+
train_data, # Can be DataLoader or raw data
140+
val_data, # Can be DataLoader or raw data
141+
config: Optional[TrainingConfig] = None,
127142
device: Optional[torch.device] = None,
128-
epochs: Optional[int] = None,
129-
patience: Optional[int] = None,
130-
min_delta: Optional[float] = None,
131-
checkpoint_path: Optional[Path] = None,
132-
grad_clip_norm: Optional[float] = None,
133-
writer: Optional[SummaryWriter] = None,
134-
num_workers: int = 2,
135143
verbose: bool = True,
136144
) -> Dict[str, float]:
137145
if self.network is None:
138146
raise RuntimeError("Train called before model is initialized")
139147
device = device or get_device()
140148
self.network = self.network.to(device)
141149

142-
default_config = TrainingConfig()
143-
config = TrainingConfig(
144-
learning_rate=float(self.params.get("learning_rate", default_config.learning_rate)),
145-
weight_decay=float(self.params.get("weight_decay", default_config.weight_decay)),
146-
optimizer=self.params.get("optimizer", default_config.optimizer),
147-
batch_size=int(self.params.get("batch_size", default_config.batch_size)),
148-
# Infrastructure params: use provided values or defaults
149-
epochs=epochs if epochs is not None else default_config.epochs,
150-
patience=patience if patience is not None else default_config.patience,
151-
min_delta=min_delta if min_delta is not None else default_config.min_delta,
152-
checkpoint_path=checkpoint_path if checkpoint_path is not None else default_config.checkpoint_path,
153-
grad_clip_norm=grad_clip_norm if grad_clip_norm is not None else default_config.grad_clip_norm,
154-
writer=writer if writer is not None else default_config.writer,
155-
)
156-
157-
train_loader, val_loader = create_dataloaders(
158-
X_train,
159-
y_train,
160-
X_val,
161-
y_val,
162-
batch_size=config.batch_size,
163-
num_workers=num_workers,
164-
)
150+
config = config or TrainingConfig()
151+
152+
# Handle different input types - convert to DataLoaders if needed
153+
if isinstance(train_data, DataLoader) and isinstance(val_data, DataLoader):
154+
train_loader = train_data
155+
val_loader = val_data
156+
else:
157+
# Raw data provided - create DataLoaders
158+
from framework.data_utils import create_dataloaders
159+
if hasattr(train_data, '__len__') and hasattr(val_data, '__len__'):
160+
# Assume train_data is X_train, val_data is y_train for backwards compatibility
161+
if len(train_data) > 0 and not isinstance(train_data[0], (int, float)):
162+
# This looks like X_train, y_train, X_test, y_test pattern
163+
# Need to extract from the calling pattern
164+
batch_size = getattr(config, 'batch_size', 64)
165+
train_loader, val_loader = create_dataloaders(
166+
train_data, val_data, train_data[:len(val_data)], val_data, batch_size=batch_size
167+
)
168+
else:
169+
raise ValueError("Invalid data format provided to CNN.train()")
170+
else:
171+
raise ValueError("Invalid data format provided to CNN.train()")
165172

166173
optimizer = self._build_optimizer(self.network, config)
167174
scheduler = optim.lr_scheduler.OneCycleLR(
@@ -256,7 +263,8 @@ def train(
256263

257264
def predict(
258265
self,
259-
data_loader: DataLoader,
266+
data: DataLoader | List | np.ndarray,
267+
labels: List | np.ndarray = None,
260268
device: Optional[torch.device] = None,
261269
return_probabilities: bool = False,
262270
) -> torch.Tensor:
@@ -266,6 +274,19 @@ def predict(
266274
network = self.network.to(device)
267275
network.eval()
268276

277+
# Handle data input - create DataLoader if raw data is provided
278+
if isinstance(data, DataLoader):
279+
data_loader = data
280+
else:
281+
# Raw data provided - create DataLoader
282+
from framework.data_utils import create_dataloaders
283+
if labels is None:
284+
labels = [0] * len(data) # dummy labels for prediction
285+
batch_size = getattr(self, 'params', {}).get('batch_size', 64)
286+
_, data_loader = create_dataloaders(
287+
data[:1], [labels[0]], data, labels, batch_size=batch_size
288+
)
289+
269290
outputs = []
270291
with torch.no_grad():
271292
for images, _ in data_loader:
@@ -279,8 +300,8 @@ def predict(
279300

280301
def evaluate(
281302
self,
282-
X: List[np.ndarray],
283-
y: np.ndarray,
303+
data, # Can be DataLoader or raw data
304+
labels=None, # Required if data is not DataLoader
284305
device: Optional[torch.device] = None,
285306
criterion: Optional[nn.Module] = None,
286307
num_workers: int = 0,
@@ -291,16 +312,18 @@ def evaluate(
291312
network = self.network.to(device)
292313
network.eval()
293314

294-
default_config = TrainingConfig()
295-
batch_size = int(self.params.get("batch_size", default_config.batch_size))
296-
dataset = CIFAR10Dataset(X, y)
297-
data_loader = DataLoader(
298-
dataset,
299-
batch_size=batch_size,
300-
shuffle=False,
301-
num_workers=num_workers,
302-
pin_memory=torch.cuda.is_available(),
303-
)
315+
# Handle data input - create DataLoader if raw data is provided
316+
if isinstance(data, DataLoader):
317+
data_loader = data
318+
else:
319+
# Raw data provided - create DataLoader
320+
from framework.data_utils import create_dataloaders
321+
if labels is None:
322+
raise ValueError("labels must be provided when data is not a DataLoader")
323+
batch_size = getattr(self, 'params', {}).get('batch_size', 64)
324+
_, data_loader = create_dataloaders(
325+
data[:1], [labels[0]], data, labels, batch_size=batch_size
326+
)
304327

305328
criterion = criterion or nn.CrossEntropyLoss()
306329

@@ -338,19 +361,26 @@ def evaluate(
338361
f1_macro = report["macro avg"]["f1-score"]
339362
f1_micro = report.get("micro avg", {}).get("f1-score", f1_score(y_true, y_pred, average="micro", zero_division=0))
340363

341-
roc_auc = roc_auc_score(y_true, y_proba, average="macro", multi_class="ovr")
342-
343-
avg_loss = total_loss / len(data_loader)
344-
345-
return {
346-
"loss": avg_loss,
364+
# Initialize metrics without ROC AUC first
365+
metrics = {
366+
"loss": total_loss / len(data_loader), # average loss
347367
"accuracy": accuracy,
348368
"precision_macro": precision_macro,
349369
"recall_macro": recall_macro,
350370
"f1_macro": f1_macro,
351371
"f1_micro": f1_micro,
352-
"roc_auc": roc_auc,
353372
}
373+
374+
# Try to calculate ROC AUC, but handle potential errors gracefully
375+
try:
376+
roc_auc = roc_auc_score(y_true, y_proba, average="macro", multi_class="ovr")
377+
metrics["roc_auc"] = roc_auc
378+
except ValueError as e:
379+
# ROC AUC calculation failed (likely due to insufficient samples per class)
380+
# Continue without ROC AUC metric
381+
print(f"Warning: Could not calculate ROC AUC: {e}")
382+
383+
return metrics
354384

355385
def get_param_space(self) -> Dict[str, ParamSpace]:
356386
return {

models/decision_tree.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,37 @@ def evaluate(self, X_test, y_test) -> Dict[str, float]:
5757
report = classification_report(
5858
y_test, predictions, output_dict=True, zero_division=0
5959
)
60-
61-
proba = self.estimator.predict_proba(X_test)
6260

61+
# Initialize metrics without ROC AUC first
6362
metrics: Dict[str, float] = {
6463
"accuracy": report["accuracy"],
6564
"precision_macro": report["macro avg"]["precision"],
6665
"recall_macro": report["macro avg"]["recall"],
6766
"f1_macro": report["macro avg"]["f1-score"],
6867
"f1_micro": report.get("micro avg", {}).get("f1-score", f1_score(y_test, predictions, average="micro", zero_division=0)),
69-
"roc_auc": roc_auc_score(y_test, proba, average="macro", multi_class="ovr"),
7068
"precision_weighted": report["weighted avg"]["precision"],
7169
"recall_weighted": report["weighted avg"]["recall"],
7270
"f1_weighted": report["weighted avg"]["f1-score"],
73-
"roc_auc_weighted": roc_auc_score(y_test, proba, average="weighted", multi_class="ovr"),
7471
}
7572

73+
# Add ROC AUC if possible (with proper error handling)
74+
if hasattr(self.estimator, "predict_proba"):
75+
try:
76+
proba = self.estimator.predict_proba(X_test)
77+
if proba.ndim == 2 and proba.shape[1] > 1:
78+
# Check if we have enough classes for ROC AUC calculation
79+
unique_classes = len(set(y_test))
80+
if unique_classes >= 2 and proba.shape[1] == len(set(y_test)):
81+
metrics["roc_auc"] = roc_auc_score(
82+
y_test, proba, average="macro", multi_class="ovr"
83+
)
84+
metrics["roc_auc_weighted"] = roc_auc_score(
85+
y_test, proba, average="weighted", multi_class="ovr"
86+
)
87+
except (ValueError, Exception) as e:
88+
# ROC AUC calculation failed, skip it
89+
pass
90+
7691
return metrics
7792

7893
def get_param_space(self) -> Dict[str, ParamSpace]:

models/knn.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,37 @@ def evaluate(self, X_test, y_test) -> Dict[str, float]:
5656
report = classification_report(
5757
y_test, predictions, output_dict=True, zero_division=0
5858
)
59-
proba = self.estimator.predict_proba(X_test)
6059

60+
# Initialize metrics without ROC AUC first
6161
metrics: Dict[str, float] = {
6262
"accuracy": report["accuracy"],
6363
"precision_macro": report["macro avg"]["precision"],
6464
"recall_macro": report["macro avg"]["recall"],
6565
"f1_macro": report["macro avg"]["f1-score"],
6666
"f1_micro": report.get("micro avg", {}).get("f1-score", f1_score(y_test, predictions, average="micro", zero_division=0)),
67-
"roc_auc": roc_auc_score(y_test, proba, average="macro", multi_class="ovr"),
6867
"precision_weighted": report["weighted avg"]["precision"],
6968
"recall_weighted": report["weighted avg"]["recall"],
7069
"f1_weighted": report["weighted avg"]["f1-score"],
7170
}
72-
71+
72+
# Add ROC AUC if possible (with proper error handling)
73+
if hasattr(self.estimator, "predict_proba"):
74+
try:
75+
proba = self.estimator.predict_proba(X_test)
76+
if proba.ndim == 2 and proba.shape[1] > 1:
77+
# Check if we have enough classes for ROC AUC calculation
78+
unique_classes = len(set(y_test))
79+
if unique_classes >= 2 and proba.shape[1] == len(set(y_test)):
80+
metrics["roc_auc"] = roc_auc_score(
81+
y_test, proba, average="macro", multi_class="ovr"
82+
)
83+
metrics["roc_auc_weighted"] = roc_auc_score(
84+
y_test, proba, average="weighted", multi_class="ovr"
85+
)
86+
except (ValueError, Exception) as e:
87+
# ROC AUC calculation failed, skip it
88+
pass
89+
7390
return metrics
7491

7592
def get_param_space(self) -> Dict[str, ParamSpace]:

0 commit comments

Comments
 (0)