@@ -185,8 +185,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
185185 else :
186186 loss = sigmoid_focal_loss (input , target , self .gamma , alpha_arg )
187187
188- num_of_classes = target .shape [1 ]
189-
190188 if mask is not None :
191189 loss = loss * mask
192190
@@ -213,3 +211,112 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
213211 else :
214212 broadcast_shape = [1 , num_classes ] + [1 ] * (loss .ndim - 2 )
215213 loss = loss * cw .view (broadcast_shape )
214+
215+ if self .reduction == LossReduction .SUM .value :
216+ # Previously there was a mean over the last dimension, which did not
217+ # return a compatible BCE loss. To maintain backwards compatible
218+ # behavior we have a flag that performs this extra step, disable or
219+ # parameterize if necessary. (Or justify why the mean should be there)
220+ average_spatial_dims = True
221+ if average_spatial_dims :
222+ loss = loss .mean (dim = list (range (2 , len (target .shape ))))
223+ loss = loss .sum ()
224+
225+ elif self .reduction == LossReduction .MEAN .value :
226+ if mask is not None :
227+ # Ensure we only sum the loss where the mask is 1
228+ # Then divide by the actual number of 1s in the mask
229+ loss = (loss * mask ).sum () / mask .sum ().clamp (min = 1e-5 )
230+ else :
231+ loss = loss .mean ()
232+
233+ elif self .reduction == LossReduction .NONE .value :
234+ pass
235+
236+ return loss
237+
238+
239+ def softmax_focal_loss (
240+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch .Tensor | None = None
241+ ) -> torch .Tensor :
242+ """
243+ FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
244+
245+ where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
246+ s_j is the unnormalized score for class j.
247+ """
248+ input_ls = input .log_softmax (1 )
249+ loss : torch .Tensor = - (1 - input_ls .exp ()).pow (gamma ) * input_ls * target
250+
251+ if alpha is not None :
252+ if isinstance (alpha , torch .Tensor ):
253+ alpha_t = alpha .to (device = input .device , dtype = input .dtype )
254+ else :
255+ alpha_t = torch .tensor (alpha , device = input .device , dtype = input .dtype )
256+
257+ if alpha_t .ndim == 0 : # scalar
258+ alpha_val = alpha_t .item ()
259+ # (1-alpha) for the background class and alpha for the other classes
260+ alpha_fac = torch .tensor ([1 - alpha_val ] + [alpha_val ] * (target .shape [1 ] - 1 )).to (loss )
261+ else : # tensor (sequence)
262+ if alpha_t .shape [0 ] != target .shape [1 ]:
263+ raise ValueError (
264+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
265+ )
266+ alpha_fac = alpha_t
267+
268+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
269+ alpha_fac = alpha_fac .view (broadcast_dims )
270+ loss = alpha_fac * loss
271+
272+ return loss
273+
274+
275+ def sigmoid_focal_loss (
276+ input : torch .Tensor , target : torch .Tensor , gamma : float = 2.0 , alpha : float | torch .Tensor | None = None
277+ ) -> torch .Tensor :
278+ """
279+ FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
280+
281+ where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
282+ """
283+ # computing binary cross entropy with logits
284+ # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')
285+ # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363
286+ loss : torch .Tensor = input - input * target - F .logsigmoid (input )
287+
288+ # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>
289+ # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>
290+ # 1-p if t==1; p if t==0 <=>
291+ # pfac, that is, the term (1 - pt)
292+ invprobs = F .logsigmoid (- input * (target * 2 - 1 )) # reduced chance of overflow
293+ # (pfac.log() * gamma).exp() <=>
294+ # pfac.log().exp() ^ gamma <=>
295+ # pfac ^ gamma
296+ loss = (invprobs * gamma ).exp () * loss
297+
298+ if alpha is not None :
299+ if isinstance (alpha , torch .Tensor ):
300+ alpha_t = alpha .to (device = input .device , dtype = input .dtype )
301+ else :
302+ alpha_t = torch .tensor (alpha , device = input .device , dtype = input .dtype )
303+
304+ if alpha_t .ndim == 0 : # scalar
305+ # alpha if t==1; (1-alpha) if t==0
306+ alpha_factor = target * alpha_t + (1 - target ) * (1 - alpha_t )
307+ else : # tensor (sequence)
308+ if alpha_t .shape [0 ] != target .shape [1 ]:
309+ raise ValueError (
310+ f"The length of alpha ({ alpha_t .shape [0 ]} ) must match the number of classes ({ target .shape [1 ]} )."
311+ )
312+ # Reshape alpha for broadcasting: (1, C, 1, 1...)
313+ broadcast_dims = [- 1 ] + [1 ] * len (target .shape [2 :])
314+ alpha_t = alpha_t .view (broadcast_dims )
315+ # Apply per-class weight only to positive samples
316+ # For positive samples (target==1): multiply by alpha[c]
317+ # For negative samples (target==0): keep weight as 1.0
318+ alpha_factor = torch .where (target == 1 , alpha_t , torch .ones_like (alpha_t ))
319+
320+ loss = alpha_factor * loss
321+
322+ return loss
0 commit comments