fix: correct param counts for modules shared across parents (#327)#396
fix: correct param counts for modules shared across parents (#327)#396Mikyx-1 wants to merge 2 commits into
Conversation
…#327) A module instance shared by several parents (e.g. one nn.ReLU() passed into every block, as in the reported VNet) was counted incorrectly, inflating the total parameters — especially when combined with nested ModuleLists. Two root causes, both stemming from a shared module having one parent recorded instead of many: 1. Hierarchy (torchinfo.py): the pre-hook captured (var_name, depth, parent_info) at registration time and kept only the last parent, so every execution of a shared module reported the wrong parent. This scrambled the layer tree and mis-grouped children. Fixed by resolving the parent dynamically at execution time: accumulate every structural context a module is reached through, maintain a runtime call stack via the pre/post hooks, and select the context whose nearest executing ancestor is the current stack top. Single-parent modules are unchanged. 2. Counting (layer_info.py): leftover_params() excluded recursive children from its subtraction, re-attributing a recursive child's params (already counted at their real occurrence) to the parent — counting a shared parameterized module once per parent. Fixed with a shared _leftover() helper that subtracts each distinct child once (keyed by layer_id) and skips recursive subtrees. Adds the SharedModuleInNestedList fixture and a regression test. Verified no behavioral change for existing models (RecursiveNet, ReuseReLU, ReuseLinear, SimpleRNN, etc. all produce identical output). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
| return [input_size] | ||
|
|
||
|
|
||
| LayerContext = tuple[str, int, "LayerInfo | None"] |
There was a problem hiding this comment.
move new types to the top of the file
AI-generated code review:Critical Feedback & Code Review✅ Correctness: Generally Sound, But Assumptions Warrant Scrutiny1. Hook Ordering Assumption – POTENTIAL RISK# From construct_hook (line 647-650)
if module_stack:
module_stack.pop()Issue: The code assumes LIFO (Last-In-First-Out) hook firing order, but PyTorch's hook documentation doesn't guarantee this formally. The comment says:
Critical Question: What if a user registers custom hooks? PyTorch 2.0+ allows registering hooks with priorities, and hooks from different sources could fire in unexpected orders. Scenarios where this could break:
Recommendation: # Safer: match the module explicitly
if module_stack and module_stack[-1].module is module:
module_stack.pop()
else:
# Log warning or raise error
raise RuntimeError(f"Stack mismatch: expected {module}, got {module_stack[-1].module if module_stack else 'empty'}")2. Container Module Assumption – ARCHITECTURAL FLAW# From resolve_layer_context (line 596-598)
# Walk up past containers (ModuleList/ModuleDict have no forward and never
# appear on the runtime stack) to the nearest currently-executing ancestor.Hard-coded assumption: The code assumes that This is correct for standard PyTorch, but breaks with:
Example failure case: class ExecutableModuleList(nn.ModuleList):
def forward(self, x):
# Custom logic
return x
# This would appear on the stack, breaking the assumptionRecommendation: Make it configurable or add runtime detection: # Check if module has a non-trivial forward (not just nn.ModuleList default)
NON_EXECUTING_CONTAINERS = (nn.ModuleList, nn.ModuleDict, nn.Sequential)
# But note: Sequential *can* execute its forward3. Parent Resolution Logic – EDGE CASE NOT COVERED# From resolve_layer_context (line 595-608)
for context in contexts:
parent = context[2]
while parent is not None and id(parent.module) not in on_stack:
parent = parent.parent_info
if parent is not None and parent.module is top.module:
return context
return contexts[0] # Fallback to first contextProblem: If none of the contexts match, it silently falls back to the first one. This can mask bugs. Scenario: class Diamond(nn.Module):
def __init__(self):
super().__init__()
shared = nn.Linear(4, 4)
# shared is reached through TWO independent parent chains
self.path_a = shared
self.path_b = shared
def forward(self, x):
x = self.path_a(x)
x = self.path_b(x) # During second call, which context is "correct"?
return xBoth paths are equally valid, but the algorithm picks the first context, which may be misleading. Recommendation: matched_context = contexts[0] # default
for context in contexts:
parent = context[2]
while parent is not None and id(parent.module) not in on_stack:
parent = parent.parent_info
if parent is not None and parent.module is top.module:
matched_context = context
break
# Optional: warn if we couldn't find an exact match
if matched_context is contexts[0] and len(contexts) > 1:
# Could log: "Ambiguous parent context, picking first"
pass
return matched_context🔴 Code Readability Issues4. Unclear Variable Naming:
|
| Issue | Severity | Impact | Mitigation |
|---|---|---|---|
| Hook firing order assumption | HIGH | Silent failures if hooks interleave | Runtime verification |
| Container module hard-coding | MEDIUM | Breaks with custom containers | Configuration/detection |
| Silent fallback in context resolution | MEDIUM | Masking ambiguity | Log warnings |
| Manual loop index management | LOW | Maintenance burden | Refactor to iterator |
| No type-safe context passing | LOW | IDE support, debugging | Use dataclass |
| Missing invariant checks | MEDIUM | Hard-to-debug failures | Add assertions |
| Implicit state management | MEDIUM | Resource leaks | Use context managers |
Recommendation Before Merge
- (Minor) Move type definitions to file top [already noted in review]
- 🔴 (Must-have) Add defensive check:
assert module_stack[-1].module is modulein post-hook - 🟡 (Should-have) Document the LIFO hook assumption + failure modes
- 🟡 (Nice-to-have) Convert
LayerContextto dataclass for clarity - 🟡 (Nice-to-have) Add test for diamond-shaped sharing patterns
Problem
Fixes #327.
summary()over-counts parameters (and corrupts the layer tree) for models where a single module instance is shared across several parents — most commonly oneactivationmodule passed into every block — especially when combined with nestedModuleLists. The reporter's VNet encoder shows2,903,728params instead of the true2,524,144(off by+379,584).Minimal reproduction:
Root cause
Two independent issues, both stemming from a shared module having one parent recorded instead of its several real parents:
Hierarchy (
torchinfo.py). The forward pre-hook captured(var_name, depth, parent_info)at registration time, and the per-module hook dedup kept only the last parent encountered during traversal. So every runtime execution of a shared module reported the same (wrong) parent for all but one call site. This scrambled the layer tree, andget_children_layers()then grouped the wrong nodes.Counting (
layer_info.py).leftover_params()excluded recursive children from its subtraction. A module shared across parents is marked(recursive)at every parent after the first, so its params (already counted at the first occurrence) were re-attributed to each subsequent parent's "leftover" — counting them once per parent.Fix
Resolve the parent dynamically at execution time. Accumulate every structural context a module is reached through, maintain a runtime call stack via the pre/post hooks, and select the context whose nearest executing ancestor is the current stack top. Single-parent modules (the overwhelmingly common case) are unchanged.
_leftover()helper shared byleftover_params/leftover_trainable_params/leftover_param_bytes: subtract each distinct child once (keyed bylayer_id) and skip recursive subtrees. This also fixes weight-shared parameterized modules (e.g. oneConv/Linearreused across blocks), which previously counted once per parent.After the fix the reported total matches
sum(p.numel() for p in model.parameters())for these models.Testing
SharedModuleInNestedList+test_shared_module_in_nested_list(with snapshot), asserting the column total equals the deduplicated param count.RecursiveNet,ReuseReLU,ReuseLinear,SimpleRNN, etc. produce byte-identical output).resnet18,densenet121,mobilenet_v3_small,maxvit_t) — unchanged — and the reporter's VNet, now correct.Scope
This addresses over-counting from shared module instances. The separate over-counting caused by tied parameter tensors across distinct modules (e.g. tied embeddings in #322 / #377) is a different root cause and is intentionally out of scope for this PR.
🤖 Generated with Claude Code