-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss.py
More file actions
21 lines (19 loc) · 881 Bytes
/
loss.py
File metadata and controls
21 lines (19 loc) · 881 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
class CEL_Jaccard:
def __init__(self, jaccard_weight=0.3, num_classes=8):
self.nll_loss = nn.NLLLoss()
self.jaccard_weight = jaccard_weight
self.num_classes = num_classes
def __call__(self, outputs, targets):
loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)
if self.jaccard_weight:
eps = 1e-15 # Statistical stability
for cls in range(self.num_classes):
jaccard_target = (targets == cls).float()
jaccard_output = outputs[:, cls].exp()
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
denom = union - intersection
loss -= torch.log((intersection + eps) / (denom + eps)) * self.jaccard_weight
return loss