Skip to content

Commit fd53394

Browse files
authored
Merge pull request graphnet-team#847 from sevmag/bce_with_logits
Bce with logits
2 parents de50e86 + 0a89fa8 commit fd53394

1 file changed

Lines changed: 22 additions & 7 deletions

File tree

src/graphnet/training/loss_functions.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.nn.functional import (
1616
one_hot,
1717
binary_cross_entropy,
18+
binary_cross_entropy_with_logits,
1819
softplus,
1920
)
2021

@@ -203,16 +204,30 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
203204

204205

205206
class BinaryCrossEntropyLoss(LossFunction):
206-
"""Compute binary cross entropy loss.
207+
"""Compute binary cross entropy loss."""
207208

208-
Predictions are vector probabilities (i.e., values between 0 and 1),
209-
and targets should be 0 and 1.
210-
"""
209+
def __init__(self, from_logits: bool = False, *args: Any, **kwargs: Any):
210+
"""Construct BinaryCrossEntropyLoss.
211+
212+
Args:
213+
from_logits: Whether the predictions are logits.
214+
NOTE: If True, the predictions are expected to be raw scores
215+
(i.e., not passed through a sigmoid function). If False, the
216+
predictions are expected to be probabilities
217+
(i.e., passed through a sigmoid function).
218+
"""
219+
super().__init__(*args, **kwargs)
220+
self._from_logits = from_logits
211221

212222
def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
213-
return binary_cross_entropy(
214-
prediction.float(), target.float(), reduction="none"
215-
)
223+
if self._from_logits:
224+
return binary_cross_entropy_with_logits(
225+
prediction.float(), target.float(), reduction="none"
226+
)
227+
else:
228+
return binary_cross_entropy(
229+
prediction.float(), target.float(), reduction="none"
230+
)
216231

217232

218233
class LogCMK(torch.autograd.Function):

0 commit comments

Comments
 (0)