22
33from dataclasses import dataclass
44from pathlib import Path
5- from typing import Dict , List , Optional
5+ from typing import Dict , List , Optional , List
66
77import numpy as np
88import torch
2424from .ParamSpace import ParamSpace
2525from .base import BaseModel
2626
27+ import numpy as np
28+ from torch import tensor
29+
2730MODEL_PATH = Path (".cache/models/cnn_cifar.pth" )
2831
2932
@@ -89,6 +92,7 @@ def __init__(
8992 nn .Linear (64 , num_classes ),
9093 )
9194
95+ # Will be run by PyTorch under the hood
9296 def forward (self , x : torch .Tensor ) -> torch .Tensor :
9397 x = self .features (x )
9498 return self .classifier (x )
@@ -104,13 +108,25 @@ def __init__(self, num_classes: int = 10) -> None:
104108 self ._input_channels = 1 # grayscale CIFAR-10
105109
106110 def create_model (self , ** params ) -> None :
111+ """Create the CNN model with given parameters."""
112+ # Extract architecture-specific parameters for Backbone creation
113+ kernel_size = params .get ("kernel_size" , 3 )
114+ stride = params .get ("stride" , 1 )
115+ learning_rate = params .get ("learning_rate" , 3e-4 )
116+ batch_size = params .get ("batch_size" , 64 )
117+ weight_decay = params .get ("weight_decay" , 1e-3 )
118+ optimizer = params .get ("optimizer" , "AdamW" )
119+
120+ # Also ensure default values are set if not provided
121+ self .params .setdefault ("kernel_size" , kernel_size )
122+ self .params .setdefault ("stride" , stride )
123+ self .params .setdefault ("learning_rate" , learning_rate )
124+ self .params .setdefault ("batch_size" , batch_size )
125+ self .params .setdefault ("weight_decay" , weight_decay )
126+ self .params .setdefault ("optimizer" , optimizer )
127+
107128 # Store all parameters passed in
108129 self .params .update (params )
109-
110- # Extract architecture-specific parameters for Backbone creation
111- kernel_size = self .params .get ("kernel_size" , 3 )
112- stride = self .params .get ("stride" , 1 )
113-
114130 self .network = Backbone (
115131 in_channels = self ._input_channels ,
116132 num_classes = self .num_classes ,
@@ -120,10 +136,9 @@ def create_model(self, **params) -> None:
120136
121137 def train (
122138 self ,
123- X_train : List [np .ndarray ],
124- y_train : np .ndarray ,
125- X_val : List [np .ndarray ],
126- y_val : np .ndarray ,
139+ train_loader : DataLoader ,
140+ val_loader : DataLoader ,
141+ config : Optional [TrainingConfig ] = None ,
127142 device : Optional [torch .device ] = None ,
128143 epochs : Optional [int ] = None ,
129144 patience : Optional [int ] = None ,
@@ -139,29 +154,7 @@ def train(
139154 device = device or get_device ()
140155 self .network = self .network .to (device )
141156
142- default_config = TrainingConfig ()
143- config = TrainingConfig (
144- learning_rate = float (self .params .get ("learning_rate" , default_config .learning_rate )),
145- weight_decay = float (self .params .get ("weight_decay" , default_config .weight_decay )),
146- optimizer = self .params .get ("optimizer" , default_config .optimizer ),
147- batch_size = int (self .params .get ("batch_size" , default_config .batch_size )),
148- # Infrastructure params: use provided values or defaults
149- epochs = epochs if epochs is not None else default_config .epochs ,
150- patience = patience if patience is not None else default_config .patience ,
151- min_delta = min_delta if min_delta is not None else default_config .min_delta ,
152- checkpoint_path = checkpoint_path if checkpoint_path is not None else default_config .checkpoint_path ,
153- grad_clip_norm = grad_clip_norm if grad_clip_norm is not None else default_config .grad_clip_norm ,
154- writer = writer if writer is not None else default_config .writer ,
155- )
156-
157- train_loader , val_loader = create_dataloaders (
158- X_train ,
159- y_train ,
160- X_val ,
161- y_val ,
162- batch_size = config .batch_size ,
163- num_workers = num_workers ,
164- )
157+ config = config or TrainingConfig ()
165158
166159 optimizer = self ._build_optimizer (self .network , config )
167160 scheduler = optim .lr_scheduler .OneCycleLR (
@@ -256,7 +249,8 @@ def train(
256249
257250 def predict (
258251 self ,
259- data_loader : DataLoader ,
252+ data : DataLoader | List | np .ndarray ,
253+ labels : List | np .ndarray = None ,
260254 device : Optional [torch .device ] = None ,
261255 return_probabilities : bool = False ,
262256 ) -> torch .Tensor :
@@ -266,6 +260,19 @@ def predict(
266260 network = self .network .to (device )
267261 network .eval ()
268262
263+ # Handle data input - create DataLoader if raw data is provided
264+ if isinstance (data , DataLoader ):
265+ data_loader = data
266+ else :
267+ # Raw data provided - create DataLoader
268+ from framework .data_utils import create_dataloaders
269+ if labels is None :
270+ labels = [0 ] * len (data ) # dummy labels for prediction
271+ batch_size = getattr (self , 'params' , {}).get ('batch_size' , 64 )
272+ _ , data_loader = create_dataloaders (
273+ data [:1 ], [labels [0 ]], data , labels , batch_size = batch_size
274+ )
275+
269276 outputs = []
270277 with torch .no_grad ():
271278 for images , _ in data_loader :
@@ -279,8 +286,7 @@ def predict(
279286
280287 def evaluate (
281288 self ,
282- X : List [np .ndarray ],
283- y : np .ndarray ,
289+ data_loader : DataLoader ,
284290 device : Optional [torch .device ] = None ,
285291 criterion : Optional [nn .Module ] = None ,
286292 num_workers : int = 0 ,
@@ -302,6 +308,19 @@ def evaluate(
302308 pin_memory = torch .cuda .is_available (),
303309 )
304310
311+ # Handle data input - create DataLoader if raw data is provided
312+ if isinstance (data , DataLoader ):
313+ data_loader = data
314+ else :
315+ # Raw data provided - create DataLoader
316+ from framework .data_utils import create_dataloaders
317+ if labels is None :
318+ raise ValueError ("labels must be provided when data is not a DataLoader" )
319+ batch_size = getattr (self , 'params' , {}).get ('batch_size' , 64 )
320+ _ , data_loader = create_dataloaders (
321+ data [:1 ], [labels [0 ]], data , labels , batch_size = batch_size
322+ )
323+
305324 criterion = criterion or nn .CrossEntropyLoss ()
306325
307326 total_loss = 0.0
0 commit comments