diff --git a/CHANGELOG.md b/CHANGELOG.md index b13f6dd5439..ac4c58da206 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978)) ### Changed diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 5542372d436..9a9aed7ba71 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -24,6 +24,7 @@ from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.imports import ( _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, @@ -238,6 +239,8 @@ class MeanAveragePrecision(Metric): groundtruth_crowds: List[Tensor] groundtruth_area: List[Tensor] + warn_on_many_detections: bool = True + def __init__( self, box_format: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", @@ -327,7 +330,7 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] _input_validator(preds, target, iou_type=self.iou_type) for item in preds: - detections = self._get_safe_item_values(item) + detections = self._get_safe_item_values(item, warn=self.warn_on_many_detections) self.detections.append(detections) self.detection_labels.append(item["labels"]) @@ -540,11 +543,12 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None: with open(f"{name}_target.json", "w") as f: f.write(target_json) - def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: + def _get_safe_item_values(self, item: Dict[str, Any], warn: bool = False) -> Union[Tensor, Tuple]: """Convert and return the boxes or masks from the item depending on the iou_type. Args: item: input dictionary containing the boxes or masks + warn: whether to warn if the number of boxes or masks exceeds the max_detection_thresholds Returns: boxes or masks depending on the iou_type @@ -554,12 +558,16 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]: boxes = _fix_empty_tensors(item["boxes"]) if boxes.numel() > 0: boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xywh") + if warn and len(boxes) > self.max_detection_thresholds[-1]: + _warning_on_too_many_detections(self.max_detection_thresholds[-1]) return boxes if self.iou_type == "segm": masks = [] for i in item["masks"].cpu().numpy(): rle = mask_utils.encode(np.asfortranarray(i)) masks.append((tuple(rle["size"]), rle["counts"])) + if warn and len(masks) > self.max_detection_thresholds[-1]: + _warning_on_too_many_detections(self.max_detection_thresholds[-1]) return tuple(masks) raise Exception(f"IOU type {self.iou_type} is not supported") @@ -741,3 +749,13 @@ def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] dist.all_gather_object(list_gathered, list_to_gather, group=process_group) return [list_gathered[rank][idx] for idx in range(len(list_gathered[0])) for rank in range(world_size)] + + +def _warning_on_too_many_detections(limit: int) -> None: + rank_zero_warn( + f"Encountered more than {limit} detections in a single image. This means that certain detections with the" + " lowest scores will be ignored, that may have an undesirable impact on performance. Please consider adjusting" + " the `max_detection_threshold` to suit your use case. To disable this warning, set attribute class" + " `warn_on_many_detections=False`, after initializing the metric.", + UserWarning, + ) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index f1c002428a0..227b914d390 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -622,19 +622,19 @@ def test_error_on_wrong_input(): ) -def _generate_random_segm_input(device): +def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_size=10, random_size=True): """Generate random inputs for mAP when iou_type=segm.""" preds = [] targets = [] - for _ in range(2): + for _ in range(batch_size): result = {} - num_preds = torch.randint(0, 10, (1,)).item() + num_preds = torch.randint(0, num_preds_size, (1,)).item() if random_size else num_preds_size result["scores"] = torch.rand((num_preds,), device=device) result["labels"] = torch.randint(0, 10, (num_preds,), device=device) result["masks"] = torch.randint(0, 2, (num_preds, 10, 10), device=device).bool() preds.append(result) gt = {} - num_gt = torch.randint(0, 10, (1,)).item() + num_gt = torch.randint(0, num_gt_size, (1,)).item() if random_size else num_gt_size gt["labels"] = torch.randint(0, 10, (num_gt,), device=device) gt["masks"] = torch.randint(0, 2, (num_gt, 10, 10), device=device).bool() targets.append(gt) @@ -683,3 +683,23 @@ def test_for_box_format(box_format, iou_val_expected, map_val_expected): result = metric.compute() assert result["map"].item() == map_val_expected assert round(float(metric.coco_eval.ious[(0, 0)]), 3) == iou_val_expected + + +@pytest.mark.parametrize("iou_type", ["bbox", "segm"]) +def test_warning_on_many_detections(iou_type): + """Test that a warning is raised when there are many detections.""" + if iou_type == "bbox": + preds = [ + { + "boxes": torch.tensor([[0.5, 0.5, 1, 1]]).repeat(101, 1), + "scores": torch.tensor([1.0]).repeat(101), + "labels": torch.tensor([0]).repeat(101), + } + ] + targets = [{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}] + else: + preds, targets = _generate_random_segm_input("cpu", 1, 101, 10, False) + + metric = MeanAveragePrecision(iou_type=iou_type) + with pytest.warns(UserWarning, match="Encountered more than 100 detections in a single image.*"): + metric.update(preds, targets)