-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
81 lines (56 loc) · 2.37 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import Levenshtein
import torch
import torch.nn.functional as F
def edit_distance(word1,word2):
distance=Levenshtein.distance(word1,word2)
return distance
def embedding_loss(obj_num,margin,x1,c1,x2=None,c2=None,lev_distance=None,t_max=None):
dis1=(1.0-F.cosine_similarity(x1,c1))
if obj_num==0:
if c2 is not None and lev_distance is not None:
dis2=(1.0-F.cosine_similarity(x1,c2))
lev_distance=torch.tensor(lev_distance,device=x1.device)
t_tensor=torch.full((lev_distance.size()),t_max)
t_tensor=t_tensor.to(x1.device)
min_tensor=(torch.min(t_tensor,lev_distance)/t_max).to(x1.device)
margin_tensor=margin*min_tensor
loss= torch.mean(F.relu(margin_tensor+ dis1 - dis2))
else:
raise ValueError(f"c2 of shape {c1.shape} is required but None type object is provided")
return loss
if obj_num==1:
if c2 is not None:
dis2=(1.0-F.cosine_similarity(c1,c2))
loss= torch.mean(F.relu(margin + dis1 - dis2))
else:
raise ValueError(f"c2 of shape {c1.shape} is required but None type object is provided")
return loss
if obj_num==2:
if x2 is not None:
dis2=(1.0-F.cosine_similarity(c1,x2))
loss= torch.mean(F.relu(margin + dis1 - dis2))
else:
raise ValueError(f"x2 of shape {x1.shape} is required but None type object is provided")
return loss
if obj_num==3:
if x2 is not None:
dis2=(1.0-F.cosine_similarity(x1,x2))
loss= torch.mean(F.relu(margin + dis1 - dis2))
else:
raise ValueError(f"x2 of shape {x1.shape} is required but None type object is provided")
return loss
def contrastive_loss(obj,margin,x1,c1,x2=None,c2=None,lev_distance=None,t_max=None):
loss=0
for obj_num in obj:
loss+=embedding_loss(obj_num,margin,x1,c1,x2,c2,lev_distance,t_max)
return loss
if __name__=='__main__':
obj_num=0
margin=0.5
lev_distance=[i for i in range(32)]
t_max=9
x1=torch.randn((32,1024))
c1=torch.randn((32,1024))
c2=torch.randn((32,1024))
loss=embedding_loss(obj_num=obj_num,margin=margin,x1=x1,c1=c1,c2=c2,lev_distance=lev_distance,t_max=t_max)
print(loss)