11"""CNN model"""
22
3- from dataclasses import dataclass
3+ from dataclasses import dataclass , replace
44from pathlib import Path
5- from typing import Dict , List , Optional , List
5+ from typing import Dict , List , Optional
66
77import numpy as np
88import torch
2424from .ParamSpace import ParamSpace
2525from .base import BaseModel
2626
27- import numpy as np
28- from torch import tensor
2927
3028MODEL_PATH = Path (".cache/models/cnn_cifar.pth" )
3129
@@ -92,7 +90,6 @@ def __init__(
9290 nn .Linear (64 , num_classes ),
9391 )
9492
95- # Will be run by PyTorch under the hood
9693 def forward (self , x : torch .Tensor ) -> torch .Tensor :
9794 x = self .features (x )
9895 return self .classifier (x )
@@ -108,36 +105,24 @@ def __init__(self, num_classes: int = 10) -> None:
108105 self ._input_channels = 1 # grayscale CIFAR-10
109106
110107 def create_model (self , ** params ) -> None :
111- """Create the CNN model with given parameters."""
112- # Extract architecture-specific parameters for Backbone creation
113- kernel_size = params .get ("kernel_size" , 3 )
114- stride = params .get ("stride" , 1 )
115- learning_rate = params .get ("learning_rate" , 3e-4 )
116- batch_size = params .get ("batch_size" , 64 )
117- weight_decay = params .get ("weight_decay" , 1e-3 )
118- optimizer = params .get ("optimizer" , "AdamW" )
119-
120- # Also ensure default values are set if not provided
121- self .params .setdefault ("kernel_size" , kernel_size )
122- self .params .setdefault ("stride" , stride )
123- self .params .setdefault ("learning_rate" , learning_rate )
124- self .params .setdefault ("batch_size" , batch_size )
125- self .params .setdefault ("weight_decay" , weight_decay )
126- self .params .setdefault ("optimizer" , optimizer )
127-
108+ """Create the CNN model with given parameters."""
128109 # Store all parameters passed in
129110 self .params .update (params )
111+
112+ # Create Backbone with architecture-specific parameters
130113 self .network = Backbone (
131114 in_channels = self ._input_channels ,
132115 num_classes = self .num_classes ,
133- kernel_size = kernel_size ,
134- stride = stride ,
116+ kernel_size = self . params . get ( " kernel_size" , 3 ) ,
117+ stride = self . params . get ( " stride" , 1 ) ,
135118 )
136119
137120 def train (
138121 self ,
139- train_data , # Can be DataLoader or raw data
140- val_data , # Can be DataLoader or raw data
122+ X_train : List [np .ndarray ],
123+ y_train : np .ndarray ,
124+ X_val : List [np .ndarray ],
125+ y_val : np .ndarray ,
141126 config : Optional [TrainingConfig ] = None ,
142127 device : Optional [torch .device ] = None ,
143128 verbose : bool = True ,
@@ -147,28 +132,25 @@ def train(
147132 device = device or get_device ()
148133 self .network = self .network .to (device )
149134
150- config = config or TrainingConfig ()
135+ # Create config from params if not provided, using defaults where needed
136+ if config is None :
137+ default_config = TrainingConfig ()
138+ config = replace (
139+ default_config ,
140+ learning_rate = float (self .params .get ("learning_rate" , default_config .learning_rate )),
141+ weight_decay = float (self .params .get ("weight_decay" , default_config .weight_decay )),
142+ optimizer = self .params .get ("optimizer" , default_config .optimizer ),
143+ batch_size = int (self .params .get ("batch_size" , default_config .batch_size )),
144+ )
151145
152- # Handle different input types - convert to DataLoaders if needed
153- if isinstance (train_data , DataLoader ) and isinstance (val_data , DataLoader ):
154- train_loader = train_data
155- val_loader = val_data
156- else :
157- # Raw data provided - create DataLoaders
158- from framework .data_utils import create_dataloaders
159- if hasattr (train_data , '__len__' ) and hasattr (val_data , '__len__' ):
160- # Assume train_data is X_train, val_data is y_train for backwards compatibility
161- if len (train_data ) > 0 and not isinstance (train_data [0 ], (int , float )):
162- # This looks like X_train, y_train, X_test, y_test pattern
163- # Need to extract from the calling pattern
164- batch_size = getattr (config , 'batch_size' , 64 )
165- train_loader , val_loader = create_dataloaders (
166- train_data , val_data , train_data [:len (val_data )], val_data , batch_size = batch_size
167- )
168- else :
169- raise ValueError ("Invalid data format provided to CNN.train()" )
170- else :
171- raise ValueError ("Invalid data format provided to CNN.train()" )
146+ # Create DataLoaders from X_train/y_train/X_val/y_val
147+ train_loader , val_loader = create_dataloaders (
148+ X_train ,
149+ y_train ,
150+ X_val ,
151+ y_val ,
152+ batch_size = config .batch_size ,
153+ )
172154
173155 optimizer = self ._build_optimizer (self .network , config )
174156 scheduler = optim .lr_scheduler .OneCycleLR (
@@ -263,8 +245,7 @@ def train(
263245
264246 def predict (
265247 self ,
266- data : DataLoader | List | np .ndarray ,
267- labels : List | np .ndarray = None ,
248+ X_test : List [np .ndarray ],
268249 device : Optional [torch .device ] = None ,
269250 return_probabilities : bool = False ,
270251 ) -> torch .Tensor :
@@ -274,18 +255,16 @@ def predict(
274255 network = self .network .to (device )
275256 network .eval ()
276257
277- # Handle data input - create DataLoader if raw data is provided
278- if isinstance (data , DataLoader ):
279- data_loader = data
280- else :
281- # Raw data provided - create DataLoader
282- from framework .data_utils import create_dataloaders
283- if labels is None :
284- labels = [0 ] * len (data ) # dummy labels for prediction
285- batch_size = getattr (self , 'params' , {}).get ('batch_size' , 64 )
286- _ , data_loader = create_dataloaders (
287- data [:1 ], [labels [0 ]], data , labels , batch_size = batch_size
288- )
258+ # Create DataLoader from X_test (with dummy labels for prediction)
259+ batch_size = int (self .params .get ("batch_size" , 128 ))
260+ dummy_labels = np .zeros (len (X_test ), dtype = np .int64 )
261+ dataset = CIFAR10Dataset (X_test , dummy_labels )
262+ data_loader = DataLoader (
263+ dataset ,
264+ batch_size = batch_size ,
265+ shuffle = False ,
266+ pin_memory = torch .cuda .is_available (),
267+ )
289268
290269 outputs = []
291270 with torch .no_grad ():
@@ -300,30 +279,26 @@ def predict(
300279
301280 def evaluate (
302281 self ,
303- data , # Can be DataLoader or raw data
304- labels = None , # Required if data is not DataLoader
282+ X_test : List [ np . ndarray ],
283+ y_test : np . ndarray ,
305284 device : Optional [torch .device ] = None ,
306285 criterion : Optional [nn .Module ] = None ,
307- num_workers : int = 0 ,
308286 ) -> Dict [str , float ]:
309287 if self .network is None :
310288 raise RuntimeError ("Evaluate called before model is initialized" )
311289 device = device or get_device ()
312290 network = self .network .to (device )
313291 network .eval ()
314292
315- # Handle data input - create DataLoader if raw data is provided
316- if isinstance (data , DataLoader ):
317- data_loader = data
318- else :
319- # Raw data provided - create DataLoader
320- from framework .data_utils import create_dataloaders
321- if labels is None :
322- raise ValueError ("labels must be provided when data is not a DataLoader" )
323- batch_size = getattr (self , 'params' , {}).get ('batch_size' , 64 )
324- _ , data_loader = create_dataloaders (
325- data [:1 ], [labels [0 ]], data , labels , batch_size = batch_size
326- )
293+ # Create DataLoader from X_test/y_test
294+ batch_size = int (self .params .get ("batch_size" , 128 ))
295+ dataset = CIFAR10Dataset (X_test , y_test )
296+ data_loader = DataLoader (
297+ dataset ,
298+ batch_size = batch_size ,
299+ shuffle = False ,
300+ pin_memory = torch .cuda .is_available (),
301+ )
327302
328303 criterion = criterion or nn .CrossEntropyLoss ()
329304
@@ -359,28 +334,21 @@ def evaluate(
359334 precision_macro = report ["macro avg" ]["precision" ]
360335 recall_macro = report ["macro avg" ]["recall" ]
361336 f1_macro = report ["macro avg" ]["f1-score" ]
362- f1_micro = report . get ( "micro avg" , {}). get ( "f1-score" , f1_score (y_true , y_pred , average = "micro" , zero_division = 0 ) )
337+ f1_micro = f1_score (y_true , y_pred , average = "micro" , zero_division = 0 )
363338
364- # Initialize metrics without ROC AUC first
365- metrics = {
366- "loss" : total_loss / len (data_loader ), # average loss
339+ roc_auc = roc_auc_score (y_true , y_proba , average = "macro" , multi_class = "ovr" )
340+
341+ avg_loss = total_loss / len (data_loader )
342+
343+ return {
344+ "loss" : avg_loss ,
367345 "accuracy" : accuracy ,
368346 "precision_macro" : precision_macro ,
369347 "recall_macro" : recall_macro ,
370348 "f1_macro" : f1_macro ,
371349 "f1_micro" : f1_micro ,
350+ "roc_auc" : roc_auc ,
372351 }
373-
374- # Try to calculate ROC AUC, but handle potential errors gracefully
375- try :
376- roc_auc = roc_auc_score (y_true , y_proba , average = "macro" , multi_class = "ovr" )
377- metrics ["roc_auc" ] = roc_auc
378- except ValueError as e :
379- # ROC AUC calculation failed (likely due to insufficient samples per class)
380- # Continue without ROC AUC metric
381- print (f"Warning: Could not calculate ROC AUC: { e } " )
382-
383- return metrics
384352
385353 def get_param_space (self ) -> Dict [str , ParamSpace ]:
386354 return {
0 commit comments