Skip to content

fix(losses): register buffers in GlobalMutualInformationLoss#8872

Open
AlexanderSanin wants to merge 5 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/global-mutual-information-register-buffer
Open

fix(losses): register buffers in GlobalMutualInformationLoss#8872
AlexanderSanin wants to merge 5 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/global-mutual-information-register-buffer

Conversation

@AlexanderSanin
Copy link
Copy Markdown
Contributor

Summary

  • GlobalMutualInformationLoss stored preterm and bin_centers as plain tensor attributes when kernel_type="gaussian", so calling loss.to("cuda") or loss.cuda() did not move them to the target device
  • Replace the plain assignments with register_buffer(..., persistent=False), consistent with the pattern already applied to LocalNormalizedCrossCorrelationLoss in fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss #8818
  • The .to(img) calls in parzen_windowing_gaussian are retained for dtype coercion (e.g. float16 inference)

Test plan

  • python -m pytest tests/losses/image_dissimilarity/test_global_mutual_information_loss.py -v — all existing tests still pass
  • TestGlobalMutualInformationLossBuffers::test_gaussian_kernel_registers_bufferspreterm and bin_centers are in _buffers and have requires_grad=False
  • TestGlobalMutualInformationLossBuffers::test_bspline_kernel_has_no_gaussian_buffers — b-spline mode is unaffected
  • TestGlobalMutualInformationLossBuffers::test_gaussian_kernel_forward_correct — forward pass returns a scalar loss

Closes #8819

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hey @ericspod @aymuos15. Could you, please, have a look at this?

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 25, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

GlobalMutualInformationLoss now registers preterm and bin_centers as non-persistent buffers initialized to None, then populates them when kernel_type == "gaussian". Tests were added to verify buffer registration and properties for the gaussian kernel, absence for b-spline, that a gaussian forward returns a scalar tensor, and that the gaussian buffers move with the module to CUDA.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly summarizes the main change: registering buffers in GlobalMutualInformationLoss to fix device placement.
Description check ✅ Passed Description covers the bug, solution, rationale for .to(img) retention, and test plan, though the template checkbox for 'New tests added' is unmarked despite tests being added.
Linked Issues check ✅ Passed Changes fully implement the objective from #8819: registers preterm and bin_centers as non-persistent buffers in GlobalMutualInformationLoss [#8819], ensuring proper device movement and no gradient tracking.
Out of Scope Changes check ✅ Passed All changes directly address the linked issue: buffer registration in GlobalMutualInformationLoss and corresponding test coverage for gaussian vs. b-spline modes and device movement.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (3)

158-161: ⚡ Quick win

Add docstring per coding guidelines.

📝 Suggested docstring
 def test_bspline_kernel_has_no_gaussian_buffers(self):
+    """Verify b-spline kernel does not register gaussian-specific buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="b-spline")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 158 - 161, The test function
test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short
descriptive docstring at the top of the function explaining that it verifies
GlobalMutualInformationLoss(kernel_type="b-spline") does not populate
Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers"
are not in loss._buffers). Keep it concise and follow existing test docstring
style.

163-168: ⚡ Quick win

Add docstring per coding guidelines.

📝 Suggested docstring
 def test_gaussian_kernel_forward_correct(self):
+    """Verify gaussian kernel forward pass returns scalar loss tensor."""
     pred = torch.rand(2, 1, 8, 8, dtype=torch.float32)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 163 - 168, Add a docstring to the unit test function
test_gaussian_kernel_forward_correct that briefly describes what the test
verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian"
returns a scalar tensor and preserves shape), placing it directly under the def
line in that function; reference the function name
test_gaussian_kernel_forward_correct and the class/constructor
GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and
confirm the new docstring.

149-156: ⚡ Quick win

Add docstring per coding guidelines.

Docstrings required for all test methods describing purpose and expectations. As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

📝 Suggested docstring
 def test_gaussian_kernel_registers_buffers(self):
+    """Verify gaussian kernel registers preterm and bin_centers as non-trainable buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="gaussian")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 156, Add a Google-style docstring to the test method
test_gaussian_kernel_registers_buffers describing what is being tested (that
GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and
bin_centers as non-trainable buffers, that they move with .to(), and that
bin_centers has ndim == 3), including a short "Args" if needed and an "Expected"
or "Raises" note for the assertions; update the docstring inside the test
function definition (test_gaussian_kernel_registers_buffers) so it clearly
states the purpose and the expected conditions checked by the assertions.
monai/losses/image_dissimilarity.py (1)

236-237: 💤 Low value

Type annotations declared unconditionally but attributes are conditionally assigned.

These annotations are defined outside the gaussian conditional block, but the actual attributes are only created when kernel_type == "gaussian". While runtime behavior is correct (attributes only accessed in gaussian path), static type checkers may flag potential AttributeError for b-spline mode.

Consider either:

  • Moving annotations inside the conditional, or
  • Initializing to None and using Optional[torch.Tensor] type
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/image_dissimilarity.py` around lines 236 - 237, The attributes
self.preterm and self.bin_centers are only created when kernel_type ==
"gaussian" but currently annotated unconditionally; update their declarations to
reflect conditional creation by typing them as Optional[torch.Tensor] and
initialize them to None in the non-gaussian branch (or before the conditional)
so static type checkers know they may be absent, and ensure any gaussian-only
use sites (e.g., inside the gaussian branch) treat them as non-None; reference
the attributes self.preterm, self.bin_centers and the kernel_type == "gaussian"
conditional when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 148-169: Add a test that verifies Gaussian kernel buffers actually
move with the module: instantiate
GlobalMutualInformationLoss(kernel_type="gaussian"), if
torch.cuda.is_available() call loss_cuda = loss.to("cuda") (or loss.cuda()),
then assert loss_cuda.preterm.device.type == "cuda" and
loss_cuda.bin_centers.device.type == "cuda", create CUDA tensors for pred and
target and run result = loss_cuda(pred, target) and assert result.device.type ==
"cuda"; reference the GlobalMutualInformationLoss class and its buffers preterm
and bin_centers and add this as a new test method (e.g.,
test_gaussian_buffers_move_with_module) alongside the existing tests.

---

Nitpick comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 236-237: The attributes self.preterm and self.bin_centers are only
created when kernel_type == "gaussian" but currently annotated unconditionally;
update their declarations to reflect conditional creation by typing them as
Optional[torch.Tensor] and initialize them to None in the non-gaussian branch
(or before the conditional) so static type checkers know they may be absent, and
ensure any gaussian-only use sites (e.g., inside the gaussian branch) treat them
as non-None; reference the attributes self.preterm, self.bin_centers and the
kernel_type == "gaussian" conditional when making the change.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 158-161: The test function
test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short
descriptive docstring at the top of the function explaining that it verifies
GlobalMutualInformationLoss(kernel_type="b-spline") does not populate
Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers"
are not in loss._buffers). Keep it concise and follow existing test docstring
style.
- Around line 163-168: Add a docstring to the unit test function
test_gaussian_kernel_forward_correct that briefly describes what the test
verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian"
returns a scalar tensor and preserves shape), placing it directly under the def
line in that function; reference the function name
test_gaussian_kernel_forward_correct and the class/constructor
GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and
confirm the new docstring.
- Around line 149-156: Add a Google-style docstring to the test method
test_gaussian_kernel_registers_buffers describing what is being tested (that
GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and
bin_centers as non-trainable buffers, that they move with .to(), and that
bin_centers has ndim == 3), including a short "Args" if needed and an "Expected"
or "Raises" note for the assertions; update the docstring inside the test
function definition (test_gaussian_kernel_registers_buffers) so it clearly
states the purpose and the expected conditions checked by the assertions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 22b238ee-90d5-4b39-9f22-3df62dfea05d

📥 Commits

Reviewing files that changed from the base of the PR and between 0a8d945 and f20d3f6.

📒 Files selected for processing (2)
  • monai/losses/image_dissimilarity.py
  • tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (2)

149-185: ⚡ Quick win

Use full Google-style docstrings for new test methods.

Current one-line docstrings don’t include the required sections (Args, Returns, Raises) from the repo guideline.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 185, The tests use one-line docstrings; update each test's
docstring (test_gaussian_kernel_registers_buffers,
test_bspline_kernel_has_no_gaussian_buffers,
test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to
full Google-style docstrings that include a short summary plus Args (describe
pred/target shapes or when the test constructs the loss), Returns (what the test
asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if
none state "None"); keep the existing descriptive text as the summary and add
the three sections to meet the repo guideline.

149-163: ⚡ Quick win

Assert non-persistent buffer contract explicitly.

Please also verify preterm and bin_centers are excluded from state_dict() to lock in persistent=False behavior.

Proposed test additions
 def test_gaussian_kernel_registers_buffers(self):
     """preterm and bin_centers are registered as non-persistent buffers for gaussian kernel."""
     loss = GlobalMutualInformationLoss(kernel_type="gaussian")
     self.assertIn("preterm", loss._buffers)
     self.assertIn("bin_centers", loss._buffers)
     self.assertFalse(loss.preterm.requires_grad)
     self.assertFalse(loss.bin_centers.requires_grad)
     self.assertEqual(loss.bin_centers.ndim, 3)
+    state = loss.state_dict()
+    self.assertNotIn("preterm", state)
+    self.assertNotIn("bin_centers", state)

 def test_bspline_kernel_has_no_gaussian_buffers(self):
     """b-spline kernel does not register gaussian-specific buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="b-spline")
     self.assertIsNone(loss.preterm)
     self.assertIsNone(loss.bin_centers)
+    state = loss.state_dict()
+    self.assertNotIn("preterm", state)
+    self.assertNotIn("bin_centers", state)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 163, Update the tests for GlobalMutualInformationLoss to
assert the non-persistent buffer contract by checking that gaussian-specific
buffers do not appear in the module state dict: in
test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after
asserting preterm and bin_centers exist and have correct properties, also call
loss.state_dict() and assert "preterm" and "bin_centers" are not keys;
similarly, in test_bspline_kernel_has_no_gaussian_buffers (for
kernel_type="b-spline") confirm state_dict() also does not contain those keys
(and that loss.preterm and loss.bin_centers remain None) so persistent=False
behavior is enforced.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 149-185: The tests use one-line docstrings; update each test's
docstring (test_gaussian_kernel_registers_buffers,
test_bspline_kernel_has_no_gaussian_buffers,
test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to
full Google-style docstrings that include a short summary plus Args (describe
pred/target shapes or when the test constructs the loss), Returns (what the test
asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if
none state "None"); keep the existing descriptive text as the summary and add
the three sections to meet the repo guideline.
- Around line 149-163: Update the tests for GlobalMutualInformationLoss to
assert the non-persistent buffer contract by checking that gaussian-specific
buffers do not appear in the module state dict: in
test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after
asserting preterm and bin_centers exist and have correct properties, also call
loss.state_dict() and assert "preterm" and "bin_centers" are not keys;
similarly, in test_bspline_kernel_has_no_gaussian_buffers (for
kernel_type="b-spline") confirm state_dict() also does not contain those keys
(and that loss.preterm and loss.bin_centers remain None) so persistent=False
behavior is enforced.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 13a2592f-5cd6-4641-a30a-def8e3d5df80

📥 Commits

Reviewing files that changed from the base of the PR and between f20d3f6 and 20c702a.

📒 Files selected for processing (2)
  • monai/losses/image_dissimilarity.py
  • tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

@aymuos15
Copy link
Copy Markdown
Contributor

Happy to go through this. Any idea why the CI is failing?

@AlexanderSanin AlexanderSanin force-pushed the fix/global-mutual-information-register-buffer branch from 20c702a to 379b8a8 Compare May 26, 2026 07:06
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hi @aymuos15! The CI failures are pre-existing on the dev branch itself and are not related to this PR.

The root cause is that monai/networks/trt_compiler.py imports polygraphy.backend.common, which transitively imports cupy.testing._random, which does import pytest. In the packaging / full-dep (ubuntu-latest) CI jobs, pytest is not installed in the target environment (it's only in requirements-dev.txt), so any import of monai.networks triggers a ModuleNotFoundError and causes cascade failures across ~6759 tests.

You can verify by checking the most recent dev branch CI run — it shows the exact same FAILED (failures=17, errors=6759, skipped=1020) result on the same jobs.

All tests specific to this PR pass in CI (confirmed in the packaging job log):

test_bspline_kernel_has_no_gaussian_buffers ... ok
test_gaussian_buffers_move_with_module      ... ok
test_gaussian_kernel_forward_correct        ... ok
test_gaussian_kernel_registers_buffers      ... ok
test_ill_opts_{0..3}                        ... ok
test_ill_shape_{0..1}                       ... ok

I've also rebased the branch on the latest dev just now.

@aymuos15
Copy link
Copy Markdown
Contributor

aymuos15 commented May 26, 2026

Okay! Thanks a lot for the detailed reply.

Let's wait for that to get an upstream fix then? I think that is being tracked and worked on in real time.

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Okay! Thanks a lot for the detailed reply.

Let's wait for that to get an upstream fix then? I think that is being tracked and worked on in real time.

Hi @aymuos15 ,

I've addressed that in another PR: #8873
Please review , thanks

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

AlexanderSanin commented May 26, 2026

Status & Blockers

All code changes are complete and tests pass. Two things are blocking merge:

1. CI Failures (external regression — not our code)

The failing CI jobs (full-dep (ubuntu-latest), static-checks (pytype), mypy, codeformat) are caused by a cupy-cuda12x 14.1.0 regression released on May 26, 2026. That version introduced import pytest at module load time in cupy/testing/_random.py, which breaks any environment where pytest is not installed.

The failure chain is:
monai.networkstrt_compilerpolygraphy.backend.commonpolygraphy.util.utilcupy.testing._randomimport pytestModuleNotFoundError

This is pre-existing on the dev branch — the last fully green dev CI run (#26286623163) was on May 22 with cupy 14.0.1. All our PR-specific tests pass in the packaging job logs.

2. Awaiting Required Review

This PR needs at least one approving review from a code owner before it can be merged.

@KumoLiu @ericspod @Nic-Ma — could one of you take a look when you get a chance? The change is in monai/losses/image_dissimilarity.py and registers preterm / bin_centers as non-persistent buffers so they move correctly with .to(device) / .cuda(). Thanks!

@aymuos15
Copy link
Copy Markdown
Contributor

Hey, thanks for the fix there. But I think that is a deeper fix and one the maintainers should do. Will get to both when they are done.

@ericspod
Copy link
Copy Markdown
Member

Hi @AlexanderSanin a previous PR #8869 was looking at the same issue. Yours does something slightly different but we should merge that one first then integrate your changes and tests. I've updated so the tests should now run.

When kernel_type="gaussian", `preterm` and `bin_centers` were stored
as plain tensor attributes via simple assignment. This means they are
not registered in PyTorch's module buffer system, so calling
`loss.to("cuda")` or `loss.cuda()` does not move these tensors to the
target device. Each forward pass had to call `.to(img)` to patch the
device mismatch at runtime, which is both redundant and misleading.

Use `register_buffer(..., persistent=False)` so that both tensors are
properly tracked by the module and automatically move with `.to()` /
`.cuda()` / `.cpu()` calls, consistent with the pattern already used
by `LocalNormalizedCrossCorrelationLoss`.

The `.to(img)` calls in `parzen_windowing_gaussian` are retained for
dtype coercion (e.g. float16 inference).

Adds `TestGlobalMutualInformationLossBuffers` to verify buffer
registration and that b-spline mode does not create gaussian buffers.

Closes Project-MONAI#8819

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…uffer fix

Use register_buffer("preterm", None) / register_buffer("bin_centers", None)
unconditionally so that both buffers are always present in _buffers (with None
for b-spline). This avoids a KeyError that occurred when plain instance
attribute assignment conflicted with a subsequent register_buffer call.

Also add docstrings to the new test methods and a device-movement test that
verifies buffers follow the module when .cuda() is called.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
- Add state_dict assertions to verify non-persistent=False contract
- Update test docstrings to use Verify... format

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
register_buffer leaves mypy inferring the broad ``Tensor | Module`` union,
which fails the arithmetic on these attributes in parzen_windowing_gaussian.
Declare them as ``torch.Tensor`` (the type at the gaussian-kernel use sites).

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@AlexanderSanin AlexanderSanin force-pushed the fix/global-mutual-information-register-buffer branch from fa1a46e to 4ad65e4 Compare May 29, 2026 13:34
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Thanks @ericspod, that plan works for me — happy to have #8869 go in first.

I've rebased this branch onto the latest dev (so it now picks up the CI fixes) and resolved the one remaining static-checks (mypy) failure: register_buffer left mypy inferring the broad Tensor | Module union for preterm/bin_centers, which broke the arithmetic in parzen_windowing_gaussian. Added explicit torch.Tensor annotations to fix it.

A couple of things this PR does on top of #8869, in case you'd like to fold them in after merging it:

  • Registers both preterm and bin_centers as non-persistent buffers (Register GlobalMutualInformationLoss bin_centers as buffer #8869 registers only bin_centers), so the whole gaussian-kernel state moves consistently with .to(device) / .cuda().
  • Adds a focused test class (TestGlobalMutualInformationLossBuffers) covering: the non-persistent contract (buffers present in _buffers but excluded from state_dict()), the b-spline kernel registering no gaussian buffers, forward correctness, and device movement.

I'm glad to rebase these tests/changes directly on top of #8869 once it lands so they apply cleanly — just let me know your preference.

@aymuos15
Copy link
Copy Markdown
Contributor

Since .to(img) has to stay for the dtype coercion (it handles device too) and the results are bitwise accurate irrespective of the change, maybe this change is not required because we cant get rid of the .to(img) anyways?

Bare `return` in unittest does not signal a skip — pytest emits a
warning ("Did you mean to use assert instead of return?"). Replace
with `self.skipTest("CUDA not available")` which correctly marks the
test as skipped and is consistent with MONAI conventions.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hi @aymuos15, great question! You're right that .to(img) handles device too — but register_buffer is still needed for correct module semantics, and here's why:

Without register_buffer:

loss = GlobalMutualInformationLoss(kernel_type="gaussian")
loss = loss.to("cuda")      # moves all registered parameters/buffers — but preterm/bin_centers are plain attributes → stay on CPU
pred = pred.cuda()
target = target.cuda()
result = loss(pred, target)  # .to(img) allocates a new CUDA copy on every forward call

Every forward pass silently allocates a fresh copy of preterm/bin_centers on GPU. For repeated inference this is wasteful, and loss.preterm.device would still report cpu even after loss.to("cuda") — misleading for anyone inspecting the module's state.

With register_buffer(..., persistent=False):

loss = GlobalMutualInformationLoss(kernel_type="gaussian")
loss = loss.to("cuda")      # moves preterm and bin_centers to CUDA once
result = loss(pred, target)  # .to(img) is now a device no-op; only coerces dtype if needed (e.g. float16)

loss.preterm.device correctly shows cuda, and the buffers are allocated once — not once per forward call.

Why keep .to(img) at all?
Exactly as you noted — for dtype coercion. If the user feeds float16 inputs the buffers need to match. After this change .to(img) only does a dtype cast (cheap), never a device transfer (because device already matches after loss.to("cuda")).

Summary: register_buffer upholds the contract that .to(device) fully moves a module's non-trainable state, while .to(img) remains as a safety net for dtype. Both are needed but for different reasons.

Also pushed a small follow-up commit: replaced the bare return in test_gaussian_buffers_move_with_module with self.skipTest("CUDA not available") so pytest correctly marks the test as skipped rather than emitting a warning.

@aymuos15
Copy link
Copy Markdown
Contributor

I do agree with the unnecessary calls, but that is completley neglible if you benchmark it right?

@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

You're right — the overhead is negligible in practice. A few bytes re-pinned to GPU on each forward call won't show up in any real benchmark.

The real reason for register_buffer isn't performance, it's API correctness: PyTorch's contract for nn.Module is that module.to(device) moves all non-trainable module state. Without register_buffer, a user who writes:

loss = GlobalMutualInformationLoss(kernel_type='gaussian').to('cuda')

is left with loss.preterm still reporting cpu — inconsistent with how every other nn.Module buffer behaves, and likely confusing to anyone inspecting the module or writing device-agnostic code that introspects loss.buffers().

register_buffer is the documented PyTorch mechanism for exactly this case: a non-trainable tensor that is conceptually part of the module's state and should move with it. The non-persistent flag keeps it out of state_dict() since it's deterministically derived from the constructor args and doesn't need to be saved.

Happy to defer to maintainer preference on whether that level of correctness is worth the diff, but that's the motivation behind the change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] LocalNormalizedCrossCorrelationLoss: kernel not registered as buffer — silent gradient tracking + wrong device placement

3 participants