From 3577009ef97d554c8551817bf051595eed247277 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Thu, 21 Apr 2022 21:38:50 +0200 Subject: [PATCH 1/3] test: add functional vertical flip tests on segmentation mask --- test/test_prototype_transforms_functional.py | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 91623854330..777ae065f59 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -332,6 +332,12 @@ def crop_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_segmentation_mask(): + for mask in make_segmentation_masks(extra_dims=((), (4,))): + yield SampleInput(mask) + + @pytest.mark.parametrize( "kernel", [ @@ -860,3 +866,26 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte ) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): + mask = torch.tensor( + [ + [[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + ], + device=device, + ) + + expected_mask = torch.tensor( + [ + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], + ], + device=device, + ) + out_mask = F.vertical_flip_segmentation_mask(mask) + torch.testing.assert_close(out_mask, expected_mask) From abcffc29dd4396300a76b6ffe818a9ca238cd8a1 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Fri, 22 Apr 2022 15:04:37 +0200 Subject: [PATCH 2/3] refactor: improve test readibility --- test/test_prototype_transforms_functional.py | 23 +++++--------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 777ae065f59..9f1e2c16676 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -334,7 +334,7 @@ def crop_bounding_box(): @register_kernel_info_from_sample_inputs_fn def vertical_flip_segmentation_mask(): - for mask in make_segmentation_masks(extra_dims=((), (4,))): + for mask in make_segmentation_masks(): yield SampleInput(mask) @@ -870,22 +870,11 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): - mask = torch.tensor( - [ - [[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], - [[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], - ], - device=device, - ) + mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + mask[:, 0, :] = 1 - expected_mask = torch.tensor( - [ - [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], - [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], - ], - device=device, - ) out_mask = F.vertical_flip_segmentation_mask(mask) + + expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + expected_mask[:, -1, :] = 1 torch.testing.assert_close(out_mask, expected_mask) From 30ad0759bd3a981168bd1fec16cab9b73b3f09d2 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 22 Apr 2022 16:31:30 +0200 Subject: [PATCH 3/3] Update test_prototype_transforms_functional.py --- test/test_prototype_transforms_functional.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 09cb334af90..6c99720114a 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -345,7 +345,7 @@ def crop_segmentation_mask(): width=width, ) - + @register_kernel_info_from_sample_inputs_fn def vertical_flip_segmentation_mask(): for mask in make_segmentation_masks(): @@ -933,4 +933,3 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) expected_mask[:, -1, :] = 1 torch.testing.assert_close(out_mask, expected_mask) -