|
23 | 23 | logger = logging.getLogger(__name__) |
24 | 24 |
|
25 | 25 |
|
26 | | -# TODO: Implement Lightning? |
27 | | - |
28 | | - |
29 | | -def promote_buffers_to_parameters( |
30 | | - model: torch.nn.Module, |
31 | | - buffer_indices: list[int] | None = None, |
32 | | -) -> torch.nn.Module: |
33 | | - """ |
34 | | - Promote ONNX initializer buffers to nn.Parameters so they become trainable. |
35 | | -
|
36 | | - ONNX-converted GraphModules (from onnx2torch) store dense/FC layer weights as |
37 | | - buffers on an ``initializers`` submodule, making them invisible to the optimizer. |
38 | | - This function converts selected buffers to nn.Parameters so they can be fine-tuned. |
39 | | -
|
40 | | - Parameters |
41 | | - ---------- |
42 | | - model |
43 | | - The loaded GraphModule from onnx2torch. |
44 | | - buffer_indices |
45 | | - Indices of ``onnx_initializer_*`` buffers to promote. If None, promotes the |
46 | | - global feature branch (0-5) and the final dense head (34-45). |
47 | | -
|
48 | | - Returns |
49 | | - ------- |
50 | | - torch.nn.Module |
51 | | - The same model with buffers promoted to parameters. |
52 | | -
|
53 | | - """ |
54 | | - if buffer_indices is None: |
55 | | - # Dense head (34-45) + global feature branch (0-5) |
56 | | - buffer_indices = list(range(0, 6)) + list(range(34, 46)) |
57 | | - |
58 | | - init_mod = dict(model.named_modules()).get("initializers") |
59 | | - if init_mod is None: |
60 | | - logger.debug("No 'initializers' submodule found; skipping buffer promotion.") |
61 | | - return model |
62 | | - |
63 | | - promoted = 0 |
64 | | - for idx in buffer_indices: |
65 | | - name = f"onnx_initializer_{idx}" |
66 | | - if name in init_mod._buffers: |
67 | | - buf = init_mod._buffers.pop(name) |
68 | | - init_mod._parameters[name] = torch.nn.Parameter(buf) |
69 | | - promoted += 1 |
70 | | - |
71 | | - logger.debug( |
72 | | - f"Promoted {promoted} buffers to parameters. " |
73 | | - f"Total trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}" |
74 | | - ) |
75 | | - return model |
76 | | - |
77 | | - |
78 | 26 | def load_model( |
79 | 27 | model: torch.nn.Module | PathLike | str | None = None, |
80 | 28 | device: str | None = None, |
@@ -148,9 +96,6 @@ def train( |
148 | 96 | """ |
149 | 97 | model = load_model(model, device) |
150 | 98 |
|
151 | | - # Promote ONNX initializer buffers (dense head) to trainable parameters |
152 | | - model = promote_buffers_to_parameters(model) |
153 | | - |
154 | 99 | # Parse datasets; setup loaders |
155 | 100 | train_loader = DataLoader( |
156 | 101 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers |
|
0 commit comments