diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 363642175a7..8839d842b85 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -6,7 +6,7 @@ import pytest import torch -from common_utils import assert_equal +from common_utils import assert_equal, cpu_and_gpu from test_prototype_transforms_functional import ( make_bounding_box, make_bounding_boxes, @@ -15,6 +15,7 @@ make_one_hot_labels, make_segmentation_mask, ) +from torchvision.ops.boxes import box_iou from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image @@ -1127,6 +1128,124 @@ def test_ctor(self, trfms): assert isinstance(output, torch.Tensor) +class TestRandomIoUCrop: + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) + def test__get_params(self, device, options, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + bboxes = features.BoundingBox( + torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), + format="XYXY", + image_size=image.image_size, + device=device, + ) + sample = [image, bboxes] + + transform = transforms.RandomIoUCrop(sampler_options=options) + + n_samples = 5 + for _ in range(n_samples): + + params = transform._get_params(sample) + + if options == [2.0]: + assert len(params) == 0 + return + + assert len(params["is_within_crop_area"]) > 0 + assert params["is_within_crop_area"].dtype == torch.bool + + orig_h = image.image_size[0] + orig_w = image.image_size[1] + assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) + assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) + + left, top = params["left"], params["top"] + new_h, new_w = params["height"], params["width"] + ious = box_iou( + bboxes, + torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device), + ) + assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}" + + def test__transform_empty_params(self, mocker): + transform = transforms.RandomIoUCrop(sampler_options=[2.0]) + image = features.Image(torch.rand(1, 3, 4, 4)) + bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4)) + label = features.Label(torch.tensor([1])) + sample = [image, bboxes, label] + # Let's mock transform._get_params to control the output: + transform._get_params = mocker.MagicMock(return_value={}) + output = transform(sample) + torch.testing.assert_close(output, sample) + + def test_forward_assertion(self): + transform = transforms.RandomIoUCrop() + with pytest.raises( + TypeError, + match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels", + ): + transform(torch.tensor(0)) + + def test__transform(self, mocker): + transform = transforms.RandomIoUCrop() + + image = features.Image(torch.rand(3, 32, 24)) + bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,)) + label = features.Label(torch.randint(0, 10, size=(6,))) + ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) + masks = make_segmentation_mask((32, 24)) + ohe_masks = features.SegmentationMask(torch.randint(0, 2, size=(6, 32, 24))) + sample = [image, bboxes, label, ohe_label, masks, ohe_masks] + + fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x) + is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) + + params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) + transform._get_params = mocker.MagicMock(return_value=params) + output = transform(sample) + + assert fn.call_count == 4 + + expected_calls = [ + mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), + mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), + mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), + mocker.call( + ohe_masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ), + ] + + fn.assert_has_calls(expected_calls) + + expected_within_targets = sum(is_within_crop_area) + + # check number of bboxes vs number of labels: + output_bboxes = output[1] + assert isinstance(output_bboxes, features.BoundingBox) + assert len(output_bboxes) == expected_within_targets + + # check labels + output_label = output[2] + assert isinstance(output_label, features.Label) + assert len(output_label) == expected_within_targets + torch.testing.assert_close(output_label, label[is_within_crop_area]) + + output_ohe_label = output[3] + assert isinstance(output_ohe_label, features.OneHotLabel) + torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) + + output_masks = output[4] + assert isinstance(output_masks, features.SegmentationMask) + assert output_masks.shape[:-2] == masks.shape[:-2] + + output_ohe_masks = output[5] + assert isinstance(output_ohe_masks, features.SegmentationMask) + assert len(output_ohe_masks) == expected_within_targets + + class TestScaleJitter: def test__get_params(self, mocker): image_size = (24, 32) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index ca89fee918a..e1ba20904fe 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -24,6 +24,7 @@ RandomAffine, RandomCrop, RandomHorizontalFlip, + RandomIoUCrop, RandomPerspective, RandomResizedCrop, RandomRotation, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 303f4502b04..e0215caaf87 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -5,15 +5,17 @@ import PIL.Image import torch +from torchvision.ops.boxes import box_iou from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms.functional import InterpolationMode, pil_to_tensor from torchvision.transforms.functional_tensor import _parse_pad_padding from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size + from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image +from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image class RandomHorizontalFlip(_RandomApplyTransform): @@ -620,6 +622,116 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) +class RandomIoUCrop(Transform): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): + super().__init__() + # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 + self.min_scale = min_scale + self.max_scale = max_scale + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + if sampler_options is None: + sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] + self.options = sampler_options + self.trials = trials + + def _get_params(self, sample: Any) -> Dict[str, Any]: + + image = query_image(sample) + _, orig_h, orig_w = get_image_dimensions(image) + bboxes = query_bounding_box(sample) + + while True: + # sample an option + idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + min_jaccard_overlap = self.options[idx] + if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option + return dict() + + for _ in range(self.trials): + # check the aspect ratio limitations + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + new_w = int(orig_w * r[0]) + new_h = int(orig_h * r[1]) + aspect_ratio = new_w / new_h + if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): + continue + + # check for 0 area crops + r = torch.rand(2) + left = int((orig_w - new_w) * r[0]) + top = int((orig_h - new_h) * r[1]) + right = left + new_w + bottom = top + new_h + if left == right or top == bottom: + continue + + # check for any valid boxes with centers within the crop area + xyxy_bboxes = F.convert_bounding_box_format( + bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True + ) + cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) + cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) + is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + if not is_within_crop_area.any(): + continue + + # check at least 1 box with jaccard limitations + xyxy_bboxes = xyxy_bboxes[is_within_crop_area] + ious = box_iou( + xyxy_bboxes, + torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), + ) + if ious.max() < min_jaccard_overlap: + continue + + return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if len(params) < 1: + return inpt + + is_within_crop_area = params["is_within_crop_area"] + + if isinstance(inpt, (features.Label, features.OneHotLabel)): + return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] + + output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + + if isinstance(output, features.BoundingBox): + bboxes = output[is_within_crop_area] + bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) + output = features.BoundingBox.new_like(output, bboxes) + elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1: + # apply is_within_crop_area if mask is one-hot encoded + masks = output[is_within_crop_area] + output = features.SegmentationMask.new_like(output, masks) + + return output + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + # TODO: Allow image to be a torch.Tensor + if not ( + has_all(sample, features.BoundingBox) + and has_any(sample, PIL.Image.Image, features.Image) + and has_any(sample, features.Label, features.OneHotLabel) + ): + raise TypeError( + f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " + "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." + ) + return super().forward(sample) + + class ScaleJitter(Transform): def __init__( self, diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 9f2ef84ced5..4cfe1da3649 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -17,6 +17,15 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im raise TypeError("No image was found in the sample") +def query_bounding_box(sample: Any) -> features.BoundingBox: + flat_sample, _ = tree_flatten(sample) + for i in flat_sample: + if isinstance(i, features.BoundingBox): + return i + + raise TypeError("No bounding box was found in the sample") + + def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: if isinstance(image, features.Image): channels = image.num_channels diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index fee0c4dd1e3..5883cc9119a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip from ._meta import ( + clamp_bounding_box, convert_bounding_box_format, convert_color_space_image_tensor, convert_color_space_image_pil, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index f1aea2018bc..168a6dfe1b4 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -61,6 +61,15 @@ def convert_bounding_box_format( return bounding_box +def clamp_bounding_box( + bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY) + xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1]) + xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0]) + return convert_bounding_box_format(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False) + + def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return image[..., :-1, :, :], image[..., -1:, :, :]