Skip to content

Commit 5e9890c

Browse files
Update _model_ops.py
Make prediction head trainable
1 parent 2b0aea9 commit 5e9890c

1 file changed

Lines changed: 76 additions & 7 deletions

File tree

deeplc/_model_ops.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,80 @@
1919
# TODO: Implement Lightning?
2020

2121

22+
def promote_buffers_to_parameters(
23+
model: torch.nn.Module,
24+
buffer_indices: list[int] | None = None,
25+
) -> torch.nn.Module:
26+
"""
27+
Promote ONNX initializer buffers to nn.Parameters so they become trainable.
28+
29+
ONNX-converted GraphModules (from onnx2torch) store dense/FC layer weights as
30+
buffers on an ``initializers`` submodule, making them invisible to the optimizer.
31+
This function converts selected buffers to nn.Parameters so they can be fine-tuned.
32+
33+
Parameters
34+
----------
35+
model
36+
The loaded GraphModule from onnx2torch.
37+
buffer_indices
38+
Indices of ``onnx_initializer_*`` buffers to promote. If None, promotes the
39+
global feature branch (0-5) and the final dense head (34-45).
40+
41+
Returns
42+
-------
43+
torch.nn.Module
44+
The same model with buffers promoted to parameters.
45+
46+
"""
47+
if buffer_indices is None:
48+
# Dense head (34-45) + global feature branch (0-5)
49+
buffer_indices = list(range(0, 6)) + list(range(34, 46))
50+
51+
init_mod = dict(model.named_modules()).get("initializers")
52+
if init_mod is None:
53+
logger.debug("No 'initializers' submodule found; skipping buffer promotion.")
54+
return model
55+
56+
promoted = 0
57+
for idx in buffer_indices:
58+
name = f"onnx_initializer_{idx}"
59+
if name in init_mod._buffers:
60+
buf = init_mod._buffers.pop(name)
61+
init_mod._parameters[name] = torch.nn.Parameter(buf)
62+
promoted += 1
63+
64+
logger.info(
65+
f"Promoted {promoted} buffers to parameters. "
66+
f"Total trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
67+
)
68+
return model
69+
70+
2271
def load_model(
2372
model: torch.nn.Module | PathLike | str | None = None,
2473
device: str | None = None,
2574
) -> torch.nn.Module:
2675
"""Load a model from a file or return a randomly initialized model if none is provided."""
2776
# If device is not specified, use the default device (GPU if available, else CPU)
28-
selected_device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
77+
selected_device = device or torch.device(
78+
"cuda" if torch.cuda.is_available() else "cpu"
79+
)
2980

3081
# Load model from file if a path is provided
3182
if isinstance(model, str | Path):
32-
loaded_model = torch.load(model, weights_only=False, map_location=selected_device)
83+
loaded_model = torch.load(
84+
model, weights_only=False, map_location=selected_device
85+
)
3386
elif isinstance(model, torch.nn.Module):
3487
loaded_model = model
3588
elif model is None:
3689
# Initialize a new model with default architecture
3790
loaded_model = DeepLCModel()
3891
logger.debug("Initialized new DeepLCModel with default architecture")
3992
else:
40-
raise TypeError(f"Expected a PyTorch Module or a file path, got {type(model)} instead.")
93+
raise TypeError(
94+
f"Expected a PyTorch Module or a file path, got {type(model)} instead."
95+
)
4196

4297
# Ensure the model is on the specified device
4398
loaded_model.to(selected_device)
@@ -92,6 +147,11 @@ def train(
92147
"""
93148
model = load_model(model, device)
94149

150+
# Promote ONNX initializer buffers (dense head) to trainable parameters
151+
model = promote_buffers_to_parameters(model)
152+
153+
# Freeze layers if requested
154+
95155
# Freeze layers if requested
96156
if trainable_layers is not None:
97157
_freeze_layers(model, trainable_layers)
@@ -102,7 +162,10 @@ def train(
102162
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
103163
)
104164
val_loader = DataLoader(
105-
validation_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
165+
validation_dataset,
166+
batch_size=batch_size,
167+
shuffle=False,
168+
num_workers=num_workers,
106169
)
107170

108171
optimizer = _get_optimizer(model, learning_rate)
@@ -145,7 +208,9 @@ def predict(
145208
) -> torch.Tensor:
146209
"""Predict using the model for the given dataset."""
147210
model = load_model(model, device)
148-
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
211+
data_loader = DataLoader(
212+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
213+
)
149214
predictions = _predict_epoch(model, data_loader, device)
150215
return predictions.cpu().detach()
151216

@@ -159,7 +224,9 @@ def evaluate(
159224
) -> float:
160225
"""Evaluate the model on the given dataset."""
161226
model = load_model(model, device)
162-
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
227+
data_loader = DataLoader(
228+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
229+
)
163230
loss_fn = torch.nn.L1Loss()
164231
avg_loss = _validate_epoch(model, data_loader, loss_fn, device)
165232
return avg_loss
@@ -171,7 +238,9 @@ def _freeze_layers(model: torch.nn.Module, unfreeze_keyword: str) -> None:
171238
param.requires_grad = unfreeze_keyword in name
172239

173240

174-
def _get_optimizer(model: torch.nn.Module, learning_rate: float) -> torch.optim.Optimizer:
241+
def _get_optimizer(
242+
model: torch.nn.Module, learning_rate: float
243+
) -> torch.optim.Optimizer:
175244
return torch.optim.Adam(
176245
filter(lambda p: p.requires_grad, model.parameters()),
177246
lr=learning_rate,

0 commit comments

Comments
 (0)