Skip to content

Commit

Permalink
DenseCL init weights copy query encoder weights to key encoder. (#411)
Browse files Browse the repository at this point in the history
* DenseCL init weights copy query encoder weights to key encoder.

* Logger prints that key encoder is initialized with query encoder.
  • Loading branch information
lorinczszabolcs authored and fangyixiao18 committed Sep 1, 2022
1 parent 1bb524b commit 6e6bec1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
25 changes: 20 additions & 5 deletions mmselfsup/models/algorithms/densecl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.utils.logging import logger_initialized, print_log

from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
concat_all_gather)
Expand Down Expand Up @@ -49,11 +50,6 @@ def __init__(self,
self.encoder_k = nn.Sequential(
build_backbone(backbone), build_neck(neck))

for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False

self.backbone = self.encoder_q[0]
assert head is not None
self.head = build_head(head)
Expand All @@ -71,6 +67,25 @@ def __init__(self,
self.queue2 = nn.functional.normalize(self.queue2, dim=0)
self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long))

def init_weights(self):
"""Init weights and copy query encoder init weights to key encoder."""
super().init_weights()

# Get the initialized logger, if not exist,
# create a logger named `mmselfsup`
logger_names = list(logger_initialized.keys())
logger_name = logger_names[0] if logger_names else 'mmselfsup'

# log that key encoder is initialized by the query encoder
print_log(
'Key encoder is initialized by the query encoder.',
logger=logger_name)

for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False

@torch.no_grad()
def _momentum_update_key_encoder(self):
"""Momentum update of the key encoder."""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_models/test_algorithms/test_densecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def test_densecl():
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
assert alg.queue2.size() == torch.Size([feat_dim, queue_len])

alg.init_weights()
for param_q, param_k in zip(alg.encoder_q.parameters(),
alg.encoder_k.parameters()):
assert torch.equal(param_q, param_k)
assert param_k.requires_grad is False

fake_input = torch.randn((2, 3, 224, 224))
with pytest.raises(AssertionError):
fake_out = alg.forward_train(fake_input)
Expand Down

0 comments on commit 6e6bec1

Please sign in to comment.