Skip to content

Commit

Permalink
updated nt_xent_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Nov 16, 2020
1 parent b1cb38a commit b8595eb
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,28 +264,42 @@ def optimizer_step(
else:
optimizer.step(closure=optimizer_closure)

def nt_xent_loss(self, out_1, out_2, temperature):
def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
"""
assume out_1 and out_2 are normalized
out_1: [batch_size, dim]
out_2: [batch_size, dim]
"""
# gather representations in case of distributed training
# out_1_dist: [batch_size * world_size, dim]
# out_2_dist: [batch_size * world_size, dim]
if torch.distributed.is_available() and torch.distributed.is_initialized():
out_1 = _gather_representations(out_1)
out_2 = _gather_representations(out_2)
out_1_dist = _gather_representations(out_1)
out_2_dist = _gather_representations(out_2)
else:
out_1_dist = out_1
out_2_dist = out_2

# out: [2 * batch_size, dim]
# out_dist: [2 * batch_size * world_size, dim]
out = torch.cat([out_1, out_2], dim=0)
n_samples = len(out)
out_dist = torch.cat([out_1_dist, out_2_dist], dim=0)

# Full similarity matrix
cov = torch.mm(out, out.t().contiguous())
# cov and sim: [2 * batch_size, 2 * batch_size * world_size]
# neg: [2 * batch_size]
cov = torch.mm(out, out_dist.t().contiguous())
sim = torch.exp(cov / temperature)
neg = sim.sum(dim=-1)

# Negative similarity
mask = ~torch.eye(n_samples, device=sim.device).bool()
neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)
# from each row, subtract e^1 to remove similarity measure for x1.x1
row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability

# Positive similarity :
# Positive similarity, pos becomes [2 * batch_size]
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0)

loss = -torch.log(pos / neg).mean()
loss = -torch.log(pos / (neg + eps)).mean()

return loss

Expand Down

0 comments on commit b8595eb

Please sign in to comment.