diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 91623854330..de49d8a8bef 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -332,6 +332,20 @@ def crop_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def crop_segmentation_mask(): + for mask, top, left, height, width in itertools.product( + make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20] + ): + yield SampleInput( + mask, + top=top, + left=left, + height=height, + width=width, + ) + + @pytest.mark.parametrize( "kernel", [ @@ -860,3 +874,44 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte ) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "top, left, height, width", + [ + [4, 6, 30, 40], + [-8, 6, 70, 40], + [-8, -6, 70, 8], + ], +) +def test_correctness_crop_segmentation_mask(device, top, left, height, width): + def _compute_expected_mask(mask, top_, left_, height_, width_): + h, w = mask.shape[-2], mask.shape[-1] + if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w: + expected = mask[..., top_ : top_ + height_, left_ : left_ + width_] + else: + # Create output mask + expected_shape = mask.shape[:-2] + (height_, width_) + expected = torch.zeros(expected_shape, device=mask.device, dtype=mask.dtype) + + out_y1 = abs(top_) if top_ < 0 else 0 + out_y2 = h - top_ if top_ + height_ >= h else height_ + out_x1 = abs(left_) if left_ < 0 else 0 + out_x2 = w - left_ if left_ + width_ >= w else width_ + + in_y1 = 0 if top_ < 0 else top_ + in_y2 = h if top_ + height_ >= h else top_ + height_ + in_x1 = 0 if left_ < 0 else left_ + in_x2 = w if left_ + width_ >= w else left_ + width_ + # Paste input mask into output + expected[..., out_y1:out_y2, out_x1:out_x2] = mask[..., in_y1:in_y2, in_x1:in_x2] + + return expected + + for mask in make_segmentation_masks(): + if mask.device != torch.device(device): + mask = mask.to(device) + output_mask = F.crop_segmentation_mask(mask, top, left, height, width) + expected_mask = _compute_expected_mask(mask, top, left, height, width) + torch.testing.assert_close(output_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index decf9e21020..bbfa9584d88 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -63,6 +63,7 @@ crop_bounding_box, crop_image_tensor, crop_image_pil, + crop_segmentation_mask, perspective_image_tensor, perspective_image_pil, vertical_flip_image_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 71be0a22c00..d4f1fadb0bf 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -440,6 +440,10 @@ def crop_bounding_box( ).view(shape) +def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: + return crop_image_tensor(img, top, left, height, width) + + def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float],