From f09f1fd121540de11f0045435b948241c5d61278 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Mon, 17 Jul 2023 21:18:26 +0900 Subject: [PATCH 1/2] fix softmax --- .../adapters/mmdet/hooks/det_class_probability_map_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py b/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py index 2613990d2e5..7931e234091 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py +++ b/src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py @@ -62,7 +62,7 @@ def func( # Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects, # it would highlight one of the class maps as a background class - if self.use_cls_softmax: + if self.use_cls_softmax and self._num_cls_out_channels > 1: cls_scores = [torch.softmax(t, dim=1) for t in cls_scores] batch_size, _, height, width = cls_scores[-1].size() From 9b4d8a1bc4c9c2fa3300dd862b8039016cf24ce9 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 13 Jul 2023 18:59:46 +0900 Subject: [PATCH 2/2] fix validity tests --- .../classification/test_xai_classification_validity.py | 4 +++- .../unit/algorithms/detection/test_xai_detection_validity.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/algorithms/classification/test_xai_classification_validity.py b/tests/unit/algorithms/classification/test_xai_classification_validity.py index 1ec20d0c2c1..675cd53b89c 100644 --- a/tests/unit/algorithms/classification/test_xai_classification_validity.py +++ b/tests/unit/algorithms/classification/test_xai_classification_validity.py @@ -54,4 +54,6 @@ def test_saliency_map_cls(self, template): assert len(saliency_maps) == 2 assert saliency_maps[0].ndim == 3 assert saliency_maps[0].shape == (1000, 7, 7) - assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_cls[template.name]) <= 1) + actual_sal_vals = saliency_maps[0][0][0].astype(np.int8) + ref_sal_vals = self.ref_saliency_vals_cls[template.name].astype(np.int8) + assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) diff --git a/tests/unit/algorithms/detection/test_xai_detection_validity.py b/tests/unit/algorithms/detection/test_xai_detection_validity.py index 6d61beed752..6f684376064 100644 --- a/tests/unit/algorithms/detection/test_xai_detection_validity.py +++ b/tests/unit/algorithms/detection/test_xai_detection_validity.py @@ -80,7 +80,9 @@ def test_saliency_map_det(self, template): assert len(saliency_maps) == 2 assert saliency_maps[0].ndim == 3 assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name] - assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_det[template.name]) <= 1) + actual_sal_vals = saliency_maps[0][0][0].astype(np.int8) + ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.int8) + assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) @e2e_pytest_unit @pytest.mark.parametrize("template", templates_det, ids=templates_det_ids)