Skip to content

Commit

Permalink
test: add functional vertical flip tests on segmentation mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico Pozzi committed Apr 21, 2022
1 parent a64c674 commit 3577009
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ def crop_bounding_box():
)


@register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))):
yield SampleInput(mask)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -860,3 +866,26 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
)

torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)


@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
mask = torch.tensor(
[
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
],
device=device,
)

expected_mask = torch.tensor(
[
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]],
],
device=device,
)
out_mask = F.vertical_flip_segmentation_mask(mask)
torch.testing.assert_close(out_mask, expected_mask)

0 comments on commit 3577009

Please sign in to comment.