-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathMNL_Loss.py
87 lines (63 loc) · 2.91 KB
/
MNL_Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import itertools
from random import sample
EPS = 1e-2
eps = 1e-8
class Fidelity_Loss(torch.nn.Module):
def __init__(self):
super(Fidelity_Loss, self).__init__()
def forward(self, p, g):
g = g.view(-1, 1)
p = p.view(-1, 1)
loss = 1 - (torch.sqrt(p * g + eps) + torch.sqrt((1 - p) * (1 - g) + eps))
return torch.mean(loss)
class Focal_Fidelity_Loss(torch.nn.Module):
def __init__(self, gamma=1):
super(Focal_Fidelity_Loss, self).__init__()
self.gamma = gamma
def forward(self, p, g, alpha=1):
g = g.view(-1, 1)
p = p.view(-1, 1)
#loss = alpha * (1 - torch.exp(-self.gamma*torch.abs(p-g))*(torch.sqrt(p * g + eps) + torch.sqrt((1 - p) * (1 - g) + eps)))
fidelity_loss = 1 - (torch.sqrt(p * g + eps) + torch.sqrt((1 - p) * (1 - g) + eps))
focal_fidelity_loss = alpha * (1 - torch.exp(-self.gamma*fidelity_loss)*(torch.sqrt(p * g + eps) + torch.sqrt((1 - p) * (1 - g) + eps)))
return torch.mean(focal_fidelity_loss)
class Sigma_Fidelity_Loss(torch.nn.Module):
def __init__(self):
super(Sigma_Fidelity_Loss, self).__init__()
def forward(self, p, g, sigma1, sigma2):
g = g.view(-1, 1)
p = p.view(-1, 1)
loss = 1 - (torch.sqrt(p * g + eps) + torch.sqrt((1 - p) * (1 - g) + eps))
#loss = 0.5 * loss / (sigma1*sigma1 + sigma2*sigma2 + eps)
#loss += 0.5 * torch.log(sigma1*sigma1 + sigma2*sigma2 + eps)
loss = 0.5 * loss / (sigma1 * sigma1 + sigma2 * sigma2 + eps)
loss += 0.5 * (sigma1 * sigma1 + sigma2 * sigma2 + eps)
return torch.mean(loss)
class Pairwise_Fidelity_Loss(torch.nn.Module):
def __init__(self):
super(Pairwise_Fidelity_Loss, self).__init__()
def forward(self, pmos, gmos, gstd, pstd=None, ratio=1):
loss = []
pairs = []
combs = itertools.combinations(range(0, pmos.size(0)), 2)
for pair in combs:
pairs.append(pair)
pairs = sample(pairs, int(ratio*len(pairs)))
for pair in pairs:
idx1 = pair[0]
idx2 = pair[1]
if pstd is None:
constant = torch.sqrt(torch.Tensor([2])).to(pmos.device)
p = 0.5 * (1 + torch.erf((pmos[idx1] - pmos[idx2]) / constant))
else:
p_var = pstd[idx1] * pstd[idx1] + pstd[idx1] * pstd[idx2] + eps
p = 0.5 * (1 + torch.erf((pmos[idx1] - pmos[idx2]) / torch.sqrt(p_var)))
g_var = gstd[idx1] * gstd[idx1] + gstd[idx1] * gstd[idx2] + eps
g = 0.5 * (1 + torch.erf((gmos[idx1] - gmos[idx2]) / torch.sqrt(g_var)))
g = g.view(-1, 1)
p = p.view(-1, 1)
#loss += 1 - (torch.sqrt(p * g + eps) + torch.sqrt((1 - p) * (1 - g) + eps))
loss_item = 1 - (torch.sqrt(p * g + eps) + torch.sqrt((1 - p) * (1 - g) + eps))
loss.append(loss_item)
return loss