-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathCOT.py
executable file
·29 lines (25 loc) · 1.01 KB
/
COT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
import torch.nn.functional as F
# Complement Entropy (CE)
classes = 100
class ComplementEntropy(nn.Module):
def __init__(self):
super(ComplementEntropy, self).__init__()
# here we implemented step by step for corresponding to our formula
# described in the paper
def forward(self, yHat, y):
self.batch_size = len(y)
self.classes = classes
yHat = F.softmax(yHat, dim=1)
Yg = torch.gather(yHat, 1, torch.unsqueeze(y, 1))
Yg_ = (1 - Yg) + 1e-7 # avoiding numerical issues (first)
Px = yHat / Yg_.view(len(yHat), 1)
Px_log = torch.log(Px + 1e-10) # avoiding numerical issues (second)
y_zerohot = torch.ones(self.batch_size, self.classes).scatter_(
1, y.view(self.batch_size, 1).data.cpu(), 0)
output = Px * Px_log * y_zerohot.cuda()
loss = torch.sum(output)
loss /= float(self.batch_size)
loss /= float(self.classes)
return loss