Skip to content

Commit

Permalink
test: add all padding options
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico Pozzi committed Apr 24, 2022
1 parent d01e74f commit 8b42851
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def pad_segmentation_mask():
for mask, padding, padding_mode in itertools.product(
make_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge"], # padding mode,
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
continue
Expand Down Expand Up @@ -969,15 +969,24 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
torch.testing.assert_close(out_mask, expected_mask)


@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")])
def test_correctness_pad_segmentation_mask(padding, padding_mode):
def compute_expected_mask():
h, w = mask.shape[-2], mask.shape[-1]
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, 1.0, [1, 2]])
def test_correctness_pad_segmentation_mask(padding):
def _parse_padding():
if isinstance(padding, int):
return [padding] * 4
if isinstance(padding, float):
return [int(padding)] * 4
if isinstance(padding, list):
if len(padding) == 1:
return padding * 4
if len(padding) == 2:
return padding * 2 # [left, up, right, down]

return padding

pad_left = padding[0]
pad_up = padding[1]
pad_right = padding[2]
pad_down = padding[3]
def _compute_expected_mask(padding):
h, w = mask.shape[-2], mask.shape[-1]
pad_left, pad_up, pad_right, pad_down = padding

new_h = h + pad_up + pad_down
new_w = w + pad_left + pad_right
Expand All @@ -988,8 +997,10 @@ def compute_expected_mask():

return expected_mask

padding = _parse_padding()

for mask in make_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode)
out_mask = F.pad_segmentation_mask(mask, padding, "constant")

expected_mask = compute_expected_mask()
expected_mask = _compute_expected_mask(padding)
torch.testing.assert_close(out_mask, expected_mask)

0 comments on commit 8b42851

Please sign in to comment.