Skip to content

Commit 680c6e3

Browse files
committed
Flexibly handle different data type before training, added "loss" metric to CNN model, added "ROC" (one-to-rest) metric to decision tree model and KNN model.
1 parent b867e85 commit 680c6e3

3 files changed

Lines changed: 51 additions & 36 deletions

File tree

models/cnn.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,10 @@ def create_model(self, **params) -> None:
136136

137137
def train(
138138
self,
139-
train_loader: DataLoader,
140-
val_loader: DataLoader,
139+
train_data, # Can be DataLoader or raw data
140+
val_data, # Can be DataLoader or raw data
141141
config: Optional[TrainingConfig] = None,
142142
device: Optional[torch.device] = None,
143-
epochs: Optional[int] = None,
144-
patience: Optional[int] = None,
145-
min_delta: Optional[float] = None,
146-
checkpoint_path: Optional[Path] = None,
147-
grad_clip_norm: Optional[float] = None,
148-
writer: Optional[SummaryWriter] = None,
149-
num_workers: int = 2,
150143
verbose: bool = True,
151144
) -> Dict[str, float]:
152145
if self.network is None:
@@ -155,6 +148,27 @@ def train(
155148
self.network = self.network.to(device)
156149

157150
config = config or TrainingConfig()
151+
152+
# Handle different input types - convert to DataLoaders if needed
153+
if isinstance(train_data, DataLoader) and isinstance(val_data, DataLoader):
154+
train_loader = train_data
155+
val_loader = val_data
156+
else:
157+
# Raw data provided - create DataLoaders
158+
from framework.data_utils import create_dataloaders
159+
if hasattr(train_data, '__len__') and hasattr(val_data, '__len__'):
160+
# Assume train_data is X_train, val_data is y_train for backwards compatibility
161+
if len(train_data) > 0 and not isinstance(train_data[0], (int, float)):
162+
# This looks like X_train, y_train, X_test, y_test pattern
163+
# Need to extract from the calling pattern
164+
batch_size = getattr(config, 'batch_size', 64)
165+
train_loader, val_loader = create_dataloaders(
166+
train_data, val_data, train_data[:len(val_data)], val_data, batch_size=batch_size
167+
)
168+
else:
169+
raise ValueError("Invalid data format provided to CNN.train()")
170+
else:
171+
raise ValueError("Invalid data format provided to CNN.train()")
158172

159173
optimizer = self._build_optimizer(self.network, config)
160174
scheduler = optim.lr_scheduler.OneCycleLR(
@@ -286,7 +300,8 @@ def predict(
286300

287301
def evaluate(
288302
self,
289-
data_loader: DataLoader,
303+
data, # Can be DataLoader or raw data
304+
labels=None, # Required if data is not DataLoader
290305
device: Optional[torch.device] = None,
291306
criterion: Optional[nn.Module] = None,
292307
num_workers: int = 0,
@@ -297,17 +312,6 @@ def evaluate(
297312
network = self.network.to(device)
298313
network.eval()
299314

300-
default_config = TrainingConfig()
301-
batch_size = int(self.params.get("batch_size", default_config.batch_size))
302-
dataset = CIFAR10Dataset(X, y)
303-
data_loader = DataLoader(
304-
dataset,
305-
batch_size=batch_size,
306-
shuffle=False,
307-
num_workers=num_workers,
308-
pin_memory=torch.cuda.is_available(),
309-
)
310-
311315
# Handle data input - create DataLoader if raw data is provided
312316
if isinstance(data, DataLoader):
313317
data_loader = data
@@ -357,19 +361,26 @@ def evaluate(
357361
f1_macro = report["macro avg"]["f1-score"]
358362
f1_micro = report.get("micro avg", {}).get("f1-score", f1_score(y_true, y_pred, average="micro", zero_division=0))
359363

360-
roc_auc = roc_auc_score(y_true, y_proba, average="macro", multi_class="ovr")
361-
362-
avg_loss = total_loss / len(data_loader)
363-
364-
return {
365-
"loss": avg_loss,
364+
# Initialize metrics without ROC AUC first
365+
metrics = {
366+
"loss": total_loss / len(data_loader), # average loss
366367
"accuracy": accuracy,
367368
"precision_macro": precision_macro,
368369
"recall_macro": recall_macro,
369370
"f1_macro": f1_macro,
370371
"f1_micro": f1_micro,
371-
"roc_auc": roc_auc,
372372
}
373+
374+
# Try to calculate ROC AUC, but handle potential errors gracefully
375+
try:
376+
roc_auc = roc_auc_score(y_true, y_proba, average="macro", multi_class="ovr")
377+
metrics["roc_auc"] = roc_auc
378+
except ValueError as e:
379+
# ROC AUC calculation failed (likely due to insufficient samples per class)
380+
# Continue without ROC AUC metric
381+
print(f"Warning: Could not calculate ROC AUC: {e}")
382+
383+
return metrics
373384

374385
def get_param_space(self) -> Dict[str, ParamSpace]:
375386
return {

models/decision_tree.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,30 @@ def evaluate(self, X_test, y_test) -> Dict[str, float]:
5757
report = classification_report(
5858
y_test, predictions, output_dict=True, zero_division=0
5959
)
60-
61-
proba = self.estimator.predict_proba(X_test)
6260

61+
# Initialize metrics without ROC AUC first
6362
metrics: Dict[str, float] = {
6463
"accuracy": report["accuracy"],
6564
"precision_macro": report["macro avg"]["precision"],
6665
"recall_macro": report["macro avg"]["recall"],
6766
"f1_macro": report["macro avg"]["f1-score"],
6867
"f1_micro": report.get("micro avg", {}).get("f1-score", f1_score(y_test, predictions, average="micro", zero_division=0)),
69-
"roc_auc": roc_auc_score(y_test, proba, average="macro", multi_class="ovr"),
7068
"precision_weighted": report["weighted avg"]["precision"],
7169
"recall_weighted": report["weighted avg"]["recall"],
7270
"f1_weighted": report["weighted avg"]["f1-score"],
7371
}
7472

73+
# Add ROC AUC if possible (with proper error handling)
7574
if hasattr(self.estimator, "predict_proba"):
7675
try:
7776
proba = self.estimator.predict_proba(X_test)
7877
if proba.ndim == 2 and proba.shape[1] > 1:
7978
# Check if we have enough classes for ROC AUC calculation
8079
unique_classes = len(set(y_test))
81-
if unique_classes >= 2:
80+
if unique_classes >= 2 and proba.shape[1] == len(set(y_test)):
81+
metrics["roc_auc"] = roc_auc_score(
82+
y_test, proba, average="macro", multi_class="ovr"
83+
)
8284
metrics["roc_auc_weighted"] = roc_auc_score(
8385
y_test, proba, average="weighted", multi_class="ovr"
8486
)

models/knn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,30 @@ def evaluate(self, X_test, y_test) -> Dict[str, float]:
5656
report = classification_report(
5757
y_test, predictions, output_dict=True, zero_division=0
5858
)
59-
proba = self.estimator.predict_proba(X_test)
6059

60+
# Initialize metrics without ROC AUC first
6161
metrics: Dict[str, float] = {
6262
"accuracy": report["accuracy"],
6363
"precision_macro": report["macro avg"]["precision"],
6464
"recall_macro": report["macro avg"]["recall"],
6565
"f1_macro": report["macro avg"]["f1-score"],
6666
"f1_micro": report.get("micro avg", {}).get("f1-score", f1_score(y_test, predictions, average="micro", zero_division=0)),
67-
"roc_auc": roc_auc_score(y_test, proba, average="macro", multi_class="ovr"),
6867
"precision_weighted": report["weighted avg"]["precision"],
6968
"recall_weighted": report["weighted avg"]["recall"],
7069
"f1_weighted": report["weighted avg"]["f1-score"],
7170
}
7271

73-
# Add ROC AUC if possible
72+
# Add ROC AUC if possible (with proper error handling)
7473
if hasattr(self.estimator, "predict_proba"):
7574
try:
7675
proba = self.estimator.predict_proba(X_test)
7776
if proba.ndim == 2 and proba.shape[1] > 1:
7877
# Check if we have enough classes for ROC AUC calculation
7978
unique_classes = len(set(y_test))
80-
if unique_classes >= 2:
79+
if unique_classes >= 2 and proba.shape[1] == len(set(y_test)):
80+
metrics["roc_auc"] = roc_auc_score(
81+
y_test, proba, average="macro", multi_class="ovr"
82+
)
8183
metrics["roc_auc_weighted"] = roc_auc_score(
8284
y_test, proba, average="weighted", multi_class="ovr"
8385
)

0 commit comments

Comments
 (0)