-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtriplets_loss.py
116 lines (93 loc) · 5.01 KB
/
triplets_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
def _pairwise_distances(embeddings, squared=False):
"""Compute the 2D matrix of distances between all the embeddings.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
pairwise_distances: tensor of shape (batch_size, batch_size)
"""
# Get the dot product between all embeddings
# shape (batch_size, batch_size)
dot_product = torch.matmul(embeddings, embeddings.t())
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
# shape (batch_size,)
square_norm = torch.diag(dot_product)
# Compute the pairwise distance matrix as we have:
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
# shape (batch_size, batch_size)
distances = torch.unsqueeze(square_norm, 1) - 2.0 * dot_product + torch.unsqueeze(square_norm, 0)
# Because of computation errors, some distances might be negative so we put everything >= 0.0
distances = torch.max(distances, torch.tensor([0.0]).cuda())
if not squared:
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
# we need to add a small epsilon where distances == 0.0
mask = (torch.eq(distances, 0.0)).float()
distances = distances + mask * 1e-16
distances = torch.sqrt(distances)
# Correct the epsilon added: set the distances on the mask to be exactly 0.0
distances = distances * (torch.sub(1.0, mask))
return distances
def _get_triplet_mask(labels):
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
A triplet (i, j, k) is valid if:
- i, j, k are distinct
- labels[i] == labels[j] and labels[i] != labels[k]
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
"""
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.shape[0]).cuda()
indices_not_equal = torch.tensor([1.0]).cuda()-indices_equal
i_not_equal_j = torch.unsqueeze(indices_not_equal, 2)
i_not_equal_k = torch.unsqueeze(indices_not_equal, 1)
j_not_equal_k = torch.unsqueeze(indices_not_equal, 0)
distinct_indices = torch.mul(torch.mul(i_not_equal_j, i_not_equal_k), j_not_equal_k)
# Check if labels[i] == labels[j] and labels[i] != labels[k]
label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1)).float()
i_equal_j = torch.unsqueeze(label_equal, 2)
i_equal_k = torch.unsqueeze(label_equal, 1)
valid_labels = torch.mul(i_equal_j, torch.tensor([1.0]).cuda()-i_equal_k)
# Combine the two masks
mask = torch.mul(distinct_indices, valid_labels)
return mask
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
"""Build the triplet loss over a batch of embeddings.
We generate all the valid triplets and average the loss over the positive ones.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = _pairwise_distances(embeddings, squared=squared)
# shape (batch_size, batch_size, 1)
anchor_positive_dist = torch.unsqueeze(pairwise_dist, 2)
assert anchor_positive_dist.shape[2] == 1, "{}".format(anchor_positive_dist.shape)
# shape (batch_size, 1, batch_size)
anchor_negative_dist = torch.unsqueeze(pairwise_dist, 1)
assert anchor_negative_dist.shape[1] == 1, "{}".format(anchor_negative_dist.shape)
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
# and the 2nd (batch_size, 1, batch_size)
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
# Put to zero the invalid triplets
# (where label(a) != label(p) or label(n) == label(a) or a == p)
mask = _get_triplet_mask(labels)
mask = mask.float()
triplet_loss = torch.mul(mask, triplet_loss)
# Remove negative losses (i.e. the easy triplets)
triplet_loss = torch.max(triplet_loss, torch.tensor([0.0]).cuda())
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = torch.gt(triplet_loss, 1e-16).float()
num_positive_triplets = torch.sum(valid_triplets)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = torch.sum(triplet_loss) / (num_positive_triplets + 1e-16)
return triplet_loss