Skip to content

Commit b80beca

Browse files
committed
fix: centralize ignore_index masking in metrics
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 20d13a6 commit b80beca

2 files changed

Lines changed: 111 additions & 3 deletions

File tree

monai/losses/focal_loss.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
185185
else:
186186
loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)
187187

188-
num_of_classes = target.shape[1]
189-
190188
if mask is not None:
191189
loss = loss * mask
192190

@@ -213,3 +211,112 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
213211
else:
214212
broadcast_shape = [1, num_classes] + [1] * (loss.ndim - 2)
215213
loss = loss * cw.view(broadcast_shape)
214+
215+
if self.reduction == LossReduction.SUM.value:
216+
# Previously there was a mean over the last dimension, which did not
217+
# return a compatible BCE loss. To maintain backwards compatible
218+
# behavior we have a flag that performs this extra step, disable or
219+
# parameterize if necessary. (Or justify why the mean should be there)
220+
average_spatial_dims = True
221+
if average_spatial_dims:
222+
loss = loss.mean(dim=list(range(2, len(target.shape))))
223+
loss = loss.sum()
224+
225+
elif self.reduction == LossReduction.MEAN.value:
226+
if mask is not None:
227+
# Ensure we only sum the loss where the mask is 1
228+
# Then divide by the actual number of 1s in the mask
229+
loss = (loss * mask).sum() / mask.sum().clamp(min=1e-5)
230+
else:
231+
loss = loss.mean()
232+
233+
elif self.reduction == LossReduction.NONE.value:
234+
pass
235+
236+
return loss
237+
238+
239+
def softmax_focal_loss(
240+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
241+
) -> torch.Tensor:
242+
"""
243+
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
244+
245+
where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
246+
s_j is the unnormalized score for class j.
247+
"""
248+
input_ls = input.log_softmax(1)
249+
loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target
250+
251+
if alpha is not None:
252+
if isinstance(alpha, torch.Tensor):
253+
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
254+
else:
255+
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
256+
257+
if alpha_t.ndim == 0: # scalar
258+
alpha_val = alpha_t.item()
259+
# (1-alpha) for the background class and alpha for the other classes
260+
alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)
261+
else: # tensor (sequence)
262+
if alpha_t.shape[0] != target.shape[1]:
263+
raise ValueError(
264+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
265+
)
266+
alpha_fac = alpha_t
267+
268+
broadcast_dims = [-1] + [1] * len(target.shape[2:])
269+
alpha_fac = alpha_fac.view(broadcast_dims)
270+
loss = alpha_fac * loss
271+
272+
return loss
273+
274+
275+
def sigmoid_focal_loss(
276+
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None
277+
) -> torch.Tensor:
278+
"""
279+
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
280+
281+
where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
282+
"""
283+
# computing binary cross entropy with logits
284+
# equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
285+
# see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363
286+
loss: torch.Tensor = input - input * target - F.logsigmoid(input)
287+
288+
# sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
289+
# 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
290+
# 1-p if t==1; p if t==0 <=>
291+
# pfac, that is, the term (1 - pt)
292+
invprobs = F.logsigmoid(-input * (target * 2 - 1)) # reduced chance of overflow
293+
# (pfac.log() * gamma).exp() <=>
294+
# pfac.log().exp() ^ gamma <=>
295+
# pfac ^ gamma
296+
loss = (invprobs * gamma).exp() * loss
297+
298+
if alpha is not None:
299+
if isinstance(alpha, torch.Tensor):
300+
alpha_t = alpha.to(device=input.device, dtype=input.dtype)
301+
else:
302+
alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)
303+
304+
if alpha_t.ndim == 0: # scalar
305+
# alpha if t==1; (1-alpha) if t==0
306+
alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)
307+
else: # tensor (sequence)
308+
if alpha_t.shape[0] != target.shape[1]:
309+
raise ValueError(
310+
f"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]})."
311+
)
312+
# Reshape alpha for broadcasting: (1, C, 1, 1...)
313+
broadcast_dims = [-1] + [1] * len(target.shape[2:])
314+
alpha_t = alpha_t.view(broadcast_dims)
315+
# Apply per-class weight only to positive samples
316+
# For positive samples (target==1): multiply by alpha[c]
317+
# For negative samples (target==0): keep weight as 1.0
318+
alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))
319+
320+
loss = alpha_factor * loss
321+
322+
return loss

monai/metrics/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def create_ignore_mask(y: torch.Tensor, ignore_index: int | None) -> torch.Tenso
116116
num_classes = y.shape[1]
117117
if 0 <= ignore_index < num_classes:
118118
# Valid class index: exclude that channel
119-
return 1.0 - y[:, ignore_index : ignore_index + 1] # type: ignore[no-any-return]
119+
return 1.0 - y[:, ignore_index : ignore_index + 1] # type: ignore[no-any-return]
120120
else:
121121
# Sentinel value: exclude where all channels are zero
122122
return (y.sum(dim=1, keepdim=True) > 0).float()
@@ -353,6 +353,7 @@ def get_edge_surface_distance(
353353
use_subvoxels: bool = False,
354354
symmetric: bool = False,
355355
class_index: int = -1,
356+
mask: torch.Tensor | None = None,
356357
) -> tuple[
357358
tuple[torch.Tensor, torch.Tensor],
358359
tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor],

0 commit comments

Comments
 (0)