Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional otx algo det unit tests #3467

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 1 addition & 47 deletions src/otx/algo/detection/heads/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _sample_pos(self, assign_result: AssignResult, num_expected: int, **kwargs)
def _sample_neg(self, assign_result: AssignResult, num_expected: int, **kwargs) -> torch.Tensor:
"""Sample negative samples."""

@abstractmethod
def sample(
self,
assign_result: AssignResult,
Expand Down Expand Up @@ -103,53 +104,6 @@ def sample(
Returns:
:obj:`SamplingResult`: Sampling result.
"""
gt_bboxes = gt_instances.bboxes # type: ignore[attr-defined]
priors = pred_instances.priors # type: ignore[attr-defined]
gt_labels = gt_instances.labels # type: ignore[attr-defined]
if len(priors.shape) < 2:
priors = priors[None, :]

gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
gt_bboxes_ = gt_bboxes
priors = torch.cat([gt_bboxes_, priors], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])

num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self.pos_sampler._sample_pos( # noqa: SLF001
assign_result,
num_expected_pos,
bboxes=priors,
**kwargs,
)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg( # noqa: SLF001
assign_result,
num_expected_neg,
bboxes=priors,
**kwargs,
)
neg_inds = neg_inds.unique()

return SamplingResult(
pos_inds=pos_inds,
neg_inds=neg_inds,
priors=priors,
gt_bboxes=gt_bboxes,
assign_result=assign_result,
gt_flags=gt_flags,
)


class PseudoSampler(BaseSampler):
Expand Down
8 changes: 1 addition & 7 deletions src/otx/algo/detection/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,7 @@ def empty_instances(

results_list = []
for img_id in range(len(batch_img_metas)):
if instance_results is not None:
results = instance_results[img_id]
if not isinstance(results, InstanceData):
msg = f"instance_results should be InstanceData, but got {type(results)}"
raise TypeError(msg)
else:
results = InstanceData()
results = instance_results[img_id] if instance_results is not None else InstanceData()

if task_type == "bbox":
bboxes = torch.zeros(0, 4, device=device)
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/algo/detection/heads/test_custom_anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,22 @@ def anchor_generator(self) -> SSDAnchorGeneratorClustered:
def test_gen_base_anchors(self, anchor_generator) -> None:
assert anchor_generator.base_anchors[0].shape == torch.Size([4, 4])
assert anchor_generator.base_anchors[1].shape == torch.Size([5, 4])

def test_sparse_priors(self, anchor_generator) -> None:
assert anchor_generator.sparse_priors(torch.IntTensor([0]), [32, 32], 0, device="cpu").shape == torch.Size(
[1, 4],
)

def test_grid_anchors(self, anchor_generator) -> None:
out = anchor_generator.grid_anchors([(8, 8), (16, 16)], device="cpu")
assert len(out) == 2
assert out[0].shape == torch.Size([256, 4])
assert out[1].shape == torch.Size([1280, 4])

def test_repr(self, anchor_generator) -> None:
assert "strides" in str(anchor_generator)
assert "widths" in str(anchor_generator)
assert "heights" in str(anchor_generator)
assert "num_levels" in str(anchor_generator)
assert "centers" in str(anchor_generator)
assert "center_offset" in str(anchor_generator)
12 changes: 12 additions & 0 deletions tests/unit/algo/detection/heads/test_max_iou_assigner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test Max Iou Assigner ."""

import torch
from otx.algo.detection.heads.max_iou_assigner import perm_repeat_bboxes


def test_perm_repeat_bboxes() -> None:
sample = torch.randn(1, 4)
inputs = torch.stack([sample for i in range(10)])
assert perm_repeat_bboxes(inputs, {}).shape == torch.Size([10, 1, 4])
4 changes: 4 additions & 0 deletions tests/unit/algo/detection/test_atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ def test_export(self, model):
model.eval()
output = model.forward_for_tracing(torch.randn(1, 3, 32, 32))
assert len(output) == 2

model.explain_mode = True
output = model.forward_for_tracing(torch.randn(1, 3, 32, 32))
assert len(output) == 4
45 changes: 34 additions & 11 deletions tests/unit/algo/detection/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from lightning import Trainer
from otx.algo.detection.ssd import SSD
from otx.core.data.entity.detection import DetBatchPredEntity
from otx.core.exporter.native import OTXModelExporter
from otx.core.types.export import TaskLevelExportParameters


class TestSSD:
Expand All @@ -33,30 +35,51 @@ def fxt_checkpoint(self, fxt_model, fxt_data_module, tmpdir, monkeypatch: pytest

return checkpoint_path

def test_init(self, fxt_model):
assert isinstance(fxt_model._export_parameters, TaskLevelExportParameters)
assert isinstance(fxt_model._exporter, OTXModelExporter)

def test_save_and_load_anchors(self, fxt_checkpoint) -> None:
loaded_model = SSD.load_from_checkpoint(checkpoint_path=fxt_checkpoint)

assert loaded_model.model.bbox_head.anchor_generator.widths[0][0] == 40
assert loaded_model.model.bbox_head.anchor_generator.heights[0][0] == 50

def test_loss(self, fxt_data_module):
model = SSD(3)
def test_load_state_dict_pre_hook(self, fxt_model) -> None:
prev_model = SSD(2)
state_dict = prev_model.state_dict()
fxt_model.model_classes = [1, 2, 3]
fxt_model.ckpt_classes = [1, 2]
fxt_model.load_state_dict_pre_hook(state_dict, "")
keys = [
key
for key in prev_model.state_dict()
if prev_model.state_dict()[key].shape != state_dict[key].shape
or torch.all(prev_model.state_dict()[key] != state_dict[key])
]

for key in keys:
assert key in fxt_model.classification_layers

def test_loss(self, fxt_model, fxt_data_module):
data = next(iter(fxt_data_module.train_dataloader()))
data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)]
output = model(data)
output = fxt_model(data)
assert "loss_cls" in output
assert "loss_bbox" in output

def test_predict(self, fxt_data_module):
model = SSD(3)
def test_predict(self, fxt_model, fxt_data_module):
data = next(iter(fxt_data_module.train_dataloader()))
data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)]
model.eval()
output = model(data)
fxt_model.eval()
output = fxt_model(data)
assert isinstance(output, DetBatchPredEntity)

def test_export(self):
model = SSD(3)
model.eval()
output = model.forward_for_tracing(torch.randn(1, 3, 32, 32))
def test_export(self, fxt_model):
fxt_model.eval()
output = fxt_model.forward_for_tracing(torch.randn(1, 3, 32, 32))
assert len(output) == 2

fxt_model.explain_mode = True
output = fxt_model.forward_for_tracing(torch.randn(1, 3, 32, 32))
assert len(output) == 4
4 changes: 4 additions & 0 deletions tests/unit/algo/detection/test_yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ def test_export(self, model):
model.eval()
output = model.forward_for_tracing(torch.randn(1, 3, 32, 32))
assert len(output) == 2

model.explain_mode = True
output = model.forward_for_tracing(torch.randn(1, 3, 32, 32))
assert len(output) == 4
34 changes: 34 additions & 0 deletions tests/unit/algo/instance_segmentation/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of custom algo modules of OTX Detection task."""
import pytest
from otx.core.config.data import DataModuleConfig, SubsetConfig
from otx.core.data.module import OTXDataModule
from otx.core.types.task import OTXTaskType
from torchvision.transforms.v2 import Resize


@pytest.fixture()
def fxt_data_module():
return OTXDataModule(
task=OTXTaskType.INSTANCE_SEGMENTATION,
config=DataModuleConfig(
data_format="coco_instances",
data_root="tests/assets/car_tree_bug",
train_subset=SubsetConfig(
batch_size=2,
subset_name="train",
transforms=[Resize(320)],
),
val_subset=SubsetConfig(
batch_size=2,
subset_name="val",
transforms=[Resize(320)],
),
test_subset=SubsetConfig(
batch_size=2,
subset_name="test",
transforms=[Resize(320)],
),
),
)
57 changes: 57 additions & 0 deletions tests/unit/algo/instance_segmentation/test_maskrcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of OTX MaskRCNN architecture."""

import pytest
import torch
from otx.algo.instance_segmentation.maskrcnn import MaskRCNNEfficientNet, MaskRCNNResNet50, MaskRCNNSwinT
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity
from otx.core.types.export import TaskLevelExportParameters


class TestMaskRCNN:
def test_load_weights(self, mocker) -> None:
model = MaskRCNNResNet50(2)
mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_iseg_ckpt")
model.load_from_otx_v1_ckpt({})
mock_load_ckpt.assert_called_once_with({}, "model.")

assert isinstance(model._export_parameters, TaskLevelExportParameters)

@pytest.mark.parametrize("model", [MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)])
def test_loss(self, model, mocker, fxt_data_module):
data = next(iter(fxt_data_module.train_dataloader()))
data.images = torch.randn([2, 3, 32, 32])

def mock_data_preprocessor_forward(data: torch.Tensor, training: bool) -> torch.Tensor:
return data

mocker.patch.object(model.model.data_preprocessor, "forward", side_effect=mock_data_preprocessor_forward)

output = model(data)
assert "loss_cls" in output
assert "loss_bbox" in output
assert "loss_mask" in output
assert "loss_rpn_cls" in output
assert "loss_rpn_bbox" in output

@pytest.mark.parametrize("model", [MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)])
def test_predict(self, model, fxt_data_module):
data = next(iter(fxt_data_module.train_dataloader()))
data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)]
model.eval()
output = model(data)
assert isinstance(output, InstanceSegBatchPredEntity)

@pytest.mark.parametrize("model", [MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)])
def test_export(self, model):
model.eval()
output = model.forward_for_tracing(torch.randn(1, 3, 32, 32))
assert len(output) == 3

# TODO(Eugene): Explain should return proper output.
# After enabling explain for maskrcnn, below codes shuold be passed
# model.explain_mode = True # noqa: ERA001
# output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) # noqa: ERA001
# assert len(output) == 5 # noqa: ERA001
Loading