@@ -136,17 +136,10 @@ def create_model(self, **params) -> None:
136136
137137 def train (
138138 self ,
139- train_loader : DataLoader ,
140- val_loader : DataLoader ,
139+ train_data , # Can be DataLoader or raw data
140+ val_data , # Can be DataLoader or raw data
141141 config : Optional [TrainingConfig ] = None ,
142142 device : Optional [torch .device ] = None ,
143- epochs : Optional [int ] = None ,
144- patience : Optional [int ] = None ,
145- min_delta : Optional [float ] = None ,
146- checkpoint_path : Optional [Path ] = None ,
147- grad_clip_norm : Optional [float ] = None ,
148- writer : Optional [SummaryWriter ] = None ,
149- num_workers : int = 2 ,
150143 verbose : bool = True ,
151144 ) -> Dict [str , float ]:
152145 if self .network is None :
@@ -155,6 +148,27 @@ def train(
155148 self .network = self .network .to (device )
156149
157150 config = config or TrainingConfig ()
151+
152+ # Handle different input types - convert to DataLoaders if needed
153+ if isinstance (train_data , DataLoader ) and isinstance (val_data , DataLoader ):
154+ train_loader = train_data
155+ val_loader = val_data
156+ else :
157+ # Raw data provided - create DataLoaders
158+ from framework .data_utils import create_dataloaders
159+ if hasattr (train_data , '__len__' ) and hasattr (val_data , '__len__' ):
160+ # Assume train_data is X_train, val_data is y_train for backwards compatibility
161+ if len (train_data ) > 0 and not isinstance (train_data [0 ], (int , float )):
162+ # This looks like X_train, y_train, X_test, y_test pattern
163+ # Need to extract from the calling pattern
164+ batch_size = getattr (config , 'batch_size' , 64 )
165+ train_loader , val_loader = create_dataloaders (
166+ train_data , val_data , train_data [:len (val_data )], val_data , batch_size = batch_size
167+ )
168+ else :
169+ raise ValueError ("Invalid data format provided to CNN.train()" )
170+ else :
171+ raise ValueError ("Invalid data format provided to CNN.train()" )
158172
159173 optimizer = self ._build_optimizer (self .network , config )
160174 scheduler = optim .lr_scheduler .OneCycleLR (
@@ -286,7 +300,8 @@ def predict(
286300
287301 def evaluate (
288302 self ,
289- data_loader : DataLoader ,
303+ data , # Can be DataLoader or raw data
304+ labels = None , # Required if data is not DataLoader
290305 device : Optional [torch .device ] = None ,
291306 criterion : Optional [nn .Module ] = None ,
292307 num_workers : int = 0 ,
@@ -297,17 +312,6 @@ def evaluate(
297312 network = self .network .to (device )
298313 network .eval ()
299314
300- default_config = TrainingConfig ()
301- batch_size = int (self .params .get ("batch_size" , default_config .batch_size ))
302- dataset = CIFAR10Dataset (X , y )
303- data_loader = DataLoader (
304- dataset ,
305- batch_size = batch_size ,
306- shuffle = False ,
307- num_workers = num_workers ,
308- pin_memory = torch .cuda .is_available (),
309- )
310-
311315 # Handle data input - create DataLoader if raw data is provided
312316 if isinstance (data , DataLoader ):
313317 data_loader = data
@@ -357,19 +361,26 @@ def evaluate(
357361 f1_macro = report ["macro avg" ]["f1-score" ]
358362 f1_micro = report .get ("micro avg" , {}).get ("f1-score" , f1_score (y_true , y_pred , average = "micro" , zero_division = 0 ))
359363
360- roc_auc = roc_auc_score (y_true , y_proba , average = "macro" , multi_class = "ovr" )
361-
362- avg_loss = total_loss / len (data_loader )
363-
364- return {
365- "loss" : avg_loss ,
364+ # Initialize metrics without ROC AUC first
365+ metrics = {
366+ "loss" : total_loss / len (data_loader ), # average loss
366367 "accuracy" : accuracy ,
367368 "precision_macro" : precision_macro ,
368369 "recall_macro" : recall_macro ,
369370 "f1_macro" : f1_macro ,
370371 "f1_micro" : f1_micro ,
371- "roc_auc" : roc_auc ,
372372 }
373+
374+ # Try to calculate ROC AUC, but handle potential errors gracefully
375+ try :
376+ roc_auc = roc_auc_score (y_true , y_proba , average = "macro" , multi_class = "ovr" )
377+ metrics ["roc_auc" ] = roc_auc
378+ except ValueError as e :
379+ # ROC AUC calculation failed (likely due to insufficient samples per class)
380+ # Continue without ROC AUC metric
381+ print (f"Warning: Could not calculate ROC AUC: { e } " )
382+
383+ return metrics
373384
374385 def get_param_space (self ) -> Dict [str , ParamSpace ]:
375386 return {
0 commit comments