Skip to content

Commit

Permalink
feat: add functional pad on segmentation mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico Pozzi committed Apr 24, 2022
1 parent e99278a commit 08e13b2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
38 changes: 38 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,32 @@ def vertical_flip_segmentation_mask():
yield SampleInput(mask)


@register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask():
for mask, padding, fill, padding_mode in itertools.product(
make_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
[0, 1], # fill
["constant", "symmetric", "edge"], # padding mode,
):
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
continue
if padding_mode == "edge" and fill != 0:
continue
if (
padding_mode == "edge"
and len(padding) == 2
and mask.ndim not in [2, 3]
or len(padding) == 4
and mask.ndim not in [4, 3]
or len(padding) == 1
):
continue
if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]:
continue
yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -933,3 +959,15 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, -1, :] = 1
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
mask[:, 1, 1] = 0

out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1)

expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device)
expected_mask[:, 2, 2] = 0
torch.testing.assert_close(out_mask, expected_mask)
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
pad_bounding_box,
pad_image_tensor,
pad_image_pil,
pad_segmentation_mask,
crop_bounding_box,
crop_image_tensor,
crop_image_pil,
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ def rotate_segmentation_mask(
pad_image_pil = _FP.pad


def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant"
) -> torch.Tensor:
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode)


def pad_bounding_box(
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
) -> torch.Tensor:
Expand Down

0 comments on commit 08e13b2

Please sign in to comment.