|
15 | 15 | from torch.nn.functional import ( |
16 | 16 | one_hot, |
17 | 17 | binary_cross_entropy, |
| 18 | + binary_cross_entropy_with_logits, |
18 | 19 | softplus, |
19 | 20 | ) |
20 | 21 |
|
@@ -203,16 +204,30 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: |
203 | 204 |
|
204 | 205 |
|
205 | 206 | class BinaryCrossEntropyLoss(LossFunction): |
206 | | - """Compute binary cross entropy loss. |
| 207 | + """Compute binary cross entropy loss.""" |
207 | 208 |
|
208 | | - Predictions are vector probabilities (i.e., values between 0 and 1), |
209 | | - and targets should be 0 and 1. |
210 | | - """ |
| 209 | + def __init__(self, from_logits: bool = False, *args: Any, **kwargs: Any): |
| 210 | + """Construct BinaryCrossEntropyLoss. |
| 211 | +
|
| 212 | + Args: |
| 213 | + from_logits: Whether the predictions are logits. |
| 214 | + NOTE: If True, the predictions are expected to be raw scores |
| 215 | + (i.e., not passed through a sigmoid function). If False, the |
| 216 | + predictions are expected to be probabilities |
| 217 | + (i.e., passed through a sigmoid function). |
| 218 | + """ |
| 219 | + super().__init__(*args, **kwargs) |
| 220 | + self._from_logits = from_logits |
211 | 221 |
|
212 | 222 | def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: |
213 | | - return binary_cross_entropy( |
214 | | - prediction.float(), target.float(), reduction="none" |
215 | | - ) |
| 223 | + if self._from_logits: |
| 224 | + return binary_cross_entropy_with_logits( |
| 225 | + prediction.float(), target.float(), reduction="none" |
| 226 | + ) |
| 227 | + else: |
| 228 | + return binary_cross_entropy( |
| 229 | + prediction.float(), target.float(), reduction="none" |
| 230 | + ) |
216 | 231 |
|
217 | 232 |
|
218 | 233 | class LogCMK(torch.autograd.Function): |
|
0 commit comments