Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico Pozzi committed Apr 26, 2022
1 parent b425177 commit 33ec8c1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ def pad_segmentation_mask():
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
if padding_mode == "symmetric" and mask.ndim not in [2, 3, 4]:
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
continue

if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [2, 3, 4]:
if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [3, 4]:
continue

yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
Expand Down Expand Up @@ -1049,6 +1049,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
torch.testing.assert_close(output_mask, expected_mask)


@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

Expand Down

0 comments on commit 33ec8c1

Please sign in to comment.