-
Notifications
You must be signed in to change notification settings - Fork 447
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests for algo/segmentation (#3469)
* implement seg backbone unit test * update seg head, module unit test
- Loading branch information
Showing
6 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from __future__ import annotations | ||
|
||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
import torch | ||
from otx.algo.segmentation.backbones import dinov2 as target_file | ||
from otx.algo.segmentation.backbones.dinov2 import DinoVisionTransformer | ||
|
||
|
||
class TestDinoVisionTransformer: | ||
@pytest.fixture() | ||
def mock_backbone_named_parameters(self) -> dict[str, MagicMock]: | ||
named_parameter = {} | ||
for i in range(3): | ||
parameter = MagicMock() | ||
parameter.requires_grad = True | ||
named_parameter[f"layer_{i}"] = parameter | ||
return named_parameter | ||
|
||
@pytest.fixture() | ||
def mock_backbone(self, mock_backbone_named_parameters) -> MagicMock: | ||
backbone = MagicMock() | ||
backbone.named_parameters.return_value = list(mock_backbone_named_parameters.items()) | ||
return backbone | ||
|
||
@pytest.fixture(autouse=True) | ||
def mock_torch_hub_load(self, mocker, mock_backbone): | ||
return mocker.patch("otx.algo.segmentation.backbones.dinov2.torch.hub.load", return_value=mock_backbone) | ||
|
||
def test_init(self, mock_backbone, mock_backbone_named_parameters): | ||
dino = DinoVisionTransformer(name="dinov2_vits14_reg", freeze_backbone=True, out_index=[8, 9, 10, 11]) | ||
|
||
assert dino.backbone == mock_backbone | ||
for parameter in mock_backbone_named_parameters.values(): | ||
assert parameter.requires_grad is False | ||
|
||
@pytest.fixture() | ||
def mock_init_cfg(self) -> MagicMock: | ||
return MagicMock() | ||
|
||
@pytest.fixture() | ||
def dino_vit(self, mock_init_cfg) -> DinoVisionTransformer: | ||
return DinoVisionTransformer( | ||
name="dinov2_vits14_reg", | ||
freeze_backbone=True, | ||
out_index=[8, 9, 10, 11], | ||
init_cfg=mock_init_cfg, | ||
) | ||
|
||
def test_forward(self, dino_vit, mock_backbone): | ||
tensor = torch.rand(10, 3, 3, 3) | ||
dino_vit.forward(tensor) | ||
|
||
mock_backbone.assert_called_once_with(tensor) | ||
|
||
@pytest.fixture() | ||
def mock_load_from_http(self, mocker) -> MagicMock: | ||
return mocker.patch.object(target_file, "load_from_http") | ||
|
||
@pytest.fixture() | ||
def mock_load_checkpoint_to_model(self, mocker) -> MagicMock: | ||
return mocker.patch.object(target_file, "load_checkpoint_to_model") | ||
|
||
@pytest.fixture() | ||
def pretrained_weight(self, tmp_path) -> str: | ||
weight = tmp_path / "pretrained.pth" | ||
weight.touch() | ||
return str(weight) | ||
|
||
@pytest.fixture() | ||
def mock_torch_load(self, mocker) -> MagicMock: | ||
return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load") | ||
|
||
def test_load_pretrained_weights(self, dino_vit, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model): | ||
dino_vit.load_pretrained_weights(pretrained=pretrained_weight) | ||
|
||
mock_torch_load.assert_called_once_with(pretrained_weight, "cpu") | ||
mock_load_checkpoint_to_model.assert_called_once() | ||
|
||
def test_load_pretrained_weights_from_url(self, dino_vit, mock_load_from_http, mock_load_checkpoint_to_model): | ||
pretrained_weight = "www.fake.com/fake.pth" | ||
dino_vit.load_pretrained_weights(pretrained=pretrained_weight) | ||
|
||
mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu") | ||
mock_load_checkpoint_to_model.assert_called_once() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
import torch | ||
from otx.algo.segmentation.backbones import mscan as target_file | ||
from otx.algo.segmentation.backbones.mscan import MSCAN, DropPath, drop_path | ||
|
||
|
||
@pytest.mark.parametrize("dim", [1, 2, 3, 4]) | ||
def test_drop_path(dim: int): | ||
size = [10] + [2] * dim | ||
x = torch.ones(size) | ||
out = drop_path(x, 0.5, True) | ||
|
||
assert out.size() == x.size() | ||
assert out.dtype == x.dtype | ||
assert out.device == x.device | ||
|
||
|
||
def test_drop_path_not_train(): | ||
x = torch.ones(2, 2, 2, 2) | ||
out = drop_path(x, 0.5, False) | ||
|
||
assert (x == out).all() | ||
assert out.dtype == x.dtype | ||
assert out.device == x.device | ||
|
||
|
||
def test_drop_path_zero_prob(): | ||
x = torch.ones(2, 2, 2, 2) | ||
out = drop_path(x, 0.0, True) | ||
|
||
assert (x == out).all() | ||
assert out.dtype == x.dtype | ||
assert out.device == x.device | ||
|
||
|
||
class TestDropPath: | ||
def test_init(self): | ||
drop_prob = 0.3 | ||
drop_path = DropPath(drop_prob) | ||
|
||
assert drop_path.drop_prob == drop_prob | ||
|
||
def test_forward(self): | ||
drop_prob = 0.5 | ||
drop_path = DropPath(drop_prob) | ||
drop_path.train() | ||
x = torch.ones(2, 2, 2, 2) | ||
|
||
out = drop_path.forward(x) | ||
|
||
assert out.size() == x.size() | ||
assert out.dtype == x.dtype | ||
assert out.device == x.device | ||
|
||
|
||
class TestMSCABlock: | ||
def test_init(self): | ||
num_stages = 4 | ||
mscan = MSCAN(num_stages=num_stages) | ||
|
||
for i in range(num_stages): | ||
assert hasattr(mscan, f"patch_embed{i + 1}") | ||
assert hasattr(mscan, f"block{i + 1}") | ||
assert hasattr(mscan, f"norm{i + 1}") | ||
|
||
def test_forward(self): | ||
num_stages = 4 | ||
mscan = MSCAN(num_stages=num_stages) | ||
x = torch.rand(8, 3, 3, 3) | ||
out = mscan.forward(x) | ||
|
||
assert len(out) == num_stages | ||
|
||
@pytest.fixture() | ||
def mock_load_from_http(self, mocker) -> MagicMock: | ||
return mocker.patch.object(target_file, "load_from_http") | ||
|
||
@pytest.fixture() | ||
def mock_load_checkpoint_to_model(self, mocker) -> MagicMock: | ||
return mocker.patch.object(target_file, "load_checkpoint_to_model") | ||
|
||
@pytest.fixture() | ||
def pretrained_weight(self, tmp_path) -> str: | ||
weight = tmp_path / "pretrained.pth" | ||
weight.touch() | ||
return str(weight) | ||
|
||
@pytest.fixture() | ||
def mock_torch_load(self, mocker) -> MagicMock: | ||
return mocker.patch("otx.algo.segmentation.backbones.mscan.torch.load") | ||
|
||
def test_load_pretrained_weights(self, pretrained_weight, mock_torch_load, mock_load_checkpoint_to_model): | ||
MSCAN(pretrained_weights=pretrained_weight) | ||
|
||
mock_torch_load.assert_called_once_with(pretrained_weight, "cpu") | ||
mock_load_checkpoint_to_model.assert_called_once() | ||
|
||
def test_load_pretrained_weights_from_url(self, mock_load_from_http, mock_load_checkpoint_to_model): | ||
pretrained_weight = "www.fake.com/fake.pth" | ||
MSCAN(pretrained_weights=pretrained_weight) | ||
|
||
mock_load_from_http.assert_called_once_with(pretrained_weight, "cpu") | ||
mock_load_checkpoint_to_model.assert_called_once() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import pytest | ||
import torch | ||
from otx.algo.segmentation.heads.ham_head import LightHamHead | ||
|
||
|
||
class TestLightHamHead: | ||
@pytest.fixture() | ||
def head_config(self) -> dict[str, Any]: | ||
return { | ||
"ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "train_steps": 6}, | ||
"in_channels": [128, 320, 512], | ||
"in_index": [1, 2, 3], | ||
"norm_cfg": {"num_groups": 32, "requires_grad": True, "type": "GN"}, | ||
"align_corners": False, | ||
"channels": 512, | ||
"dropout_ratio": 0.1, | ||
"ham_channels": 512, | ||
"num_classes": 2, | ||
} | ||
|
||
def test_init(self, head_config): | ||
light_ham_head = LightHamHead(**head_config) | ||
assert light_ham_head.ham_channels == head_config["ham_channels"] | ||
|
||
@pytest.fixture() | ||
def batch_size(self) -> int: | ||
return 8 | ||
|
||
@pytest.fixture() | ||
def fake_input(self, batch_size) -> list[torch.Tensor]: | ||
return [ | ||
torch.rand(batch_size, 64, 128, 128), | ||
torch.rand(batch_size, 128, 64, 64), | ||
torch.rand(batch_size, 320, 32, 32), | ||
torch.rand(batch_size, 512, 16, 16), | ||
] | ||
|
||
def test_forward(self, head_config, fake_input, batch_size): | ||
light_ham_head = LightHamHead(**head_config) | ||
out = light_ham_head.forward(fake_input) | ||
assert out.size()[0] == batch_size | ||
assert out.size()[2] == fake_input[head_config["in_index"][0]].size()[2] | ||
assert out.size()[3] == fake_input[head_config["in_index"][0]].size()[3] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Test of custom algo modules of OTX segmentation task.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import pytest | ||
import torch | ||
from otx.algo.segmentation.modules.blocks import AsymmetricPositionAttentionModule, LocalAttentionModule | ||
|
||
|
||
class TestAsymmetricPositionAttentionModule: | ||
@pytest.fixture() | ||
def init_cfg(self) -> dict[str, Any]: | ||
return { | ||
"in_channels": 320, | ||
"key_channels": 128, | ||
"value_channels": 320, | ||
"psp_size": [1, 3, 6, 8], | ||
"conv_cfg": {"type": "Conv2d"}, | ||
"norm_cfg": {"type": "BN"}, | ||
} | ||
|
||
def test_init(self, init_cfg): | ||
module = AsymmetricPositionAttentionModule(**init_cfg) | ||
|
||
assert module.in_channels == init_cfg["in_channels"] | ||
assert module.key_channels == init_cfg["key_channels"] | ||
assert module.value_channels == init_cfg["value_channels"] | ||
assert module.conv_cfg == init_cfg["conv_cfg"] | ||
assert module.norm_cfg == init_cfg["norm_cfg"] | ||
|
||
@pytest.fixture() | ||
def fake_input(self) -> torch.Tensor: | ||
return torch.rand(8, 320, 16, 16) | ||
|
||
def test_forward(self, init_cfg, fake_input): | ||
module = AsymmetricPositionAttentionModule(**init_cfg) | ||
out = module.forward(fake_input) | ||
|
||
assert out.size() == fake_input.size() | ||
|
||
|
||
class TestLocalAttentionModule: | ||
@pytest.fixture() | ||
def init_cfg(self) -> dict[str, Any]: | ||
return { | ||
"num_channels": 320, | ||
"conv_cfg": {"type": "Conv2d"}, | ||
"norm_cfg": {"type": "BN"}, | ||
} | ||
|
||
def test_init(self, init_cfg): | ||
module = LocalAttentionModule(**init_cfg) | ||
|
||
assert module.num_channels == init_cfg["num_channels"] | ||
assert module.conv_cfg == init_cfg["conv_cfg"] | ||
assert module.norm_cfg == init_cfg["norm_cfg"] | ||
|
||
@pytest.fixture() | ||
def fake_input(self) -> torch.Tensor: | ||
return torch.rand(8, 320, 16, 16) | ||
|
||
def test_forward(self, init_cfg, fake_input): | ||
module = LocalAttentionModule(**init_cfg) | ||
|
||
out = module.forward(fake_input) | ||
assert out.size() == fake_input.size() |