From 05340f34ffaa4ac7a28139bede460e0235d9fce7 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Fri, 3 May 2024 16:51:12 +0900 Subject: [PATCH] Add unit tests for detectors' forward function --- tests/unit/algo/detection/test_atss.py | 28 ++++++++++++++++++++++++- tests/unit/algo/detection/test_ssd.py | 24 +++++++++++++++++++++ tests/unit/algo/detection/test_yolox.py | 28 ++++++++++++++++++++++++- 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/tests/unit/algo/detection/test_atss.py b/tests/unit/algo/detection/test_atss.py index 007fda1d181..2108620b8a7 100644 --- a/tests/unit/algo/detection/test_atss.py +++ b/tests/unit/algo/detection/test_atss.py @@ -2,8 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 """Test of OTX SSD architecture.""" -from otx.algo.detection.atss import MobileNetV2ATSS +import pytest +import torch +from otx.algo.detection.atss import MobileNetV2ATSS, ResNeXt101ATSS from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.detection import DetBatchPredEntity from otx.core.exporter.native import OTXModelExporter from otx.core.types.export import TaskLevelExportParameters @@ -17,3 +20,26 @@ def test(self, mocker) -> None: assert isinstance(model._export_parameters, TaskLevelExportParameters) assert isinstance(model._exporter, OTXModelExporter) + + @pytest.mark.parametrize("model", [MobileNetV2ATSS(3), ResNeXt101ATSS(3)]) + def test_loss(self, model, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] + output = model(data) + assert "loss_cls" in output + assert "loss_bbox" in output + assert "loss_centerness" in output + + @pytest.mark.parametrize("model", [MobileNetV2ATSS(3), ResNeXt101ATSS(3)]) + def test_predict(self, model, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] + model.eval() + output = model(data) + assert isinstance(output, DetBatchPredEntity) + + @pytest.mark.parametrize("model", [MobileNetV2ATSS(3), ResNeXt101ATSS(3)]) + def test_export(self, model): + model.eval() + output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + assert len(output) == 2 diff --git a/tests/unit/algo/detection/test_ssd.py b/tests/unit/algo/detection/test_ssd.py index 36018446e87..7a69dc4b172 100644 --- a/tests/unit/algo/detection/test_ssd.py +++ b/tests/unit/algo/detection/test_ssd.py @@ -5,8 +5,10 @@ from pathlib import Path import pytest +import torch from lightning import Trainer from otx.algo.detection.ssd import SSD +from otx.core.data.entity.detection import DetBatchPredEntity class TestSSD: @@ -36,3 +38,25 @@ def test_save_and_load_anchors(self, fxt_checkpoint) -> None: assert loaded_model.model.bbox_head.anchor_generator.widths[0][0] == 40 assert loaded_model.model.bbox_head.anchor_generator.heights[0][0] == 50 + + def test_loss(self, fxt_data_module): + model = SSD(3) + data = next(iter(fxt_data_module.train_dataloader())) + data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] + output = model(data) + assert "loss_cls" in output + assert "loss_bbox" in output + + def test_predict(self, fxt_data_module): + model = SSD(3) + data = next(iter(fxt_data_module.train_dataloader())) + data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] + model.eval() + output = model(data) + assert isinstance(output, DetBatchPredEntity) + + def test_export(self): + model = SSD(3) + model.eval() + output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + assert len(output) == 2 diff --git a/tests/unit/algo/detection/test_yolox.py b/tests/unit/algo/detection/test_yolox.py index 25d35efca9f..c5ba277c1da 100644 --- a/tests/unit/algo/detection/test_yolox.py +++ b/tests/unit/algo/detection/test_yolox.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 """Test of OTX YOLOX architecture.""" +import pytest +import torch from otx.algo.detection.backbones.csp_darknet import CSPDarknet from otx.algo.detection.heads.yolox_head import YOLOXHead from otx.algo.detection.necks.yolox_pafpn import YOLOXPAFPN -from otx.algo.detection.yolox import YOLOXL, YOLOXTINY +from otx.algo.detection.yolox import YOLOXL, YOLOXS, YOLOXTINY, YOLOXX +from otx.core.data.entity.detection import DetBatchPredEntity from otx.core.exporter.native import OTXNativeModelExporter @@ -32,3 +35,26 @@ def test_exporter(self) -> None: otx_yolox_tiny_exporter = otx_yolox_tiny._exporter assert isinstance(otx_yolox_tiny_exporter, OTXNativeModelExporter) assert otx_yolox_tiny_exporter.swap_rgb is False + + @pytest.mark.parametrize("model", [YOLOXTINY(3), YOLOXS(3), YOLOXL(3), YOLOXX(3)]) + def test_loss(self, model, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] + output = model(data) + assert "loss_cls" in output + assert "loss_bbox" in output + assert "loss_obj" in output + + @pytest.mark.parametrize("model", [YOLOXTINY(3), YOLOXS(3), YOLOXL(3), YOLOXX(3)]) + def test_predict(self, model, fxt_data_module): + data = next(iter(fxt_data_module.train_dataloader())) + data.images = [torch.randn(3, 32, 32), torch.randn(3, 48, 48)] + model.eval() + output = model(data) + assert isinstance(output, DetBatchPredEntity) + + @pytest.mark.parametrize("model", [YOLOXTINY(3), YOLOXS(3), YOLOXL(3), YOLOXX(3)]) + def test_export(self, model): + model.eval() + output = model.forward_for_tracing(torch.randn(1, 3, 32, 32)) + assert len(output) == 2