Skip to content

fix: correct param counts for modules shared across parents (#327)#396

Open
Mikyx-1 wants to merge 2 commits into
TylerYep:mainfrom
Mikyx-1:fix/issue-327
Open

fix: correct param counts for modules shared across parents (#327)#396
Mikyx-1 wants to merge 2 commits into
TylerYep:mainfrom
Mikyx-1:fix/issue-327

Conversation

@Mikyx-1

@Mikyx-1 Mikyx-1 commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

Problem

Fixes #327.

summary() over-counts parameters (and corrupts the layer tree) for models where a single module instance is shared across several parents — most commonly one activation module passed into every block — especially when combined with nested ModuleLists. The reporter's VNet encoder shows 2,903,728 params instead of the true 2,524,144 (off by +379,584).

Minimal reproduction:

import torch.nn as nn, torchinfo

class Block(nn.Module):
    def __init__(self, act):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(4, 4) for _ in range(2)])
        self.act = act                      # shared instance

    def forward(self, x):
        for layer in self.layers:
            x = self.act(layer(x))
        return x

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        shared = nn.ReLU()                  # one instance...
        self.blocks = nn.ModuleList([Block(shared) for _ in range(3)])  # ...reused

    def forward(self, x):
        for b in self.blocks:
            x = b(x)
        return x

torchinfo.summary(Net(), input_size=(1, 4))   # reports 160, true is 120

Root cause

Two independent issues, both stemming from a shared module having one parent recorded instead of its several real parents:

  1. Hierarchy (torchinfo.py). The forward pre-hook captured (var_name, depth, parent_info) at registration time, and the per-module hook dedup kept only the last parent encountered during traversal. So every runtime execution of a shared module reported the same (wrong) parent for all but one call site. This scrambled the layer tree, and get_children_layers() then grouped the wrong nodes.

  2. Counting (layer_info.py). leftover_params() excluded recursive children from its subtraction. A module shared across parents is marked (recursive) at every parent after the first, so its params (already counted at the first occurrence) were re-attributed to each subsequent parent's "leftover" — counting them once per parent.

Fix

  1. Resolve the parent dynamically at execution time. Accumulate every structural context a module is reached through, maintain a runtime call stack via the pre/post hooks, and select the context whose nearest executing ancestor is the current stack top. Single-parent modules (the overwhelmingly common case) are unchanged.

  2. _leftover() helper shared by leftover_params / leftover_trainable_params / leftover_param_bytes: subtract each distinct child once (keyed by layer_id) and skip recursive subtrees. This also fixes weight-shared parameterized modules (e.g. one Conv/Linear reused across blocks), which previously counted once per parent.

After the fix the reported total matches sum(p.numel() for p in model.parameters()) for these models.

Testing

  • New fixture SharedModuleInNestedList + test_shared_module_in_nested_list (with snapshot), asserting the column total equals the deduplicated param count.
  • Verified the full existing suite has no change in pass/fail status (the fix is behavior-preserving for all non-shared models — RecursiveNet, ReuseReLU, ReuseLinear, SimpleRNN, etc. produce byte-identical output).
  • Verified against real torchvision models (resnet18, densenet121, mobilenet_v3_small, maxvit_t) — unchanged — and the reporter's VNet, now correct.

Scope

This addresses over-counting from shared module instances. The separate over-counting caused by tied parameter tensors across distinct modules (e.g. tied embeddings in #322 / #377) is a different root cause and is intentionally out of scope for this PR.

🤖 Generated with Claude Code

…#327)

A module instance shared by several parents (e.g. one nn.ReLU() passed
into every block, as in the reported VNet) was counted incorrectly,
inflating the total parameters — especially when combined with nested
ModuleLists.

Two root causes, both stemming from a shared module having one parent
recorded instead of many:

1. Hierarchy (torchinfo.py): the pre-hook captured (var_name, depth,
   parent_info) at registration time and kept only the last parent, so
   every execution of a shared module reported the wrong parent. This
   scrambled the layer tree and mis-grouped children. Fixed by resolving
   the parent dynamically at execution time: accumulate every structural
   context a module is reached through, maintain a runtime call stack via
   the pre/post hooks, and select the context whose nearest executing
   ancestor is the current stack top. Single-parent modules are unchanged.

2. Counting (layer_info.py): leftover_params() excluded recursive children
   from its subtraction, re-attributing a recursive child's params (already
   counted at their real occurrence) to the parent — counting a shared
   parameterized module once per parent. Fixed with a shared _leftover()
   helper that subtracts each distinct child once (keyed by layer_id) and
   skips recursive subtrees.

Adds the SharedModuleInNestedList fixture and a regression test. Verified
no behavioral change for existing models (RecursiveNet, ReuseReLU,
ReuseLinear, SimpleRNN, etc. all produce identical output).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Comment thread torchinfo/torchinfo.py
return [input_size]


LayerContext = tuple[str, int, "LayerInfo | None"]

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move new types to the top of the file

@TylerYep

TylerYep commented Jun 7, 2026

Copy link
Copy Markdown
Owner

AI-generated code review:

Critical Feedback & Code Review

Correctness: Generally Sound, But Assumptions Warrant Scrutiny

1. Hook Ordering Assumption – POTENTIAL RISK

# From construct_hook (line 647-650)
if module_stack:
    module_stack.pop()

Issue: The code assumes LIFO (Last-In-First-Out) hook firing order, but PyTorch's hook documentation doesn't guarantee this formally. The comment says:

"Hooks fire LIFO, so the top of the stack is always this module's own info."

Critical Question: What if a user registers custom hooks? PyTorch 2.0+ allows registering hooks with priorities, and hooks from different sources could fire in unexpected orders.

Scenarios where this could break:

  • User's custom register_forward_hook registered before this one
  • Hooks registered globally or via metaclasses
  • Experimental PyTorch features with async hooks

Recommendation:

# Safer: match the module explicitly
if module_stack and module_stack[-1].module is module:
    module_stack.pop()
else:
    # Log warning or raise error
    raise RuntimeError(f"Stack mismatch: expected {module}, got {module_stack[-1].module if module_stack else 'empty'}")

2. Container Module Assumption – ARCHITECTURAL FLAW

# From resolve_layer_context (line 596-598)
# Walk up past containers (ModuleList/ModuleDict have no forward and never
# appear on the runtime stack) to the nearest currently-executing ancestor.

Hard-coded assumption: The code assumes that ModuleList/ModuleDict never execute forward(), so they won't appear on the stack.

This is correct for standard PyTorch, but breaks with:

  • Custom ModuleList subclasses that override forward()
  • Third-party libraries that create container modules with custom logic (e.g., fairseq, huggingface transformers)
  • User code that adds forward() logic to containers

Example failure case:

class ExecutableModuleList(nn.ModuleList):
    def forward(self, x):
        # Custom logic
        return x

# This would appear on the stack, breaking the assumption

Recommendation: Make it configurable or add runtime detection:

# Check if module has a non-trivial forward (not just nn.ModuleList default)
NON_EXECUTING_CONTAINERS = (nn.ModuleList, nn.ModuleDict, nn.Sequential)
# But note: Sequential *can* execute its forward

3. Parent Resolution Logic – EDGE CASE NOT COVERED

# From resolve_layer_context (line 595-608)
for context in contexts:
    parent = context[2]
    while parent is not None and id(parent.module) not in on_stack:
        parent = parent.parent_info
    if parent is not None and parent.module is top.module:
        return context
return contexts[0]  # Fallback to first context

Problem: If none of the contexts match, it silently falls back to the first one. This can mask bugs.

Scenario:

class Diamond(nn.Module):
    def __init__(self):
        super().__init__()
        shared = nn.Linear(4, 4)
        # shared is reached through TWO independent parent chains
        self.path_a = shared
        self.path_b = shared
        
    def forward(self, x):
        x = self.path_a(x)
        x = self.path_b(x)  # During second call, which context is "correct"?
        return x

Both paths are equally valid, but the algorithm picks the first context, which may be misleading.

Recommendation:

matched_context = contexts[0]  # default
for context in contexts:
    parent = context[2]
    while parent is not None and id(parent.module) not in on_stack:
        parent = parent.parent_info
    if parent is not None and parent.module is top.module:
        matched_context = context
        break

# Optional: warn if we couldn't find an exact match
if matched_context is contexts[0] and len(contexts) > 1:
    # Could log: "Ambiguous parent context, picking first"
    pass

return matched_context

🔴 Code Readability Issues

4. Unclear Variable Naming: on_stack vs module_stack

on_stack = {id(frame.module) for frame in module_stack}

The variable on_stack stores IDs, while module_stack stores LayerInfo objects. This is a semantic mismatch that causes confusion.

Better:

executing_module_ids = {id(layer_info.module) for layer_info in module_stack}

5. Type Hint Ambiguity: LayerContext

LayerContext = tuple[str, int, "LayerInfo | None"]

Issues:

  • Uses string literal for forward reference to LayerInfo, but it's already imported
  • No docstring explaining what each tuple element means
  • Calling code unpacks it without names: var_name, curr_depth, parent_info = resolve_layer_context(...)

Better approach:

from dataclasses import dataclass

@dataclass(frozen=True)
class LayerContext:
    """Structural context for a module during hierarchy traversal."""
    var_name: str
    depth: int
    parent_info: LayerInfo | None

This also:

  • Eliminates tuple unpacking errors
  • Makes intent clearer
  • Enables better IDE support
  • Improves testability

6. The _leftover() Helper – Unclear Loop Logic

i, n = 0, len(children)
while i < n:
    child = children[i]
    if child.is_recursive:
        if child.layer_id not in seen:
            seen.add(child.layer_id)
            total -= cast(\"int\", getattr(child, attr))
        # Skip the rest of the recursive subtree; it is part of `child`.
        i += 1
        while i < n and children[i].depth > child.depth:
            i += 1
        continue
    seen.add(child.layer_id)
    total -= (
        cast(\"int\", getattr(child, attr))
        if child.is_leaf_layer
        else child._leftover(attr)
    )
    i += 1
return total

Readability problems:

  1. Manual index management (Python anti-pattern; should use for loop with enumerate)
  2. Nested loop to skip subtree is hard to follow
  3. Mix of continue with manual increment is error-prone

Cleaner version:

def _leftover(self, attr: str) -> int:
    total = cast("int", getattr(self, attr))
    seen: set[int] = set()
    
    i = 0
    while i < len(self.children):
        child = self.children[i]
        
        if child.is_recursive:
            # Subtract each recursive child once
            if child.layer_id not in seen:
                seen.add(child.layer_id)
                total -= cast("int", getattr(child, attr))
            
            # Skip recursive subtree (all descendants until next sibling/higher-depth)
            start = i + 1
            while start < len(self.children) and self.children[start].depth > child.depth:
                start += 1
            i = start
        else:
            # Regular child: subtract its content
            seen.add(child.layer_id)
            total -= (
                cast("int", getattr(child, attr))
                if child.is_leaf_layer
                else child._leftover(attr)
            )
            i += 1
    
    return total

Still not ideal; better would be to separate the tree traversal concern from parameter counting.


🟡 Maintainability & Future-Proofing Issues

7. Tight Coupling: Module Stack State is Implicit

The module_stack is shared state across hook instances. If a hook is removed/re-registered, state can leak.

Scenario:

# User manually calls model(x1) then model(x2)
# If exception occurs during first forward, module_stack isn't cleared
# Second call has corrupted state

Recommendation: Use a context manager:

@contextmanager
def _hook_execution_context():
    module_stack: list[LayerInfo] = []
    yield module_stack
    assert len(module_stack) == 0, "Hook stack not cleaned up"

8. No Invariant Checks

The code assumes:

  • module_contexts[module_id] always has ≥1 entry when hook fires
  • layer_ids set tracks all recursive modules consistently
  • global_layer_info[id(module)] lookup never fails

None of these are validated.

Add defensive checks:

def resolve_layer_context(contexts, module_stack):
    if not contexts:
        raise RuntimeError(f"No contexts registered for module")
    # ... rest of logic

9. Documentation Gaps

Missing:

  • High-level explanation of how module_contexts and module_stack interact
  • Why module_stack must be maintained in lock-step with hook calls
  • Why the algorithm can't distinguish between diamond-shaped parent graphs
  • Performance implications of on_stack set building on every call

Suggested docstring:

def resolve_layer_context(contexts: list[LayerContext], module_stack: list[LayerInfo]) -> LayerContext:
    """
    Resolve the correct parent context for a shared module at runtime.
    
    Background:
        Shared modules are registered during static traversal with *all* possible 
        (parent, depth) pairs. During forward pass execution, we need to pick the 
        context matching the current call site.
    
    Algorithm:
        1. If only one context exists (non-shared case), return it (fast path)
        2. Build set of currently-executing module IDs from the stack
        3. For each registered context, walk up its parent chain to find the nearest
           ancestor that's currently executing (skipping non-executing containers)
        4. Return the context whose ancestor is the top of the execution stack
    
    Edge cases:
        - Diamond-shaped parent hierarchies: ambiguous; picks first matching context
        - Custom containers with forward(): assumes standard containers are transparent
        - Async execution: ASSUMES synchronous hook firing (will break with async)
    
    Args:
        contexts: All (var_name, depth, parent) tuples for this module
        module_stack: Runtime stack of currently-executing modules
    
    Returns:
        The matching context for this call site
    """

📋 Summary Table: Risk Assessment

Issue Severity Impact Mitigation
Hook firing order assumption HIGH Silent failures if hooks interleave Runtime verification
Container module hard-coding MEDIUM Breaks with custom containers Configuration/detection
Silent fallback in context resolution MEDIUM Masking ambiguity Log warnings
Manual loop index management LOW Maintenance burden Refactor to iterator
No type-safe context passing LOW IDE support, debugging Use dataclass
Missing invariant checks MEDIUM Hard-to-debug failures Add assertions
Implicit state management MEDIUM Resource leaks Use context managers

Recommendation Before Merge

  1. (Minor) Move type definitions to file top [already noted in review]
  2. 🔴 (Must-have) Add defensive check: assert module_stack[-1].module is module in post-hook
  3. 🟡 (Should-have) Document the LIFO hook assumption + failure modes
  4. 🟡 (Nice-to-have) Convert LayerContext to dataclass for clarity
  5. 🟡 (Nice-to-have) Add test for diamond-shaped sharing patterns

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

networks with ModuleList

2 participants