Skip to content

Commit 4ad65e4

Browse files
fix(losses): annotate preterm/bin_centers buffers to satisfy mypy
register_buffer leaves mypy inferring the broad ``Tensor | Module`` union, which fails the arithmetic on these attributes in parzen_windowing_gaussian. Declare them as ``torch.Tensor`` (the type at the gaussian-kernel use sites). Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
1 parent aa35cb1 commit 4ad65e4

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

monai/losses/image_dissimilarity.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def __init__(
233233
self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"])
234234
self.num_bins = num_bins
235235
self.kernel_type = kernel_type
236+
# declared as buffers so they move with the module (e.g. ``.to(device)``); only populated for the
237+
# gaussian kernel, hence the ``Tensor`` annotation reflects the type at the use sites in that path.
238+
self.preterm: torch.Tensor
239+
self.bin_centers: torch.Tensor
236240
self.register_buffer("preterm", None, persistent=False)
237241
self.register_buffer("bin_centers", None, persistent=False)
238242
if self.kernel_type == "gaussian":

0 commit comments

Comments
 (0)