diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 139761804e2..bcd589328c5 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -354,16 +354,13 @@ def vertical_flip_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn def pad_segmentation_mask(): - for mask, padding, fill, padding_mode in itertools.product( + for mask, padding, padding_mode in itertools.product( make_segmentation_masks(), [[1], [1, 1], [1, 1, 2, 2]], # padding - [0, 1], # fill ["constant", "symmetric", "edge"], # padding mode, ): if padding_mode == "symmetric" and mask.ndim not in [3, 4]: continue - if padding_mode == "edge" and fill != 0: - continue if ( padding_mode == "edge" and len(padding) == 2 @@ -375,7 +372,7 @@ def pad_segmentation_mask(): continue if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]: continue - yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode) + yield SampleInput(mask, padding=padding, padding_mode=padding_mode) @pytest.mark.parametrize( @@ -964,10 +961,35 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_pad_segmentation_mask_on_fixed_input(device): mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) - mask[:, 1, 1] = 0 - out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1) + out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1]) - expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device) - expected_mask[:, 2, 2] = 0 + expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device) + expected_mask[:, 1:-1, 1:-1] = 1 torch.testing.assert_close(out_mask, expected_mask) + + +@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")]) +def test_correctness_pad_segmentation_mask(padding, padding_mode): + def compute_expected_mask(): + h, w = mask.shape[-2], mask.shape[-1] + + pad_left = padding[0] + pad_up = padding[1] + pad_right = padding[2] + pad_down = padding[3] + + new_h = h + pad_up + pad_down + new_w = w + pad_left + pad_right + + new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w) + expected_mask = torch.zeros(new_shape, dtype=torch.long) + expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask + + return expected_mask + + for mask in make_segmentation_masks(): + out_mask = F.pad_segmentation_mask(mask, padding, padding_mode) + + expected_mask = compute_expected_mask() + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 3cf79e86e4a..22666faea63 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -397,9 +397,9 @@ def rotate_segmentation_mask( def pad_segmentation_mask( - segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant" + segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant" ) -> torch.Tensor: - return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode) + return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode) def pad_bounding_box(