Skip to content

Commit

Permalink
reduce the unittests input of timm backbone
Browse files Browse the repository at this point in the history
  • Loading branch information
Junjun2016 committed Oct 27, 2021
1 parent 500f22e commit 0db1bd1
Showing 1 changed file with 0 additions and 48 deletions.
48 changes: 0 additions & 48 deletions tests/test_models/test_backbones/test_timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,51 +131,3 @@ def test_timm_backbone():
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))

# Test resnetv2_101x3_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_101x3_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 192, 4, 4))
assert feats[1] == torch.Size((1, 768, 2, 2))
assert feats[2] == torch.Size((1, 1536, 1, 1))
assert feats[3] == torch.Size((1, 3072, 1, 1))
assert feats[4] == torch.Size((1, 6144, 1, 1))

# Test resnetv2_152x2_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_152x2_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 128, 4, 4))
assert feats[1] == torch.Size((1, 512, 2, 2))
assert feats[2] == torch.Size((1, 1024, 1, 1))
assert feats[3] == torch.Size((1, 2048, 1, 1))
assert feats[4] == torch.Size((1, 4096, 1, 1))

# Test resnetv2_152x4_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_152x4_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 256, 4, 4))
assert feats[1] == torch.Size((1, 1024, 2, 2))
assert feats[2] == torch.Size((1, 2048, 1, 1))
assert feats[3] == torch.Size((1, 4096, 1, 1))
assert feats[4] == torch.Size((1, 8192, 1, 1))

0 comments on commit 0db1bd1

Please sign in to comment.