Skip to content

Commit

Permalink
test: add basic correctness test with random masks
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico Pozzi committed Apr 24, 2022
1 parent 08e13b2 commit d01e74f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
40 changes: 31 additions & 9 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,13 @@ def vertical_flip_segmentation_mask():

@register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask():
for mask, padding, fill, padding_mode in itertools.product(
for mask, padding, padding_mode in itertools.product(
make_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
[0, 1], # fill
["constant", "symmetric", "edge"], # padding mode,
):
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
continue
if padding_mode == "edge" and fill != 0:
continue
if (
padding_mode == "edge"
and len(padding) == 2
Expand All @@ -375,7 +372,7 @@ def pad_segmentation_mask():
continue
if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]:
continue
yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode)
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -964,10 +961,35 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
mask[:, 1, 1] = 0

out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1)
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])

expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device)
expected_mask[:, 2, 2] = 0
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
expected_mask[:, 1:-1, 1:-1] = 1
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")])
def test_correctness_pad_segmentation_mask(padding, padding_mode):
def compute_expected_mask():
h, w = mask.shape[-2], mask.shape[-1]

pad_left = padding[0]
pad_up = padding[1]
pad_right = padding[2]
pad_down = padding[3]

new_h = h + pad_up + pad_down
new_w = w + pad_left + pad_right

new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w)
expected_mask = torch.zeros(new_shape, dtype=torch.long)
expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask

return expected_mask

for mask in make_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode)

expected_mask = compute_expected_mask()
torch.testing.assert_close(out_mask, expected_mask)
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def rotate_segmentation_mask(


def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant"
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
) -> torch.Tensor:
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode)
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode)


def pad_bounding_box(
Expand Down

0 comments on commit d01e74f

Please sign in to comment.