From 125cf79839f41b2ebcaf766dbf92e0212d72e57d Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 8 Jun 2023 10:26:39 +0900 Subject: [PATCH 1/4] Add custom max iou assigner --- .../adapters/mmdet/models/__init__.py | 4 +- .../mmdet/models/assigners/__init__.py | 8 ++ .../assigners/custom_max_iou_assigner.py | 107 ++++++++++++++++++ .../efficientnetb2b_maskrcnn/model.py | 4 +- .../resnet50_maskrcnn/model.py | 4 +- .../assigners/test_custom_max_iou_assigner.py | 57 ++++++++++ 6 files changed, 178 insertions(+), 6 deletions(-) create mode 100644 otx/algorithms/detection/adapters/mmdet/models/assigners/__init__.py create mode 100644 otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py create mode 100644 tests/unit/algorithms/detection/adapters/mmdet/models/assigners/test_custom_max_iou_assigner.py diff --git a/otx/algorithms/detection/adapters/mmdet/models/__init__.py b/otx/algorithms/detection/adapters/mmdet/models/__init__.py index fc834370596..c73e3d4247e 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/__init__.py +++ b/otx/algorithms/detection/adapters/mmdet/models/__init__.py @@ -3,6 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 # -from . import backbones, dense_heads, detectors, heads, losses, necks, roi_heads +from . import assigners, backbones, dense_heads, detectors, heads, losses, necks, roi_heads -__all__ = ["backbones", "dense_heads", "detectors", "heads", "losses", "necks", "roi_heads"] +__all__ = ["assigners", "backbones", "dense_heads", "detectors", "heads", "losses", "necks", "roi_heads"] diff --git a/otx/algorithms/detection/adapters/mmdet/models/assigners/__init__.py b/otx/algorithms/detection/adapters/mmdet/models/assigners/__init__.py new file mode 100644 index 00000000000..71418724251 --- /dev/null +++ b/otx/algorithms/detection/adapters/mmdet/models/assigners/__init__.py @@ -0,0 +1,8 @@ +"""Initial file for mmdetection assigners.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from .custom_max_iou_assigner import CustomMaxIoUAssigner + +__all__ = ["CustomMaxIoUAssigner"] diff --git a/otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py b/otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py new file mode 100644 index 00000000000..eea406c7204 --- /dev/null +++ b/otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py @@ -0,0 +1,107 @@ +"""Custom assigner for mmdet MaxIouAssigner.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import torch +from mmdet.core.bbox.assigners import MaxIoUAssigner +from mmdet.core.bbox.builder import BBOX_ASSIGNERS + + +@BBOX_ASSIGNERS.register_module() +class CustomMaxIoUAssigner(MaxIoUAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, or a semi-positive integer + indicating the ground truth index. + + - -1: negative sample, no assigned gt + - semi-positive integer: positive sample, index (0-based) of assigned gt + + This CustomMaxIoUAssigner patches assign funtion of mmdet's MaxIouAssigner + so that it can prevent CPU OOM for images whose gt is extremely large + """ + + cpu_assign_thr = 1000 + + def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): + """Assign gt to bboxes. + + This method assign a gt bbox to every bbox (proposal/anchor), each bbox + will be assigned with -1, or a semi-positive number. -1 means negative + sample, semi-positive number is the index (0-based) of assigned gt. + The assignment is done in following steps, the order matters. + + Especially CustomMaxIoUAssigner split gt_bboxes tensor into small tensors + when gt_bboxes is too large. + + 1. assign every bbox to the background + 2. assign proposals whose iou with all gts < neg_iou_thr to 0 + 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, + assign it to that bbox + 4. for each gt bbox, assign its nearest proposals (may be more than + one) to itself + + Args: + bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`AssignResult`: The assign result. + + Example: + >>> self = MaxIoUAssigner(0.5, 0.5) + >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]]) + >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]]) + >>> assign_result = self.assign(bboxes, gt_bboxes) + >>> expected_gt_inds = torch.LongTensor([1, 0]) + >>> assert torch.all(assign_result.gt_inds == expected_gt_inds) + """ + assign_on_cpu = True if (self.gpu_assign_thr > 0) and (gt_bboxes.shape[0] > self.gpu_assign_thr) else False + # compute overlap and assign gt on CPU when number of GT is large + if assign_on_cpu: + device = bboxes.device + bboxes = bboxes.cpu() + gt_bboxes = gt_bboxes.cpu() + if gt_bboxes_ignore is not None: + gt_bboxes_ignore = gt_bboxes_ignore.cpu() + if gt_labels is not None: + gt_labels = gt_labels.cpu() + + if assign_on_cpu and gt_bboxes.shape[0] > self.cpu_assign_thr: + split_length = gt_bboxes.shape[0] // self.cpu_assign_thr + 1 + overlaps = None + for i in range(split_length): + gt_bboxes_split = gt_bboxes[i * self.cpu_assign_thr : (i + 1) * self.cpu_assign_thr] + if overlaps is None: + overlaps = self.iou_calculator(gt_bboxes_split, bboxes) + else: + overlaps = torch.concat((overlaps, self.iou_calculator(gt_bboxes_split, bboxes)), dim=0) + + else: + overlaps = self.iou_calculator(gt_bboxes, bboxes) + + if ( + self.ignore_iof_thr > 0 + and gt_bboxes_ignore is not None + and gt_bboxes_ignore.numel() > 0 + and bboxes.numel() > 0 + ): + if self.ignore_wrt_candidates: + ignore_overlaps = self.iou_calculator(bboxes, gt_bboxes_ignore, mode="iof") + ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) + else: + ignore_overlaps = self.iou_calculator(gt_bboxes_ignore, bboxes, mode="iof") + ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) + overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 + + assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) + if assign_on_cpu: + assign_result.gt_inds = assign_result.gt_inds.to(device) + assign_result.max_overlaps = assign_result.max_overlaps.to(device) + if assign_result.labels is not None: + assign_result.labels = assign_result.labels.to(device) + return assign_result diff --git a/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py b/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py index 52e568dc32e..614b3d91950 100644 --- a/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py +++ b/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py @@ -75,7 +75,7 @@ train_cfg=dict( rpn=dict( assigner=dict( - type="MaxIoUAssigner", + type="CustomMaxIoUAssigner", pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, @@ -97,7 +97,7 @@ ), rcnn=dict( assigner=dict( - type="MaxIoUAssigner", + type="CustomMaxIoUAssigner", pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, diff --git a/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py b/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py index c24df3a76cb..54f54d9c81c 100644 --- a/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py +++ b/otx/algorithms/detection/configs/instance_segmentation/resnet50_maskrcnn/model.py @@ -91,7 +91,7 @@ train_cfg=dict( rpn=dict( assigner=dict( - type="MaxIoUAssigner", + type="CustomMaxIoUAssigner", pos_iou_thr=0.7, neg_iou_thr=0.3, min_pos_iou=0.3, @@ -119,7 +119,7 @@ ), rcnn=dict( assigner=dict( - type="MaxIoUAssigner", + type="CustomMaxIoUAssigner", pos_iou_thr=0.5, neg_iou_thr=0.5, min_pos_iou=0.5, diff --git a/tests/unit/algorithms/detection/adapters/mmdet/models/assigners/test_custom_max_iou_assigner.py b/tests/unit/algorithms/detection/adapters/mmdet/models/assigners/test_custom_max_iou_assigner.py new file mode 100644 index 00000000000..d873190b04f --- /dev/null +++ b/tests/unit/algorithms/detection/adapters/mmdet/models/assigners/test_custom_max_iou_assigner.py @@ -0,0 +1,57 @@ +"""Unit test for cusom max iou assigner.""" +# Copyright (C) 2021-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import pytest +import torch + +from otx.algorithms.detection.adapters.mmdet.models.assigners import CustomMaxIoUAssigner +from tests.test_suite.e2e_test_system import e2e_pytest_unit + + +class TestCustomMaxIoUAssigner: + @pytest.fixture(autouse=True) + def setup(self): + """Initial setup for unit tests.""" + self.assigner = CustomMaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1, + gpu_assign_thr=300, + ) + self.assigner.cpu_assign_thr = 400 + + @e2e_pytest_unit + def test_assign_gpu(self): + """Test custom assign function on gpu.""" + gt_bboxes = torch.randn(200, 4) + bboxes = torch.randn(20000, 4) + assign_result = self.assigner.assign(bboxes, gt_bboxes) + assert assign_result.gt_inds.shape == torch.Size([20000]) + assert assign_result.max_overlaps.shape == torch.Size([20000]) + + @e2e_pytest_unit + def test_assign_cpu(self): + """Test custom assign function on cpu.""" + gt_bboxes = torch.randn(350, 4) + bboxes = torch.randn(20000, 4) + assign_result = self.assigner.assign(bboxes, gt_bboxes) + assert assign_result.gt_inds.shape == torch.Size([20000]) + assert assign_result.max_overlaps.shape == torch.Size([20000]) + + @e2e_pytest_unit + def test_assign_cpu_oom(self): + """Test custom assign function on cpu in case of cpu oom.""" + gt_bboxes = torch.randn(450, 4) + bboxes = torch.randn(20000, 4) + assign_result = self.assigner.assign(bboxes, gt_bboxes) + assert assign_result.gt_inds.shape == torch.Size([20000]) + assert assign_result.max_overlaps.shape == torch.Size([20000]) + + self.assigner_cpu_assign_thr = 500 + new_assign_result = self.assigner.assign(bboxes, gt_bboxes) + assert torch.all(new_assign_result.gt_inds == assign_result.gt_inds) + assert torch.all(new_assign_result.max_overlaps == assign_result.max_overlaps) From d4b8109104ad9ff6bedbcc4c418f444d810b55ee Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 8 Jun 2023 11:03:51 +0900 Subject: [PATCH 2/4] Modify CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 372971ab72f..e722e1545d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. ### New features - Support encrypted dataset training () +- Add custom max iou assigner () ### Enhancements From 6afbd0347a82a7c9bca5546547dcf5c9a0532102 Mon Sep 17 00:00:00 2001 From: Jaeguk Hyun Date: Thu, 8 Jun 2023 13:08:13 +0900 Subject: [PATCH 3/4] Update CHANGELOG.md Co-authored-by: Sungman Cho --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e722e1545d7..2d18cc57422 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ All notable changes to this project will be documented in this file. ### New features - Support encrypted dataset training () -- Add custom max iou assigner () +- Add custom max iou assigner to prevent CPU OOM when large annotations are used () ### Enhancements From be39736ba062446e995c0fd5f385b08fb1c5bd88 Mon Sep 17 00:00:00 2001 From: Jaeguk Hyun Date: Thu, 8 Jun 2023 13:21:29 +0900 Subject: [PATCH 4/4] Update otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py Co-authored-by: Songki Choi --- .../mmdet/models/assigners/custom_max_iou_assigner.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py b/otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py index eea406c7204..31e4a1674e2 100644 --- a/otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py +++ b/otx/algorithms/detection/adapters/mmdet/models/assigners/custom_max_iou_assigner.py @@ -73,14 +73,11 @@ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): if assign_on_cpu and gt_bboxes.shape[0] > self.cpu_assign_thr: split_length = gt_bboxes.shape[0] // self.cpu_assign_thr + 1 - overlaps = None + overlaps = [] for i in range(split_length): gt_bboxes_split = gt_bboxes[i * self.cpu_assign_thr : (i + 1) * self.cpu_assign_thr] - if overlaps is None: - overlaps = self.iou_calculator(gt_bboxes_split, bboxes) - else: - overlaps = torch.concat((overlaps, self.iou_calculator(gt_bboxes_split, bboxes)), dim=0) - + overlaps.append(self.iou_calculator(gt_bboxes_split, bboxes)) + overlaps = torch.concat(overlaps, dim=0) else: overlaps = self.iou_calculator(gt_bboxes, bboxes)