1919# TODO: Implement Lightning?
2020
2121
22+ def promote_buffers_to_parameters (
23+ model : torch .nn .Module ,
24+ buffer_indices : list [int ] | None = None ,
25+ ) -> torch .nn .Module :
26+ """
27+ Promote ONNX initializer buffers to nn.Parameters so they become trainable.
28+
29+ ONNX-converted GraphModules (from onnx2torch) store dense/FC layer weights as
30+ buffers on an ``initializers`` submodule, making them invisible to the optimizer.
31+ This function converts selected buffers to nn.Parameters so they can be fine-tuned.
32+
33+ Parameters
34+ ----------
35+ model
36+ The loaded GraphModule from onnx2torch.
37+ buffer_indices
38+ Indices of ``onnx_initializer_*`` buffers to promote. If None, promotes the
39+ global feature branch (0-5) and the final dense head (34-45).
40+
41+ Returns
42+ -------
43+ torch.nn.Module
44+ The same model with buffers promoted to parameters.
45+
46+ """
47+ if buffer_indices is None :
48+ # Dense head (34-45) + global feature branch (0-5)
49+ buffer_indices = list (range (0 , 6 )) + list (range (34 , 46 ))
50+
51+ init_mod = dict (model .named_modules ()).get ("initializers" )
52+ if init_mod is None :
53+ logger .debug ("No 'initializers' submodule found; skipping buffer promotion." )
54+ return model
55+
56+ promoted = 0
57+ for idx in buffer_indices :
58+ name = f"onnx_initializer_{ idx } "
59+ if name in init_mod ._buffers :
60+ buf = init_mod ._buffers .pop (name )
61+ init_mod ._parameters [name ] = torch .nn .Parameter (buf )
62+ promoted += 1
63+
64+ logger .info (
65+ f"Promoted { promoted } buffers to parameters. "
66+ f"Total trainable params: { sum (p .numel () for p in model .parameters () if p .requires_grad )} "
67+ )
68+ return model
69+
70+
2271def load_model (
2372 model : torch .nn .Module | PathLike | str | None = None ,
2473 device : str | None = None ,
2574) -> torch .nn .Module :
2675 """Load a model from a file or return a randomly initialized model if none is provided."""
2776 # If device is not specified, use the default device (GPU if available, else CPU)
28- selected_device = device or torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
77+ selected_device = device or torch .device (
78+ "cuda" if torch .cuda .is_available () else "cpu"
79+ )
2980
3081 # Load model from file if a path is provided
3182 if isinstance (model , str | Path ):
32- loaded_model = torch .load (model , weights_only = False , map_location = selected_device )
83+ loaded_model = torch .load (
84+ model , weights_only = False , map_location = selected_device
85+ )
3386 elif isinstance (model , torch .nn .Module ):
3487 loaded_model = model
3588 elif model is None :
3689 # Initialize a new model with default architecture
3790 loaded_model = DeepLCModel ()
3891 logger .debug ("Initialized new DeepLCModel with default architecture" )
3992 else :
40- raise TypeError (f"Expected a PyTorch Module or a file path, got { type (model )} instead." )
93+ raise TypeError (
94+ f"Expected a PyTorch Module or a file path, got { type (model )} instead."
95+ )
4196
4297 # Ensure the model is on the specified device
4398 loaded_model .to (selected_device )
@@ -92,6 +147,11 @@ def train(
92147 """
93148 model = load_model (model , device )
94149
150+ # Promote ONNX initializer buffers (dense head) to trainable parameters
151+ model = promote_buffers_to_parameters (model )
152+
153+ # Freeze layers if requested
154+
95155 # Freeze layers if requested
96156 if trainable_layers is not None :
97157 _freeze_layers (model , trainable_layers )
@@ -102,7 +162,10 @@ def train(
102162 train_dataset , batch_size = batch_size , shuffle = True , num_workers = num_workers
103163 )
104164 val_loader = DataLoader (
105- validation_dataset , batch_size = batch_size , shuffle = False , num_workers = num_workers
165+ validation_dataset ,
166+ batch_size = batch_size ,
167+ shuffle = False ,
168+ num_workers = num_workers ,
106169 )
107170
108171 optimizer = _get_optimizer (model , learning_rate )
@@ -145,7 +208,9 @@ def predict(
145208) -> torch .Tensor :
146209 """Predict using the model for the given dataset."""
147210 model = load_model (model , device )
148- data_loader = DataLoader (data , batch_size = batch_size , shuffle = False , num_workers = num_workers )
211+ data_loader = DataLoader (
212+ data , batch_size = batch_size , shuffle = False , num_workers = num_workers
213+ )
149214 predictions = _predict_epoch (model , data_loader , device )
150215 return predictions .cpu ().detach ()
151216
@@ -159,7 +224,9 @@ def evaluate(
159224) -> float :
160225 """Evaluate the model on the given dataset."""
161226 model = load_model (model , device )
162- data_loader = DataLoader (data , batch_size = batch_size , shuffle = False , num_workers = num_workers )
227+ data_loader = DataLoader (
228+ data , batch_size = batch_size , shuffle = False , num_workers = num_workers
229+ )
163230 loss_fn = torch .nn .L1Loss ()
164231 avg_loss = _validate_epoch (model , data_loader , loss_fn , device )
165232 return avg_loss
@@ -171,7 +238,9 @@ def _freeze_layers(model: torch.nn.Module, unfreeze_keyword: str) -> None:
171238 param .requires_grad = unfreeze_keyword in name
172239
173240
174- def _get_optimizer (model : torch .nn .Module , learning_rate : float ) -> torch .optim .Optimizer :
241+ def _get_optimizer (
242+ model : torch .nn .Module , learning_rate : float
243+ ) -> torch .optim .Optimizer :
175244 return torch .optim .Adam (
176245 filter (lambda p : p .requires_grad , model .parameters ()),
177246 lr = learning_rate ,
0 commit comments