88import torch
99import torch .nn as nn
1010import torch .optim as optim
11+ from sklearn .metrics import (
12+ accuracy_score ,
13+ classification_report ,
14+ f1_score ,
15+ roc_auc_score ,
16+ )
1117from torch .utils .data import DataLoader
1218from torch .utils .tensorboard import SummaryWriter
1319
@@ -34,7 +40,7 @@ class TrainingConfig:
3440 checkpoint_path : Path = MODEL_PATH
3541 grad_clip_norm : float = 1.0
3642 writer : Optional [SummaryWriter ] = None
37- batch_size : int = 64
43+ batch_size : int = 128
3844
3945
4046class Backbone (nn .Module ):
@@ -48,34 +54,39 @@ def __init__(
4854 stride : int = 1 ,
4955 ) -> None :
5056 super ().__init__ ()
51-
52- self .features = nn .Sequential (
57+ padding = kernel_size // 2
58+ self .block1 = nn .Sequential (
5359 nn .Conv2d (
54- in_channels , 32 , kernel_size = kernel_size , stride = stride , padding = 1
60+ in_channels , 16 , kernel_size = kernel_size , stride = stride , padding = padding
5561 ),
56- nn .BatchNorm2d (32 ),
62+ nn .BatchNorm2d (16 ),
5763 nn .ReLU (inplace = True ),
64+ nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
65+ )
66+
67+ self .block2 = nn .Sequential (
5868 nn .Conv2d (
59- 32 , 64 , kernel_size = kernel_size , stride = stride , padding = 1
69+ 16 , 32 , kernel_size = kernel_size , stride = stride , padding = padding
6070 ),
61- nn .BatchNorm2d (64 ),
71+ nn .BatchNorm2d (32 ),
6272 nn .ReLU (inplace = True ),
73+ nn .MaxPool2d (kernel_size = 2 , stride = 2 ),
74+ )
75+
76+ self .block3 = nn .Sequential (
6377 nn .Conv2d (
64- 64 ,
65- 128 ,
66- kernel_size = kernel_size ,
67- stride = stride ,
68- padding = 1 ,
78+ 32 , 64 , kernel_size = kernel_size , stride = stride , padding = padding
6979 ),
70- nn .BatchNorm2d (128 ),
80+ nn .BatchNorm2d (64 ),
7181 nn .ReLU (inplace = True ),
7282 nn .AdaptiveAvgPool2d (1 ),
7383 )
7484
85+ self .features = nn .Sequential (self .block1 , self .block2 , self .block3 )
86+
7587 self .classifier = nn .Sequential (
7688 nn .Flatten (),
77- nn .Dropout (),
78- nn .Linear (128 , num_classes ),
89+ nn .Linear (64 , num_classes ),
7990 )
8091
8192 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -121,6 +132,7 @@ def train(
121132 grad_clip_norm : Optional [float ] = None ,
122133 writer : Optional [SummaryWriter ] = None ,
123134 num_workers : int = 2 ,
135+ verbose : bool = True ,
124136 ) -> Dict [str , float ]:
125137 if self .network is None :
126138 raise RuntimeError ("Train called before model is initialized" )
@@ -166,8 +178,9 @@ def train(
166178 checkpoint = Checkpoint (str (config .checkpoint_path ))
167179
168180 total_params , trainable_params = count_parameters (self .network )
169- print (f"Total parameters: { total_params :,} " )
170- print (f"Trainable parameters: { trainable_params :,} " )
181+ if verbose :
182+ print (f"Total parameters: { total_params :,} " )
183+ print (f"Trainable parameters: { trainable_params :,} " )
171184
172185 history = {
173186 "train_loss" : [],
@@ -177,7 +190,8 @@ def train(
177190 }
178191
179192 for epoch in range (1 , config .epochs + 1 ):
180- print (f"\n Epoch { epoch } /{ config .epochs } " )
193+ if verbose :
194+ print (f"\n Epoch { epoch } /{ config .epochs } " )
181195 train_loss , train_acc = train_epoch (
182196 self .network ,
183197 train_loader ,
@@ -188,6 +202,7 @@ def train(
188202 epoch = epoch ,
189203 grad_clip_norm = config .grad_clip_norm ,
190204 writer = config .writer ,
205+ verbose = verbose ,
191206 )
192207 val_loss , val_acc = validate (
193208 self .network ,
@@ -196,35 +211,40 @@ def train(
196211 device ,
197212 epoch = epoch ,
198213 writer = config .writer ,
214+ verbose = verbose ,
199215 )
200216
201217 history ["train_loss" ].append (train_loss )
202218 history ["train_acc" ].append (train_acc )
203219 history ["val_loss" ].append (val_loss )
204220 history ["val_acc" ].append (val_acc )
205221
206- print (
207- f"Train Loss: { train_loss :.4f} , Train Acc: { train_acc :.4f} ({ train_acc * 100 :.2f} %)"
208- )
209- print (
210- f"Val Loss: { val_loss :.4f} , Val Acc: { val_acc :.4f} ({ val_acc * 100 :.2f} %)"
211- )
222+ if verbose :
223+ print (
224+ f"Train Loss: { train_loss :.4f} , Train Acc: { train_acc :.4f} ({ train_acc * 100 :.2f} %)"
225+ )
226+ print (
227+ f"Val Loss: { val_loss :.4f} , Val Acc: { val_acc :.4f} ({ val_acc * 100 :.2f} %)"
228+ )
212229
213230 if checkpoint .save_if_better (
214231 self .network , optimizer , epoch , val_acc , train_acc
215232 ):
216- print (
217- f"Saved best model (val_acc={ val_acc :.4f} ) to { config .checkpoint_path } "
218- )
233+ if verbose :
234+ print (
235+ f"Saved best model (val_acc={ val_acc :.4f} ) to { config .checkpoint_path } "
236+ )
219237
220238 if early_stopper (val_loss , val_acc ):
221- print (f"\n Early stopping at epoch { epoch } " )
239+ if verbose :
240+ print (f"\n Early stopping at epoch { epoch } " )
222241 break
223242
224- print ("\n Training complete!" )
225- print (
226- f"Best val acc: { checkpoint .best_val_acc :.4f} ({ checkpoint .best_val_acc * 100 :.2f} %)"
227- )
243+ if verbose :
244+ print ("\n Training complete!" )
245+ print (
246+ f"Best val acc: { checkpoint .best_val_acc :.4f} ({ checkpoint .best_val_acc * 100 :.2f} %)"
247+ )
228248
229249 return {
230250 "best_val_acc" : checkpoint .best_val_acc ,
@@ -285,8 +305,9 @@ def evaluate(
285305 criterion = criterion or nn .CrossEntropyLoss ()
286306
287307 total_loss = 0.0
288- total_correct = 0
289- total_samples = 0
308+ all_predictions = []
309+ all_labels = []
310+ all_probas = []
290311
291312 with torch .no_grad ():
292313 for images , labels in data_loader :
@@ -297,19 +318,45 @@ def evaluate(
297318
298319 total_loss += loss .item ()
299320 preds = torch .argmax (logits , dim = 1 )
300- total_correct += (preds == labels ).sum ().item ()
301- total_samples += labels .size (0 )
321+ probas = torch .softmax (logits , dim = 1 )
322+
323+ all_predictions .extend (preds .cpu ().numpy ())
324+ all_labels .extend (labels .cpu ().numpy ())
325+ all_probas .extend (probas .cpu ().numpy ())
326+
327+ y_true = np .array (all_labels )
328+ y_pred = np .array (all_predictions )
329+ y_proba = np .array (all_probas )
330+
331+ accuracy = accuracy_score (y_true , y_pred )
332+ report = classification_report (
333+ y_true , y_pred , output_dict = True , zero_division = 0
334+ )
335+
336+ precision_macro = report ["macro avg" ]["precision" ]
337+ recall_macro = report ["macro avg" ]["recall" ]
338+ f1_macro = report ["macro avg" ]["f1-score" ]
339+ f1_micro = report .get ("micro avg" , {}).get ("f1-score" , f1_score (y_true , y_pred , average = "micro" , zero_division = 0 ))
340+
341+ roc_auc = roc_auc_score (y_true , y_proba , average = "macro" , multi_class = "ovr" )
302342
303343 avg_loss = total_loss / len (data_loader )
304- accuracy = total_correct / total_samples if total_samples else 0.0
305344
306- return {"loss" : avg_loss , "accuracy" : accuracy }
345+ return {
346+ "loss" : avg_loss ,
347+ "accuracy" : accuracy ,
348+ "precision_macro" : precision_macro ,
349+ "recall_macro" : recall_macro ,
350+ "f1_macro" : f1_macro ,
351+ "f1_micro" : f1_micro ,
352+ "roc_auc" : roc_auc ,
353+ }
307354
308355 def get_param_space (self ) -> Dict [str , ParamSpace ]:
309356 return {
310357 "kernel_size" : ParamSpace .integer (min_val = 3 , max_val = 5 , default = 3 ),
311358 "stride" : ParamSpace .integer (min_val = 1 , max_val = 3 , default = 1 ),
312- "learning_rate" : ParamSpace .float_range (
359+ "learning_rate" : ParamSpace .float_log_range (
313360 min_val = 1e-5 , max_val = 1e-2 , default = 3e-4
314361 ),
315362 "batch_size" : ParamSpace .categorical (choices = [16 , 32 , 64 , 128 ], default = 64 ),
0 commit comments