diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 06c27bfa8c3..fb51922886f 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -2426,9 +2426,8 @@ def _mixup_transform(self, results): keep_list = self._filter_box_candidates(retrieve_gt_bboxes.T, cp_retrieve_gt_bboxes.T) - if keep_list.sum() >= 1.0: - retrieve_gt_labels = retrieve_gt_labels[keep_list] - cp_retrieve_gt_bboxes = cp_retrieve_gt_bboxes[keep_list] + retrieve_gt_labels = retrieve_gt_labels[keep_list] + cp_retrieve_gt_bboxes = cp_retrieve_gt_bboxes[keep_list] mixup_gt_bboxes = np.concatenate( (results['gt_bboxes'], cp_retrieve_gt_bboxes), axis=0) diff --git a/tests/test_data/test_pipelines/test_transform/test_transform.py b/tests/test_data/test_pipelines/test_transform/test_transform.py index ba848dad9a8..d256ef1bb68 100644 --- a/tests/test_data/test_pipelines/test_transform/test_transform.py +++ b/tests/test_data/test_pipelines/test_transform/test_transform.py @@ -967,6 +967,33 @@ def test_mixup(): assert results['gt_bboxes'].dtype == np.float32 assert results['gt_bboxes_ignore'].dtype == np.float32 + # test filter bbox : + # 2 boxes with sides 1 and 3 are filtered as min_bbox_size=5 + gt_bboxes = np.array([[0, 0, 1, 1], [0, 0, 3, 3]], dtype=np.float32) + results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64) + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_ignore'] = np.array([], dtype=np.float32) + mixresults = results['mix_results'][0] + mixresults['gt_labels'] = copy.deepcopy(results['gt_labels']) + mixresults['gt_bboxes'] = copy.deepcopy(results['gt_bboxes']) + mixresults['gt_bboxes_ignore'] = copy.deepcopy(results['gt_bboxes_ignore']) + transform = dict( + type='MixUp', + img_scale=(10, 12), + ratio_range=(1.5, 1.5), + min_bbox_size=5, + skip_filter=False) + mixup_module = build_from_cfg(transform, PIPELINES) + + results = mixup_module(results) + + assert results['gt_bboxes'].shape[0] == 2 + assert results['gt_labels'].shape[0] == 2 + assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0] + assert results['gt_labels'].dtype == np.int64 + assert results['gt_bboxes'].dtype == np.float32 + assert results['gt_bboxes_ignore'].dtype == np.float32 + def test_photo_metric_distortion(): img = mmcv.imread(