From 08fe9d6bd794d546da19c13decaa21cfca3bcd7b Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Mon, 14 Oct 2024 15:32:49 +0100 Subject: [PATCH] Fix empty anno - merge back develop (#4022) Refactor empty label workaround in iseg and mask_target.py --- src/otx/algo/common/utils/bbox_overlaps.py | 19 ++- .../utils/structures/mask/mask_target.py | 12 +- src/otx/core/data/pre_filtering.py | 6 +- .../unit/algo/common/test_iou2d_calculator.py | 123 ++++++++++++++++++ 4 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 tests/unit/algo/common/test_iou2d_calculator.py diff --git a/src/otx/algo/common/utils/bbox_overlaps.py b/src/otx/algo/common/utils/bbox_overlaps.py index 1df50c5cbd5..e6e05197ca1 100644 --- a/src/otx/algo/common/utils/bbox_overlaps.py +++ b/src/otx/algo/common/utils/bbox_overlaps.py @@ -8,6 +8,8 @@ from __future__ import annotations +import warnings + import torch from torch import Tensor @@ -142,15 +144,28 @@ def bbox_overlaps( >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) """ + if not (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0): + msg = "bboxes1 must have a last dimension of size 4 or be an empty tensor." + raise ValueError(msg) + + if not (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0): + msg = "bboxes2 must have a last dimension of size 4 or be an empty tensor." + raise ValueError(msg) + + if bboxes1.shape[:-2] != bboxes2.shape[:-2]: + msg = "The batch dimension of bboxes must be the same." + raise ValueError(msg) + batch_shape = bboxes1.shape[:-2] rows = bboxes1.size(-2) cols = bboxes2.size(-2) if rows * cols == 0: + warnings.warn("No bboxes are provided! Returning empty boxes!", stacklevel=2) if is_aligned: - return bboxes1.new((*batch_shape, rows)) - return bboxes1.new((*batch_shape, rows, cols)) + return bboxes1.new(batch_shape + (rows,)) # noqa: RUF005 + return bboxes1.new(batch_shape + (rows, cols)) # noqa: RUF005 area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) diff --git a/src/otx/algo/instance_segmentation/utils/structures/mask/mask_target.py b/src/otx/algo/instance_segmentation/utils/structures/mask/mask_target.py index b5ad04b6079..9a17a4b8e55 100644 --- a/src/otx/algo/instance_segmentation/utils/structures/mask/mask_target.py +++ b/src/otx/algo/instance_segmentation/utils/structures/mask/mask_target.py @@ -8,6 +8,8 @@ from __future__ import annotations +import warnings + import numpy as np import torch from datumaro.components.annotation import Polygon @@ -62,16 +64,20 @@ def mask_target_single( meta_info: dict, ) -> Tensor: """Compute mask target for each positive proposal in the image.""" + mask_size = _pair(mask_size) + if len(gt_masks) == 0: + warnings.warn("No ground truth masks are provided!", stacklevel=2) + return pos_proposals.new_zeros((0, *mask_size)) + if isinstance(gt_masks[0], Polygon): crop_and_resize = crop_and_resize_polygons elif isinstance(gt_masks, tv_tensors.Mask): crop_and_resize = crop_and_resize_masks else: - msg = f"Unsupported type of masks: {type(gt_masks[0])}" - raise NotImplementedError(msg) + warnings.warn("Unsupported ground truth mask type!", stacklevel=2) + return pos_proposals.new_zeros((0, *mask_size)) device = pos_proposals.device - mask_size = _pair(mask_size) num_pos = pos_proposals.size(0) if num_pos > 0: proposals_np = pos_proposals.cpu().numpy() diff --git a/src/otx/core/data/pre_filtering.py b/src/otx/core/data/pre_filtering.py index b3898a78f04..2d202f96e92 100644 --- a/src/otx/core/data/pre_filtering.py +++ b/src/otx/core/data/pre_filtering.py @@ -40,13 +40,15 @@ def pre_filtering( dataset = DmDataset.filter(dataset, is_valid_annot, filter_annotations=True) dataset = remove_unused_labels(dataset, data_format, ignore_index) if unannotated_items_ratio > 0: - empty_items = [item.id for item in dataset if item.subset == "train" and len(item.annotations) == 0] + empty_items = [ + item.id for item in dataset if item.subset in ("train", "TRAINING") and len(item.annotations) == 0 + ] used_background_items = set(sample(empty_items, int(len(empty_items) * unannotated_items_ratio))) return DmDataset.filter( dataset, lambda item: not ( - item.subset == "train" and len(item.annotations) == 0 and item.id not in used_background_items + item.subset in ("train", "TRAINING") and len(item.annotations) == 0 and item.id not in used_background_items ), ) diff --git a/tests/unit/algo/common/test_iou2d_calculator.py b/tests/unit/algo/common/test_iou2d_calculator.py new file mode 100644 index 00000000000..1bc6156f593 --- /dev/null +++ b/tests/unit/algo/common/test_iou2d_calculator.py @@ -0,0 +1,123 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + +import numpy as np +import pytest +import torch +from otx.algo.common.utils.assigners.iou2d_calculator import BboxOverlaps2D +from otx.algo.common.utils.bbox_overlaps import bbox_overlaps + + +def test_bbox_overlaps_2d(eps: float = 1e-7): + def _construct_bbox(num_bbox: int | None = None) -> tuple[torch.Tensor, int]: + img_h = int(np.random.randint(3, 1000)) + img_w = int(np.random.randint(3, 1000)) + if num_bbox is None: + num_bbox = np.random.randint(1, 10) + x1y1 = torch.rand((num_bbox, 2)) + x2y2 = torch.max(torch.rand((num_bbox, 2)), x1y1) + bboxes = torch.cat((x1y1, x2y2), -1) + bboxes[:, 0::2] *= img_w + bboxes[:, 1::2] *= img_h + return bboxes, num_bbox + + # Test where is_aligned is True, bboxes.size(-1) == 5 (include score) + self = BboxOverlaps2D() + bboxes1, num_bbox = _construct_bbox() + bboxes2, _ = _construct_bbox(num_bbox) + bboxes1 = torch.cat((bboxes1, torch.rand((num_bbox, 1))), 1) + bboxes2 = torch.cat((bboxes2, torch.rand((num_bbox, 1))), 1) + gious = self(bboxes1, bboxes2, "giou", True) + assert gious.size() == (num_bbox,), gious.size() + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + + # Test where is_aligned is True, bboxes1.size(-2) == 0 + bboxes1 = torch.empty((0, 4)) + bboxes2 = torch.empty((0, 4)) + gious = self(bboxes1, bboxes2, "giou", True) + assert gious.size() == (0,), gious.size() + assert torch.all(gious == torch.empty((0,))) + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + + # Test where is_aligned is True, and bboxes.ndims > 2 + bboxes1, num_bbox = _construct_bbox() + bboxes2, _ = _construct_bbox(num_bbox) + bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1) + # test assertion when batch dim is not the same + with pytest.raises(ValueError, match="The batch dimension of bboxes must be the same."): + self(bboxes1, bboxes2.unsqueeze(0).repeat(3, 1, 1), "giou", True) + bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1) + gious = self(bboxes1, bboxes2, "giou", True) + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + assert gious.size() == (2, num_bbox) + bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1, 1) + bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1, 1) + gious = self(bboxes1, bboxes2, "giou", True) + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + assert gious.size() == (2, 2, num_bbox) + + # Test where is_aligned is False + bboxes1, num_bbox1 = _construct_bbox() + bboxes2, num_bbox2 = _construct_bbox() + gious = self(bboxes1, bboxes2, "giou") + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + assert gious.size() == (num_bbox1, num_bbox2) + + # Test where is_aligned is False, and bboxes.ndims > 2 + bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1) + bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1) + gious = self(bboxes1, bboxes2, "giou") + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + assert gious.size() == (2, num_bbox1, num_bbox2) + bboxes1 = bboxes1.unsqueeze(0) + bboxes2 = bboxes2.unsqueeze(0) + gious = self(bboxes1, bboxes2, "giou") + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + assert gious.size() == (1, 2, num_bbox1, num_bbox2) + + # Test where is_aligned is False, bboxes1.size(-2) == 0 + gious = self(torch.empty(1, 2, 0, 4), bboxes2, "giou") + assert torch.all(gious == torch.empty(1, 2, 0, bboxes2.size(-2))) + assert torch.all(gious >= -1) + assert torch.all(gious <= 1) + + # test allclose between bbox_overlaps and the original official + # implementation. + bboxes1 = torch.FloatTensor( + [ + [0, 0, 10, 10], + [10, 10, 20, 20], + [32, 32, 38, 42], + ], + ) + bboxes2 = torch.FloatTensor( + [ + [0, 0, 10, 20], + [0, 10, 10, 19], + [10, 10, 20, 20], + ], + ) + gious = bbox_overlaps(bboxes1, bboxes2, "giou", is_aligned=True, eps=eps) + gious = gious.numpy().round(4) + # the gt is got with four decimal precision. + expected_gious = np.array([0.5000, -0.0500, -0.8214]) + assert np.allclose(gious, expected_gious, rtol=0, atol=eps) + + # test mode 'iof' + ious = bbox_overlaps(bboxes1, bboxes2, "iof", is_aligned=True, eps=eps) + assert torch.all(ious >= -1) + assert torch.all(ious <= 1) + assert ious.size() == (bboxes1.size(0),) + ious = bbox_overlaps(bboxes1, bboxes2, "iof", eps=eps) + assert torch.all(ious >= -1) + assert torch.all(ious <= 1) + assert ious.size() == (bboxes1.size(0), bboxes2.size(0))