diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index bcd589328c5..4661f18ffa7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -357,7 +357,7 @@ def pad_segmentation_mask(): for mask, padding, padding_mode in itertools.product( make_segmentation_masks(), [[1], [1, 1], [1, 1, 2, 2]], # padding - ["constant", "symmetric", "edge"], # padding mode, + ["constant", "symmetric", "edge", "reflect"], # padding mode, ): if padding_mode == "symmetric" and mask.ndim not in [3, 4]: continue @@ -969,15 +969,24 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): torch.testing.assert_close(out_mask, expected_mask) -@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")]) -def test_correctness_pad_segmentation_mask(padding, padding_mode): - def compute_expected_mask(): - h, w = mask.shape[-2], mask.shape[-1] +@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, 1.0, [1, 2]]) +def test_correctness_pad_segmentation_mask(padding): + def _parse_padding(): + if isinstance(padding, int): + return [padding] * 4 + if isinstance(padding, float): + return [int(padding)] * 4 + if isinstance(padding, list): + if len(padding) == 1: + return padding * 4 + if len(padding) == 2: + return padding * 2 # [left, up, right, down] + + return padding - pad_left = padding[0] - pad_up = padding[1] - pad_right = padding[2] - pad_down = padding[3] + def _compute_expected_mask(padding): + h, w = mask.shape[-2], mask.shape[-1] + pad_left, pad_up, pad_right, pad_down = padding new_h = h + pad_up + pad_down new_w = w + pad_left + pad_right @@ -988,8 +997,10 @@ def compute_expected_mask(): return expected_mask + padding = _parse_padding() + for mask in make_segmentation_masks(): - out_mask = F.pad_segmentation_mask(mask, padding, padding_mode) + out_mask = F.pad_segmentation_mask(mask, padding, "constant") - expected_mask = compute_expected_mask() + expected_mask = _compute_expected_mask(padding) torch.testing.assert_close(out_mask, expected_mask)