diff --git a/src/otx/algo/detection/heads/base_sampler.py b/src/otx/algo/detection/heads/base_sampler.py index fcc0ed5520b..d608c2a61cb 100644 --- a/src/otx/algo/detection/heads/base_sampler.py +++ b/src/otx/algo/detection/heads/base_sampler.py @@ -75,6 +75,7 @@ def _sample_pos(self, assign_result: AssignResult, num_expected: int, **kwargs) def _sample_neg(self, assign_result: AssignResult, num_expected: int, **kwargs) -> torch.Tensor: """Sample negative samples.""" + @abstractmethod def sample( self, assign_result: AssignResult, @@ -103,53 +104,6 @@ def sample( Returns: :obj:`SamplingResult`: Sampling result. """ - gt_bboxes = gt_instances.bboxes # type: ignore[attr-defined] - priors = pred_instances.priors # type: ignore[attr-defined] - gt_labels = gt_instances.labels # type: ignore[attr-defined] - if len(priors.shape) < 2: - priors = priors[None, :] - - gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8) - if self.add_gt_as_proposals and len(gt_bboxes) > 0: - gt_bboxes_ = gt_bboxes - priors = torch.cat([gt_bboxes_, priors], dim=0) - assign_result.add_gt_(gt_labels) - gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8) - gt_flags = torch.cat([gt_ones, gt_flags]) - - num_expected_pos = int(self.num * self.pos_fraction) - pos_inds = self.pos_sampler._sample_pos( # noqa: SLF001 - assign_result, - num_expected_pos, - bboxes=priors, - **kwargs, - ) - # We found that sampled indices have duplicated items occasionally. - # (may be a bug of PyTorch) - pos_inds = pos_inds.unique() - num_sampled_pos = pos_inds.numel() - num_expected_neg = self.num - num_sampled_pos - if self.neg_pos_ub >= 0: - _pos = max(1, num_sampled_pos) - neg_upper_bound = int(self.neg_pos_ub * _pos) - if num_expected_neg > neg_upper_bound: - num_expected_neg = neg_upper_bound - neg_inds = self.neg_sampler._sample_neg( # noqa: SLF001 - assign_result, - num_expected_neg, - bboxes=priors, - **kwargs, - ) - neg_inds = neg_inds.unique() - - return SamplingResult( - pos_inds=pos_inds, - neg_inds=neg_inds, - priors=priors, - gt_bboxes=gt_bboxes, - assign_result=assign_result, - gt_flags=gt_flags, - ) class PseudoSampler(BaseSampler): diff --git a/src/otx/algo/detection/utils/utils.py b/src/otx/algo/detection/utils/utils.py index d9ab24ab26b..b46eab604ec 100644 --- a/src/otx/algo/detection/utils/utils.py +++ b/src/otx/algo/detection/utils/utils.py @@ -267,13 +267,7 @@ def empty_instances( results_list = [] for img_id in range(len(batch_img_metas)): - if instance_results is not None: - results = instance_results[img_id] - if not isinstance(results, InstanceData): - msg = f"instance_results should be InstanceData, but got {type(results)}" - raise TypeError(msg) - else: - results = InstanceData() + results = instance_results[img_id] if instance_results is not None else InstanceData() if task_type == "bbox": bboxes = torch.zeros(0, 4, device=device) diff --git a/tests/unit/algo/detection/heads/test_custom_anchor_generator.py b/tests/unit/algo/detection/heads/test_custom_anchor_generator.py index 956c345bc74..01a5eb92946 100644 --- a/tests/unit/algo/detection/heads/test_custom_anchor_generator.py +++ b/tests/unit/algo/detection/heads/test_custom_anchor_generator.py @@ -25,3 +25,22 @@ def anchor_generator(self) -> SSDAnchorGeneratorClustered: def test_gen_base_anchors(self, anchor_generator) -> None: assert anchor_generator.base_anchors[0].shape == torch.Size([4, 4]) assert anchor_generator.base_anchors[1].shape == torch.Size([5, 4]) + + def test_sparse_priors(self, anchor_generator) -> None: + assert anchor_generator.sparse_priors(torch.IntTensor([0]), [32, 32], 0, device="cpu").shape == torch.Size( + [1, 4], + ) + + def test_grid_anchors(self, anchor_generator) -> None: + out = anchor_generator.grid_anchors([(8, 8), (16, 16)], device="cpu") + assert len(out) == 2 + assert out[0].shape == torch.Size([256, 4]) + assert out[1].shape == torch.Size([1280, 4]) + + def test_repr(self, anchor_generator) -> None: + assert "strides" in str(anchor_generator) + assert "widths" in str(anchor_generator) + assert "heights" in str(anchor_generator) + assert "num_levels" in str(anchor_generator) + assert "centers" in str(anchor_generator) + assert "center_offset" in str(anchor_generator) diff --git a/tests/unit/algo/detection/heads/test_max_iou_assigner.py b/tests/unit/algo/detection/heads/test_max_iou_assigner.py new file mode 100644 index 00000000000..6fa0a47faf3 --- /dev/null +++ b/tests/unit/algo/detection/heads/test_max_iou_assigner.py @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test Max Iou Assigner .""" + +import torch +from otx.algo.detection.heads.max_iou_assigner import perm_repeat_bboxes + + +def test_perm_repeat_bboxes() -> None: + sample = torch.randn(1, 4) + inputs = torch.stack([sample for i in range(10)]) + assert perm_repeat_bboxes(inputs, {}).shape == torch.Size([10, 1, 4]) diff --git a/tests/unit/algo/detection/test_atss.py b/tests/unit/algo/detection/test_atss.py index 9dfa7659be7..4b61d22757a 100644 --- a/tests/unit/algo/detection/test_atss.py +++ b/tests/unit/algo/detection/test_atss.py @@ -43,3 +43,7 @@ def test_export(self, model): model.eval() output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) assert len(output) == 2 + + model.explain_mode = True + output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + assert len(output) == 4 diff --git a/tests/unit/algo/detection/test_ssd.py b/tests/unit/algo/detection/test_ssd.py index 7a69dc4b172..e30daa163af 100644 --- a/tests/unit/algo/detection/test_ssd.py +++ b/tests/unit/algo/detection/test_ssd.py @@ -9,6 +9,8 @@ from lightning import Trainer from otx.algo.detection.ssd import SSD from otx.core.data.entity.detection import DetBatchPredEntity +from otx.core.exporter.native import OTXModelExporter +from otx.core.types.export import TaskLevelExportParameters class TestSSD: @@ -33,30 +35,51 @@ def fxt_checkpoint(self, fxt_model, fxt_data_module, tmpdir, monkeypatch: pytest return checkpoint_path + def test_init(self, fxt_model): + assert isinstance(fxt_model._export_parameters, TaskLevelExportParameters) + assert isinstance(fxt_model._exporter, OTXModelExporter) + def test_save_and_load_anchors(self, fxt_checkpoint) -> None: loaded_model = SSD.load_from_checkpoint(checkpoint_path=fxt_checkpoint) assert loaded_model.model.bbox_head.anchor_generator.widths[0][0] == 40 assert loaded_model.model.bbox_head.anchor_generator.heights[0][0] == 50 - def test_loss(self, fxt_data_module): - model = SSD(3) + def test_load_state_dict_pre_hook(self, fxt_model) -> None: + prev_model = SSD(2) + state_dict = prev_model.state_dict() + fxt_model.model_classes = [1, 2, 3] + fxt_model.ckpt_classes = [1, 2] + fxt_model.load_state_dict_pre_hook(state_dict, "") + keys = [ + key + for key in prev_model.state_dict() + if prev_model.state_dict()[key].shape != state_dict[key].shape + or torch.all(prev_model.state_dict()[key] != state_dict[key]) + ] + + for key in keys: + assert key in fxt_model.classification_layers + + def test_loss(self, fxt_model, fxt_data_module): data = next(iter(fxt_data_module.train_dataloader())) data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] - output = model(data) + output = fxt_model(data) assert "loss_cls" in output assert "loss_bbox" in output - def test_predict(self, fxt_data_module): - model = SSD(3) + def test_predict(self, fxt_model, fxt_data_module): data = next(iter(fxt_data_module.train_dataloader())) data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] - model.eval() - output = model(data) + fxt_model.eval() + output = fxt_model(data) assert isinstance(output, DetBatchPredEntity) - def test_export(self): - model = SSD(3) - model.eval() - output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + def test_export(self, fxt_model): + fxt_model.eval() + output = fxt_model.forward_for_tracing(torch.randn(1, 3, 32, 32)) assert len(output) == 2 + + fxt_model.explain_mode = True + output = fxt_model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + assert len(output) == 4 diff --git a/tests/unit/algo/detection/test_yolox.py b/tests/unit/algo/detection/test_yolox.py index c5ba277c1da..b24aa51f820 100644 --- a/tests/unit/algo/detection/test_yolox.py +++ b/tests/unit/algo/detection/test_yolox.py @@ -58,3 +58,7 @@ def test_export(self, model): model.eval() output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) assert len(output) == 2 + + model.explain_mode = True + output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + assert len(output) == 4 diff --git a/tests/unit/algo/instance_segmentation/conftest.py b/tests/unit/algo/instance_segmentation/conftest.py new file mode 100644 index 00000000000..85ec1803946 --- /dev/null +++ b/tests/unit/algo/instance_segmentation/conftest.py @@ -0,0 +1,34 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of custom algo modules of OTX Detection task.""" +import pytest +from otx.core.config.data import DataModuleConfig, SubsetConfig +from otx.core.data.module import OTXDataModule +from otx.core.types.task import OTXTaskType +from torchvision.transforms.v2 import Resize + + +@pytest.fixture() +def fxt_data_module(): + return OTXDataModule( + task=OTXTaskType.INSTANCE_SEGMENTATION, + config=DataModuleConfig( + data_format="coco_instances", + data_root="tests/assets/car_tree_bug", + train_subset=SubsetConfig( + batch_size=2, + subset_name="train", + transforms=[Resize(320)], + ), + val_subset=SubsetConfig( + batch_size=2, + subset_name="val", + transforms=[Resize(320)], + ), + test_subset=SubsetConfig( + batch_size=2, + subset_name="test", + transforms=[Resize(320)], + ), + ), + ) diff --git a/tests/unit/algo/instance_segmentation/test_maskrcnn.py b/tests/unit/algo/instance_segmentation/test_maskrcnn.py new file mode 100644 index 00000000000..7ca358d3c21 --- /dev/null +++ b/tests/unit/algo/instance_segmentation/test_maskrcnn.py @@ -0,0 +1,57 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of OTX MaskRCNN architecture.""" + +import pytest +import torch +from otx.algo.instance_segmentation.maskrcnn import MaskRCNNEfficientNet, MaskRCNNResNet50, MaskRCNNSwinT +from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity +from otx.core.types.export import TaskLevelExportParameters + + +class TestMaskRCNN: + def test_load_weights(self, mocker) -> None: + model = MaskRCNNResNet50(2) + mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_iseg_ckpt") + model.load_from_otx_v1_ckpt({}) + mock_load_ckpt.assert_called_once_with({}, "model.") + + assert isinstance(model._export_parameters, TaskLevelExportParameters) + + @pytest.mark.parametrize("model", [MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)]) + def test_loss(self, model, mocker, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = torch.randn([2, 3, 32, 32]) + + def mock_data_preprocessor_forward(data: torch.Tensor, training: bool) -> torch.Tensor: + return data + + mocker.patch.object(model.model.data_preprocessor, "forward", side_effect=mock_data_preprocessor_forward) + + output = model(data) + assert "loss_cls" in output + assert "loss_bbox" in output + assert "loss_mask" in output + assert "loss_rpn_cls" in output + assert "loss_rpn_bbox" in output + + @pytest.mark.parametrize("model", [MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)]) + def test_predict(self, model, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] + model.eval() + output = model(data) + assert isinstance(output, InstanceSegBatchPredEntity) + + @pytest.mark.parametrize("model", [MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)]) + def test_export(self, model): + model.eval() + output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + assert len(output) == 3 + + # TODO(Eugene): Explain should return proper output. + # After enabling explain for maskrcnn, below codes shuold be passed + # model.explain_mode = True # noqa: ERA001 + # output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) # noqa: ERA001 + # assert len(output) == 5 # noqa: ERA001