From b8595eb58468812f6b793e5ef20ddecae1db74e4 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 16 Nov 2020 13:24:56 -0500 Subject: [PATCH] updated nt_xent_loss --- .../self_supervised/simclr/simclr_module.py | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 2964394701..6969abd4cc 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -264,28 +264,42 @@ def optimizer_step( else: optimizer.step(closure=optimizer_closure) - def nt_xent_loss(self, out_1, out_2, temperature): + def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): + """ + assume out_1 and out_2 are normalized + out_1: [batch_size, dim] + out_2: [batch_size, dim] + """ # gather representations in case of distributed training + # out_1_dist: [batch_size * world_size, dim] + # out_2_dist: [batch_size * world_size, dim] if torch.distributed.is_available() and torch.distributed.is_initialized(): - out_1 = _gather_representations(out_1) - out_2 = _gather_representations(out_2) + out_1_dist = _gather_representations(out_1) + out_2_dist = _gather_representations(out_2) + else: + out_1_dist = out_1 + out_2_dist = out_2 + # out: [2 * batch_size, dim] + # out_dist: [2 * batch_size * world_size, dim] out = torch.cat([out_1, out_2], dim=0) - n_samples = len(out) + out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) - # Full similarity matrix - cov = torch.mm(out, out.t().contiguous()) + # cov and sim: [2 * batch_size, 2 * batch_size * world_size] + # neg: [2 * batch_size] + cov = torch.mm(out, out_dist.t().contiguous()) sim = torch.exp(cov / temperature) + neg = sim.sum(dim=-1) - # Negative similarity - mask = ~torch.eye(n_samples, device=sim.device).bool() - neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1) + # from each row, subtract e^1 to remove similarity measure for x1.x1 + row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device) + neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability - # Positive similarity : + # Positive similarity, pos becomes [2 * batch_size] pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) pos = torch.cat([pos, pos], dim=0) - loss = -torch.log(pos / neg).mean() + loss = -torch.log(pos / (neg + eps)).mean() return loss