Skip to content

Commit

Permalink
Add unit tests for algo/segmentation (#3469)
Browse files Browse the repository at this point in the history
* implement seg backbone unit test

* update seg head, module unit test
  • Loading branch information
eunwoosh authored May 8, 2024
1 parent 2e11325 commit 4cf4cc5
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 0 deletions.
86 changes: 86 additions & 0 deletions tests/unit/algo/segmentation/backbones/test_dinov2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from unittest.mock import MagicMock

import pytest
import torch
from otx.algo.segmentation.backbones import dinov2 as target_file
from otx.algo.segmentation.backbones.dinov2 import DinoVisionTransformer


class TestDinoVisionTransformer:
@pytest.fixture()
def mock_backbone_named_parameters(self) -> dict[str, MagicMock]:
named_parameter = {}
for i in range(3):
parameter = MagicMock()
parameter.requires_grad = True
named_parameter[f"layer_{i}"] = parameter
return named_parameter

@pytest.fixture()
def mock_backbone(self, mock_backbone_named_parameters) -> MagicMock:
backbone = MagicMock()
backbone.named_parameters.return_value = list(mock_backbone_named_parameters.items())
return backbone

@pytest.fixture(autouse=True)
def mock_torch_hub_load(self, mocker, mock_backbone):
return mocker.patch("otx.algo.segmentation.backbones.dinov2.torch.hub.load", return_value=mock_backbone)

def test_init(self, mock_backbone, mock_backbone_named_parameters):
dino = DinoVisionTransformer(name="dinov2_vits14_reg", freeze_backbone=True, out_index=[8, 9, 10, 11])

assert dino.backbone == mock_backbone
for parameter in mock_backbone_named_parameters.values():
assert parameter.requires_grad is False

@pytest.fixture()
def mock_init_cfg(self) -> MagicMock:
return MagicMock()

@pytest.fixture()
def dino_vit(self, mock_init_cfg) -> DinoVisionTransformer:
return DinoVisionTransformer(
name="dinov2_vits14_reg",
freeze_backbone=True,
out_index=[8, 9, 10, 11],
init_cfg=mock_init_cfg,
)

def test_forward(self, dino_vit, mock_backbone):
tensor = torch.rand(10, 3, 3, 3)
dino_vit.forward(tensor)

mock_backbone.assert_called_once_with(tensor)

@pytest.fixture()
def mock_load_from_http(self, mocker) -> MagicMock:
return mocker.patch.object(target_file, "load_from_http")

@pytest.fixture()
def mock_load_checkpoint_to_model(self, mocker) -> MagicMock:
return mocker.patch.object(target_file, "load_checkpoint_to_model")

@pytest.fixture()
def pretrained_weight(self, tmp_path) -> str:
weight = tmp_path / "pretrained.pth"
weight.touch()
return str(weight)

@pytest.fixture()
def mock_torch_load(self, mocker) -> MagicMock:
return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load")

def test_load_pretrained_weights(self, dino_vit, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model):
dino_vit.load_pretrained_weights(pretrained=pretrained_weight)

mock_torch_load.assert_called_once_with(pretrained_weight, "cpu")
mock_load_checkpoint_to_model.assert_called_once()

def test_load_pretrained_weights_from_url(self, dino_vit, mock_load_from_http, mock_load_checkpoint_to_model):
pretrained_weight = "www.fake.com/fake.pth"
dino_vit.load_pretrained_weights(pretrained=pretrained_weight)

mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu")
mock_load_checkpoint_to_model.assert_called_once()
58 changes: 58 additions & 0 deletions tests/unit/algo/segmentation/backbones/test_litehrnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from copy import deepcopy
from unittest.mock import MagicMock

import pytest
import torch
from otx.algo.segmentation.backbones import litehrnet as target_file
from otx.algo.segmentation.backbones.litehrnet import LiteHRNet, NeighbourSupport, SpatialWeightingV2, StemV2


Expand Down Expand Up @@ -75,6 +77,21 @@ def extra_cfg(self) -> dict:
(40, 80, 160, 320),
],
},
"out_modules": {
"conv": {
"enable": True,
"channels": 320,
},
"position_att": {
"enable": True,
"key_channels": 128,
"value_channels": 320,
"psp_size": [1, 3, 6, 8],
},
"local_att": {
"enable": False,
},
},
}

@pytest.fixture()
Expand Down Expand Up @@ -111,3 +128,44 @@ def test_forward(self, extra_cfg, backbone) -> None:
model = LiteHRNet(extra=extra)
outputs = model(inputs)
assert outputs is not None

@pytest.fixture()
def mock_load_from_http(self, mocker) -> MagicMock:
return mocker.patch.object(target_file, "load_from_http")

@pytest.fixture()
def mock_load_checkpoint_to_model(self, mocker) -> MagicMock:
return mocker.patch.object(target_file, "load_checkpoint_to_model")

@pytest.fixture()
def pretrained_weight(self, tmp_path) -> str:
weight = tmp_path / "pretrained.pth"
weight.touch()
return str(weight)

@pytest.fixture()
def mock_torch_load(self, mocker) -> MagicMock:
return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load")

def test_load_pretrained_weights(
self,
extra_cfg,
pretrained_weight,
mock_torch_load,
mock_load_checkpoint_to_model,
):
extra_cfg["add_stem_features"] = True
model = LiteHRNet(extra=extra_cfg)
model.load_pretrained_weights(pretrained=pretrained_weight)

mock_torch_load.assert_called_once_with(pretrained_weight, "cpu")
mock_load_checkpoint_to_model.assert_called_once()

def test_load_pretrained_weights_from_url(self, extra_cfg, mock_load_from_http, mock_load_checkpoint_to_model):
pretrained_weight = "www.fake.com/fake.pth"
extra_cfg["add_stem_features"] = True
model = LiteHRNet(extra=extra_cfg)
model.load_pretrained_weights(pretrained=pretrained_weight)

mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu")
mock_load_checkpoint_to_model.assert_called_once()
105 changes: 105 additions & 0 deletions tests/unit/algo/segmentation/backbones/test_mscan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from unittest.mock import MagicMock

import pytest
import torch
from otx.algo.segmentation.backbones import mscan as target_file
from otx.algo.segmentation.backbones.mscan import MSCAN, DropPath, drop_path


@pytest.mark.parametrize("dim", [1, 2, 3, 4])
def test_drop_path(dim: int):
size = [10] + [2] * dim
x = torch.ones(size)
out = drop_path(x, 0.5, True)

assert out.size() == x.size()
assert out.dtype == x.dtype
assert out.device == x.device


def test_drop_path_not_train():
x = torch.ones(2, 2, 2, 2)
out = drop_path(x, 0.5, False)

assert (x == out).all()
assert out.dtype == x.dtype
assert out.device == x.device


def test_drop_path_zero_prob():
x = torch.ones(2, 2, 2, 2)
out = drop_path(x, 0.0, True)

assert (x == out).all()
assert out.dtype == x.dtype
assert out.device == x.device


class TestDropPath:
def test_init(self):
drop_prob = 0.3
drop_path = DropPath(drop_prob)

assert drop_path.drop_prob == drop_prob

def test_forward(self):
drop_prob = 0.5
drop_path = DropPath(drop_prob)
drop_path.train()
x = torch.ones(2, 2, 2, 2)

out = drop_path.forward(x)

assert out.size() == x.size()
assert out.dtype == x.dtype
assert out.device == x.device


class TestMSCABlock:
def test_init(self):
num_stages = 4
mscan = MSCAN(num_stages=num_stages)

for i in range(num_stages):
assert hasattr(mscan, f"patch_embed{i + 1}")
assert hasattr(mscan, f"block{i + 1}")
assert hasattr(mscan, f"norm{i + 1}")

def test_forward(self):
num_stages = 4
mscan = MSCAN(num_stages=num_stages)
x = torch.rand(8, 3, 3, 3)
out = mscan.forward(x)

assert len(out) == num_stages

@pytest.fixture()
def mock_load_from_http(self, mocker) -> MagicMock:
return mocker.patch.object(target_file, "load_from_http")

@pytest.fixture()
def mock_load_checkpoint_to_model(self, mocker) -> MagicMock:
return mocker.patch.object(target_file, "load_checkpoint_to_model")

@pytest.fixture()
def pretrained_weight(self, tmp_path) -> str:
weight = tmp_path / "pretrained.pth"
weight.touch()
return str(weight)

@pytest.fixture()
def mock_torch_load(self, mocker) -> MagicMock:
return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load")

def test_load_pretrained_weights(self, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model):
MSCAN(pretrained_weights=pretrained_weight)

mock_torch_load.assert_called_once_with(pretrained_weight, "cpu")
mock_load_checkpoint_to_model.assert_called_once()

def test_load_pretrained_weights_from_url(self, mock_load_from_http, mock_load_checkpoint_to_model):
pretrained_weight = "www.fake.com/fake.pth"
MSCAN(pretrained_weights=pretrained_weight)

mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu")
mock_load_checkpoint_to_model.assert_called_once()
47 changes: 47 additions & 0 deletions tests/unit/algo/segmentation/heads/test_ham_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from typing import Any

import pytest
import torch
from otx.algo.segmentation.heads.ham_head import LightHamHead


class TestLightHamHead:
@pytest.fixture()
def head_config(self) -> dict[str, Any]:
return {
"ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "train_steps": 6},
"in_channels": [128, 320, 512],
"in_index": [1, 2, 3],
"norm_cfg": {"num_groups": 32, "requires_grad": True, "type": "GN"},
"align_corners": False,
"channels": 512,
"dropout_ratio": 0.1,
"ham_channels": 512,
"num_classes": 2,
}

def test_init(self, head_config):
light_ham_head = LightHamHead(**head_config)
assert light_ham_head.ham_channels == head_config["ham_channels"]

@pytest.fixture()
def batch_size(self) -> int:
return 8

@pytest.fixture()
def fake_input(self, batch_size) -> list[torch.Tensor]:
return [
torch.rand(batch_size, 64, 128, 128),
torch.rand(batch_size, 128, 64, 64),
torch.rand(batch_size, 320, 32, 32),
torch.rand(batch_size, 512, 16, 16),
]

def test_forward(self, head_config, fake_input, batch_size):
light_ham_head = LightHamHead(**head_config)
out = light_ham_head.forward(fake_input)
assert out.size()[0] == batch_size
assert out.size()[2] == fake_input[head_config["in_index"][0]].size()[2]
assert out.size()[3] == fake_input[head_config["in_index"][0]].size()[3]
3 changes: 3 additions & 0 deletions tests/unit/algo/segmentation/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of custom algo modules of OTX segmentation task."""
66 changes: 66 additions & 0 deletions tests/unit/algo/segmentation/modules/test_blokcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from typing import Any

import pytest
import torch
from otx.algo.segmentation.modules.blocks import AsymmetricPositionAttentionModule, LocalAttentionModule


class TestAsymmetricPositionAttentionModule:
@pytest.fixture()
def init_cfg(self) -> dict[str, Any]:
return {
"in_channels": 320,
"key_channels": 128,
"value_channels": 320,
"psp_size": [1, 3, 6, 8],
"conv_cfg": {"type": "Conv2d"},
"norm_cfg": {"type": "BN"},
}

def test_init(self, init_cfg):
module = AsymmetricPositionAttentionModule(**init_cfg)

assert module.in_channels == init_cfg["in_channels"]
assert module.key_channels == init_cfg["key_channels"]
assert module.value_channels == init_cfg["value_channels"]
assert module.conv_cfg == init_cfg["conv_cfg"]
assert module.norm_cfg == init_cfg["norm_cfg"]

@pytest.fixture()
def fake_input(self) -> torch.Tensor:
return torch.rand(8, 320, 16, 16)

def test_forward(self, init_cfg, fake_input):
module = AsymmetricPositionAttentionModule(**init_cfg)
out = module.forward(fake_input)

assert out.size() == fake_input.size()


class TestLocalAttentionModule:
@pytest.fixture()
def init_cfg(self) -> dict[str, Any]:
return {
"num_channels": 320,
"conv_cfg": {"type": "Conv2d"},
"norm_cfg": {"type": "BN"},
}

def test_init(self, init_cfg):
module = LocalAttentionModule(**init_cfg)

assert module.num_channels == init_cfg["num_channels"]
assert module.conv_cfg == init_cfg["conv_cfg"]
assert module.norm_cfg == init_cfg["norm_cfg"]

@pytest.fixture()
def fake_input(self) -> torch.Tensor:
return torch.rand(8, 320, 16, 16)

def test_forward(self, init_cfg, fake_input):
module = LocalAttentionModule(**init_cfg)

out = module.forward(fake_input)
assert out.size() == fake_input.size()

0 comments on commit 4cf4cc5

Please sign in to comment.