22
33from dataclasses import dataclass
44from pathlib import Path
5- from typing import Dict , List , Optional
5+ from typing import Dict , List , Optional , List
66
77import numpy as np
88import torch
2424from .ParamSpace import ParamSpace
2525from .base import BaseModel
2626
27+ import numpy as np
28+ from torch import tensor
29+
2730MODEL_PATH = Path (".cache/models/cnn_cifar.pth" )
2831
2932
@@ -89,6 +92,7 @@ def __init__(
8992 nn .Linear (64 , num_classes ),
9093 )
9194
95+ # Will be run by PyTorch under the hood
9296 def forward (self , x : torch .Tensor ) -> torch .Tensor :
9397 x = self .features (x )
9498 return self .classifier (x )
@@ -104,13 +108,25 @@ def __init__(self, num_classes: int = 10) -> None:
104108 self ._input_channels = 1 # grayscale CIFAR-10
105109
106110 def create_model (self , ** params ) -> None :
111+ """Create the CNN model with given parameters."""
112+ # Extract architecture-specific parameters for Backbone creation
113+ kernel_size = params .get ("kernel_size" , 3 )
114+ stride = params .get ("stride" , 1 )
115+ learning_rate = params .get ("learning_rate" , 3e-4 )
116+ batch_size = params .get ("batch_size" , 64 )
117+ weight_decay = params .get ("weight_decay" , 1e-3 )
118+ optimizer = params .get ("optimizer" , "AdamW" )
119+
120+ # Also ensure default values are set if not provided
121+ self .params .setdefault ("kernel_size" , kernel_size )
122+ self .params .setdefault ("stride" , stride )
123+ self .params .setdefault ("learning_rate" , learning_rate )
124+ self .params .setdefault ("batch_size" , batch_size )
125+ self .params .setdefault ("weight_decay" , weight_decay )
126+ self .params .setdefault ("optimizer" , optimizer )
127+
107128 # Store all parameters passed in
108129 self .params .update (params )
109-
110- # Extract architecture-specific parameters for Backbone creation
111- kernel_size = self .params .get ("kernel_size" , 3 )
112- stride = self .params .get ("stride" , 1 )
113-
114130 self .network = Backbone (
115131 in_channels = self ._input_channels ,
116132 num_classes = self .num_classes ,
@@ -120,48 +136,39 @@ def create_model(self, **params) -> None:
120136
121137 def train (
122138 self ,
123- X_train : List [np .ndarray ],
124- y_train : np .ndarray ,
125- X_val : List [np .ndarray ],
126- y_val : np .ndarray ,
139+ train_data , # Can be DataLoader or raw data
140+ val_data , # Can be DataLoader or raw data
141+ config : Optional [TrainingConfig ] = None ,
127142 device : Optional [torch .device ] = None ,
128- epochs : Optional [int ] = None ,
129- patience : Optional [int ] = None ,
130- min_delta : Optional [float ] = None ,
131- checkpoint_path : Optional [Path ] = None ,
132- grad_clip_norm : Optional [float ] = None ,
133- writer : Optional [SummaryWriter ] = None ,
134- num_workers : int = 2 ,
135143 verbose : bool = True ,
136144 ) -> Dict [str , float ]:
137145 if self .network is None :
138146 raise RuntimeError ("Train called before model is initialized" )
139147 device = device or get_device ()
140148 self .network = self .network .to (device )
141149
142- default_config = TrainingConfig ()
143- config = TrainingConfig (
144- learning_rate = float (self .params .get ("learning_rate" , default_config .learning_rate )),
145- weight_decay = float (self .params .get ("weight_decay" , default_config .weight_decay )),
146- optimizer = self .params .get ("optimizer" , default_config .optimizer ),
147- batch_size = int (self .params .get ("batch_size" , default_config .batch_size )),
148- # Infrastructure params: use provided values or defaults
149- epochs = epochs if epochs is not None else default_config .epochs ,
150- patience = patience if patience is not None else default_config .patience ,
151- min_delta = min_delta if min_delta is not None else default_config .min_delta ,
152- checkpoint_path = checkpoint_path if checkpoint_path is not None else default_config .checkpoint_path ,
153- grad_clip_norm = grad_clip_norm if grad_clip_norm is not None else default_config .grad_clip_norm ,
154- writer = writer if writer is not None else default_config .writer ,
155- )
156-
157- train_loader , val_loader = create_dataloaders (
158- X_train ,
159- y_train ,
160- X_val ,
161- y_val ,
162- batch_size = config .batch_size ,
163- num_workers = num_workers ,
164- )
150+ config = config or TrainingConfig ()
151+
152+ # Handle different input types - convert to DataLoaders if needed
153+ if isinstance (train_data , DataLoader ) and isinstance (val_data , DataLoader ):
154+ train_loader = train_data
155+ val_loader = val_data
156+ else :
157+ # Raw data provided - create DataLoaders
158+ from framework .data_utils import create_dataloaders
159+ if hasattr (train_data , '__len__' ) and hasattr (val_data , '__len__' ):
160+ # Assume train_data is X_train, val_data is y_train for backwards compatibility
161+ if len (train_data ) > 0 and not isinstance (train_data [0 ], (int , float )):
162+ # This looks like X_train, y_train, X_test, y_test pattern
163+ # Need to extract from the calling pattern
164+ batch_size = getattr (config , 'batch_size' , 64 )
165+ train_loader , val_loader = create_dataloaders (
166+ train_data , val_data , train_data [:len (val_data )], val_data , batch_size = batch_size
167+ )
168+ else :
169+ raise ValueError ("Invalid data format provided to CNN.train()" )
170+ else :
171+ raise ValueError ("Invalid data format provided to CNN.train()" )
165172
166173 optimizer = self ._build_optimizer (self .network , config )
167174 scheduler = optim .lr_scheduler .OneCycleLR (
@@ -256,7 +263,8 @@ def train(
256263
257264 def predict (
258265 self ,
259- data_loader : DataLoader ,
266+ data : DataLoader | List | np .ndarray ,
267+ labels : List | np .ndarray = None ,
260268 device : Optional [torch .device ] = None ,
261269 return_probabilities : bool = False ,
262270 ) -> torch .Tensor :
@@ -266,6 +274,19 @@ def predict(
266274 network = self .network .to (device )
267275 network .eval ()
268276
277+ # Handle data input - create DataLoader if raw data is provided
278+ if isinstance (data , DataLoader ):
279+ data_loader = data
280+ else :
281+ # Raw data provided - create DataLoader
282+ from framework .data_utils import create_dataloaders
283+ if labels is None :
284+ labels = [0 ] * len (data ) # dummy labels for prediction
285+ batch_size = getattr (self , 'params' , {}).get ('batch_size' , 64 )
286+ _ , data_loader = create_dataloaders (
287+ data [:1 ], [labels [0 ]], data , labels , batch_size = batch_size
288+ )
289+
269290 outputs = []
270291 with torch .no_grad ():
271292 for images , _ in data_loader :
@@ -279,8 +300,8 @@ def predict(
279300
280301 def evaluate (
281302 self ,
282- X : List [ np . ndarray ],
283- y : np . ndarray ,
303+ data , # Can be DataLoader or raw data
304+ labels = None , # Required if data is not DataLoader
284305 device : Optional [torch .device ] = None ,
285306 criterion : Optional [nn .Module ] = None ,
286307 num_workers : int = 0 ,
@@ -291,16 +312,18 @@ def evaluate(
291312 network = self .network .to (device )
292313 network .eval ()
293314
294- default_config = TrainingConfig ()
295- batch_size = int (self .params .get ("batch_size" , default_config .batch_size ))
296- dataset = CIFAR10Dataset (X , y )
297- data_loader = DataLoader (
298- dataset ,
299- batch_size = batch_size ,
300- shuffle = False ,
301- num_workers = num_workers ,
302- pin_memory = torch .cuda .is_available (),
303- )
315+ # Handle data input - create DataLoader if raw data is provided
316+ if isinstance (data , DataLoader ):
317+ data_loader = data
318+ else :
319+ # Raw data provided - create DataLoader
320+ from framework .data_utils import create_dataloaders
321+ if labels is None :
322+ raise ValueError ("labels must be provided when data is not a DataLoader" )
323+ batch_size = getattr (self , 'params' , {}).get ('batch_size' , 64 )
324+ _ , data_loader = create_dataloaders (
325+ data [:1 ], [labels [0 ]], data , labels , batch_size = batch_size
326+ )
304327
305328 criterion = criterion or nn .CrossEntropyLoss ()
306329
@@ -338,19 +361,26 @@ def evaluate(
338361 f1_macro = report ["macro avg" ]["f1-score" ]
339362 f1_micro = report .get ("micro avg" , {}).get ("f1-score" , f1_score (y_true , y_pred , average = "micro" , zero_division = 0 ))
340363
341- roc_auc = roc_auc_score (y_true , y_proba , average = "macro" , multi_class = "ovr" )
342-
343- avg_loss = total_loss / len (data_loader )
344-
345- return {
346- "loss" : avg_loss ,
364+ # Initialize metrics without ROC AUC first
365+ metrics = {
366+ "loss" : total_loss / len (data_loader ), # average loss
347367 "accuracy" : accuracy ,
348368 "precision_macro" : precision_macro ,
349369 "recall_macro" : recall_macro ,
350370 "f1_macro" : f1_macro ,
351371 "f1_micro" : f1_micro ,
352- "roc_auc" : roc_auc ,
353372 }
373+
374+ # Try to calculate ROC AUC, but handle potential errors gracefully
375+ try :
376+ roc_auc = roc_auc_score (y_true , y_proba , average = "macro" , multi_class = "ovr" )
377+ metrics ["roc_auc" ] = roc_auc
378+ except ValueError as e :
379+ # ROC AUC calculation failed (likely due to insufficient samples per class)
380+ # Continue without ROC AUC metric
381+ print (f"Warning: Could not calculate ROC AUC: { e } " )
382+
383+ return metrics
354384
355385 def get_param_space (self ) -> Dict [str , ParamSpace ]:
356386 return {
0 commit comments