From 08e13b2fe7c08c4c19293e5ea574e7b5354f6b67 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Fri, 22 Apr 2022 22:54:12 +0200 Subject: [PATCH] feat: add functional pad on segmentation mask --- test/test_prototype_transforms_functional.py | 38 +++++++++++++++++++ .../transforms/functional/__init__.py | 1 + .../transforms/functional/_geometry.py | 6 +++ 3 files changed, 45 insertions(+) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6c99720114a..139761804e2 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -352,6 +352,32 @@ def vertical_flip_segmentation_mask(): yield SampleInput(mask) +@register_kernel_info_from_sample_inputs_fn +def pad_segmentation_mask(): + for mask, padding, fill, padding_mode in itertools.product( + make_segmentation_masks(), + [[1], [1, 1], [1, 1, 2, 2]], # padding + [0, 1], # fill + ["constant", "symmetric", "edge"], # padding mode, + ): + if padding_mode == "symmetric" and mask.ndim not in [3, 4]: + continue + if padding_mode == "edge" and fill != 0: + continue + if ( + padding_mode == "edge" + and len(padding) == 2 + and mask.ndim not in [2, 3] + or len(padding) == 4 + and mask.ndim not in [4, 3] + or len(padding) == 1 + ): + continue + if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]: + continue + yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode) + + @pytest.mark.parametrize( "kernel", [ @@ -933,3 +959,15 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) expected_mask[:, -1, :] = 1 torch.testing.assert_close(out_mask, expected_mask) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_pad_segmentation_mask_on_fixed_input(device): + mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) + mask[:, 1, 1] = 0 + + out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1) + + expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device) + expected_mask[:, 2, 2] = 0 + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index bbfa9584d88..d990e346202 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -60,6 +60,7 @@ pad_bounding_box, pad_image_tensor, pad_image_pil, + pad_segmentation_mask, crop_bounding_box, crop_image_tensor, crop_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 65673203941..3cf79e86e4a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -396,6 +396,12 @@ def rotate_segmentation_mask( pad_image_pil = _FP.pad +def pad_segmentation_mask( + segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant" +) -> torch.Tensor: + return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode) + + def pad_bounding_box( bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat ) -> torch.Tensor: