11import torch
2+ import torch .nn as nn
23from tqdm import tqdm
34from typing import Tuple , Optional
45from pathlib import Path
@@ -13,64 +14,56 @@ class Checkpoint:
1314 def __init__ (self , model_path : str ):
1415 self .model_path = Path (model_path )
1516 self .best_val_acc = 0.0
16-
17+
1718 def save_if_better (
1819 self ,
1920 model : Module ,
2021 optimizer : Optimizer ,
2122 epoch : int ,
2223 val_acc : float ,
2324 train_acc : float ,
24- ** kwargs
25+ ** kwargs ,
2526 ) -> bool :
2627 """Save checkpoint if validation accuracy improved."""
2728 if val_acc > self .best_val_acc :
2829 self .best_val_acc = val_acc
2930 self .model_path .parent .mkdir (parents = True , exist_ok = True )
30-
31- checkpoint = {
32- 'epoch' : epoch ,
33- 'model_state_dict' : model .state_dict (),
34- 'optimizer_state_dict' : optimizer . state_dict () ,
35- ' val_acc' : val_acc ,
36- ' train_acc' : train_acc ,
37- ** kwargs
31+
32+ checkpoint_data = {
33+ "model_state_dict" : model . state_dict () ,
34+ "optimizer_state_dict" : optimizer .state_dict (),
35+ "epoch" : epoch ,
36+ " val_acc" : val_acc ,
37+ " train_acc" : train_acc ,
38+ ** kwargs ,
3839 }
39-
40- torch .save (checkpoint , str (self .model_path ))
40+ torch .save (checkpoint_data , self .model_path )
4141 return True
4242 return False
43-
44- def load (self , model : Module , optimizer : Optimizer ) -> dict :
45- """Load checkpoint from disk."""
46- checkpoint = torch .load (str (self .model_path ), map_location = 'cpu' )
47- model .load_state_dict (checkpoint ['model_state_dict' ])
48- optimizer .load_state_dict (checkpoint ['optimizer_state_dict' ])
49- self .best_val_acc = checkpoint .get ('val_acc' , 0.0 )
50- return checkpoint
5143
5244
5345class EarlyStopping :
54- def __init__ (self , patience : int = 10 , min_delta : float = 0.0 ) -> None :
46+ def __init__ (self , patience : int = 7 , min_delta : float = 0.0 ):
5547 self .patience = patience
5648 self .min_delta = min_delta
57- self .best_loss = float ('inf' )
5849 self .counter = 0
59- self .best_acc = 0.0
60-
61- def __call__ (self , val_loss : float , val_acc : float ) -> bool :
62- if val_loss < self .best_loss - self .min_delta :
50+ self .best_loss : Optional [float ] = None
51+ self .best_acc : float = 0.0
52+
53+ def __call__ (self , val_loss : float , val_acc : Optional [float ] = None ) -> bool :
54+ """Returns True if training should stop."""
55+ if self .best_loss is None or val_loss < self .best_loss - self .min_delta :
6356 self .best_loss = val_loss
64- self .best_acc = val_acc
57+ if val_acc is not None :
58+ self .best_acc = max (self .best_acc , val_acc )
6559 self .counter = 0
6660 return False
67- else :
68- self .counter += 1
69- if self .counter >= self .patience :
70- return True
71- return False
7261
62+ if val_acc is not None :
63+ self .best_acc = max (self .best_acc , val_acc )
7364
65+ self .counter += 1
66+ return self .counter >= self .patience
7467
7568
7669def train_epoch (
@@ -79,18 +72,20 @@ def train_epoch(
7972 criterion : torch .nn .Module ,
8073 optimizer : Optimizer ,
8174 device : torch .device ,
82- scheduler : LRScheduler ,
75+ scheduler : Optional [ LRScheduler ] = None ,
8376 epoch : int = 0 ,
8477 grad_clip_norm : float = 1.0 ,
85- writer : Optional [SummaryWriter ] = None
78+ writer : Optional [SummaryWriter ] = None ,
8679) -> Tuple [float , float ]:
8780 """Trains the model for one epoch and returns the epoch loss and accuracy."""
88- model . train ()
81+ nn . Module . train (model , mode = True )
8982 running_loss = 0.0
9083 correct = 0
9184 total = 0
9285
93- for batch_idx , (images , labels ) in enumerate (tqdm (train_loader , desc = "Training" , leave = False )):
86+ for batch_idx , (images , labels ) in enumerate (
87+ tqdm (train_loader , desc = "Training" , leave = False )
88+ ):
9489 images , labels = images .to (device ), labels .to (device )
9590
9691 optimizer .zero_grad ()
@@ -99,40 +94,32 @@ def train_epoch(
9994 loss .backward ()
10095 torch .nn .utils .clip_grad_norm_ (model .parameters (), max_norm = grad_clip_norm )
10196 optimizer .step ()
102- scheduler .step ()
97+ if scheduler is not None :
98+ scheduler .step ()
10399
104- # Stats - compute once and reuse to avoid duplicate .item() calls
105- loss_value = loss .item () # Single GPU->CPU sync
100+ loss_value = loss .item ()
106101 running_loss += loss_value
107-
102+
108103 _ , predicted = torch .max (outputs .data , 1 )
109- batch_correct = (predicted == labels ).sum ().item () # Single GPU->CPU sync
104+ batch_correct = (predicted == labels ).sum ().item ()
110105 total += labels .size (0 )
111106 correct += batch_correct
112107
113- # Track metrics
114108 if batch_idx % 10 == 0 and writer is not None :
115109 batch_total = labels .size (0 )
116110 batch_acc = 100 * batch_correct / batch_total
117- current_lr = optimizer .param_groups [0 ]['lr' ]
111+ current_lr = optimizer .param_groups [0 ]["lr" ]
118112 step = epoch * len (train_loader ) + batch_idx
119- writer .add_scalar (' train/batch_loss' , loss_value , step )
120- writer .add_scalar (' train/batch_accuracy' , batch_acc , step )
121- writer .add_scalar (' train/learning_rate' , current_lr , step )
113+ writer .add_scalar (" train/batch_loss" , loss_value , step )
114+ writer .add_scalar (" train/batch_accuracy" , batch_acc , step )
115+ writer .add_scalar (" train/learning_rate" , current_lr , step )
122116
123117 epoch_loss = running_loss / len (train_loader )
124118 epoch_acc = correct / total
125119
126- # Track epoch-level metrics
127120 if writer is not None :
128- writer .add_scalar ('train/epoch_loss' , epoch_loss , epoch )
129- writer .add_scalar ('train/epoch_accuracy' , epoch_acc * 100 , epoch )
130-
131- # Log parameter and gradient histograms (only every N epochs to reduce CPU overhead)
132- if writer is not None and (epoch % 10 == 0 or epoch == 1 ): # Log every 10 epochs or first epoch
133- for name , param in model .named_parameters ():
134- writer .add_histogram (f'train_params/{ name } ' , param .data , epoch )
135- writer .add_histogram (f'train_grads/{ name } ' , param .grad .data , epoch )
121+ writer .add_scalar ("train/epoch_loss" , epoch_loss , epoch )
122+ writer .add_scalar ("train/epoch_accuracy" , epoch_acc * 100 , epoch )
136123
137124 return epoch_loss , epoch_acc
138125
@@ -143,7 +130,7 @@ def validate(
143130 criterion : torch .nn .Module ,
144131 device : torch .device ,
145132 epoch : int = 0 ,
146- writer : Optional [SummaryWriter ] = None
133+ writer : Optional [SummaryWriter ] = None ,
147134) -> Tuple [float , float ]:
148135 """Validates the model and returns the epoch loss and accuracy."""
149136 model .eval ()
@@ -157,8 +144,8 @@ def validate(
157144
158145 outputs = model (images )
159146 loss = criterion (outputs , labels )
160-
161147 running_loss += loss .item ()
148+
162149 _ , predicted = torch .max (outputs .data , 1 )
163150 total += labels .size (0 )
164151 correct += (predicted == labels ).sum ().item ()
@@ -167,8 +154,7 @@ def validate(
167154 epoch_acc = correct / total
168155
169156 if writer is not None :
170- writer .add_scalar (' val/epoch_loss' , epoch_loss , epoch )
171- writer .add_scalar (' val/epoch_accuracy' , epoch_acc * 100 , epoch )
157+ writer .add_scalar (" val/epoch_loss" , epoch_loss , epoch )
158+ writer .add_scalar (" val/epoch_accuracy" , epoch_acc * 100 , epoch )
172159
173160 return epoch_loss , epoch_acc
174-
0 commit comments