Skip to content

Commit fe9bb73

Browse files
committed
Hard fix promote buffers to parameters in model files, avoiding fix on every model load
1 parent ac7e53d commit fe9bb73

4 files changed

Lines changed: 0 additions & 55 deletions

deeplc/_model_ops.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,58 +23,6 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
# TODO: Implement Lightning?
27-
28-
29-
def promote_buffers_to_parameters(
30-
model: torch.nn.Module,
31-
buffer_indices: list[int] | None = None,
32-
) -> torch.nn.Module:
33-
"""
34-
Promote ONNX initializer buffers to nn.Parameters so they become trainable.
35-
36-
ONNX-converted GraphModules (from onnx2torch) store dense/FC layer weights as
37-
buffers on an ``initializers`` submodule, making them invisible to the optimizer.
38-
This function converts selected buffers to nn.Parameters so they can be fine-tuned.
39-
40-
Parameters
41-
----------
42-
model
43-
The loaded GraphModule from onnx2torch.
44-
buffer_indices
45-
Indices of ``onnx_initializer_*`` buffers to promote. If None, promotes the
46-
global feature branch (0-5) and the final dense head (34-45).
47-
48-
Returns
49-
-------
50-
torch.nn.Module
51-
The same model with buffers promoted to parameters.
52-
53-
"""
54-
if buffer_indices is None:
55-
# Dense head (34-45) + global feature branch (0-5)
56-
buffer_indices = list(range(0, 6)) + list(range(34, 46))
57-
58-
init_mod = dict(model.named_modules()).get("initializers")
59-
if init_mod is None:
60-
logger.debug("No 'initializers' submodule found; skipping buffer promotion.")
61-
return model
62-
63-
promoted = 0
64-
for idx in buffer_indices:
65-
name = f"onnx_initializer_{idx}"
66-
if name in init_mod._buffers:
67-
buf = init_mod._buffers.pop(name)
68-
init_mod._parameters[name] = torch.nn.Parameter(buf)
69-
promoted += 1
70-
71-
logger.debug(
72-
f"Promoted {promoted} buffers to parameters. "
73-
f"Total trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
74-
)
75-
return model
76-
77-
7826
def load_model(
7927
model: torch.nn.Module | PathLike | str | None = None,
8028
device: str | None = None,
@@ -148,9 +96,6 @@ def train(
14896
"""
14997
model = load_model(model, device)
15098

151-
# Promote ONNX initializer buffers (dense head) to trainable parameters
152-
model = promote_buffers_to_parameters(model)
153-
15499
# Parse datasets; setup loaders
155100
train_loader = DataLoader(
156101
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)