From fcfd1b28e6b74385474f5b619a68dd83925aa26c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 26 Sep 2023 04:55:23 -0700 Subject: [PATCH] [fbsync] port AA tests (#7927) Summary: Co-authored-by: Nicolas Hug Reviewed By: matteobettini Differential Revision: D49600791 fbshipit-source-id: abf058e28a949717be7ad343e5417fca098d4078 --- test/test_transforms_v2_consistency.py | 275 ------------------------- test/test_transforms_v2_refactored.py | 117 ++++++++++- 2 files changed, 114 insertions(+), 278 deletions(-) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 1f47eb2117f..9badd8dbe1b 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -705,281 +705,6 @@ def test_to_tensor(self): assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy)) -class TestAATransforms: - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_randaug(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) - t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - # Stable API, if signed there is another random call - if t._AUGMENTATION_SPACE[keys[i]][1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - for i in range(le): - expected_output = t_ref(inpt) - output = t(inpt) - - assert_close(expected_output, output, atol=1, rtol=0.1) - - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) - def test_randaug_jit(self, interpolation, fill): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill) - t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_trivial_aug(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) - t = v2_transforms.TrivialAugmentWide(interpolation=interpolation) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - key = keys[i] - # Stable API, random magnitude - aug_op = t._AUGMENTATION_SPACE[key] - magnitudes = aug_op[0](2, 0, 0) - if magnitudes is not None: - randint_values.append(5) - # Stable API, if signed there is another random call - if aug_op[1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - # New API, random magnitude - if magnitudes is not None: - randint_values.append(5) - - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - for _ in range(le): - expected_output = t_ref(inpt) - output = t(inpt) - - assert_close(expected_output, output, atol=1, rtol=0.1) - - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) - def test_trivial_aug_jit(self, interpolation, fill): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill) - t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_augmix(self, inpt, interpolation, mocker): - t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) - t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1) - t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) - t._sample_dirichlet = lambda t: t.softmax(dim=-1) - - le = len(t._AUGMENTATION_SPACE) - keys = list(t._AUGMENTATION_SPACE.keys()) - randint_values = [] - for i in range(le): - # Stable API, op_index random call - randint_values.append(i) - key = keys[i] - # Stable API, random magnitude - aug_op = t._AUGMENTATION_SPACE[key] - magnitudes = aug_op[0](2, 0, 0) - if magnitudes is not None: - randint_values.append(5) - # Stable API, if signed there is another random call - if aug_op[1]: - randint_values.append(0) - # New API, _get_random_item - randint_values.append(i) - # New API, random magnitude - if magnitudes is not None: - randint_values.append(5) - - randint_values = iter(randint_values) - - mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values))) - mocker.patch("torch.rand", return_value=1.0) - - expected_output = t_ref(inpt) - output = t(inpt) - - assert_equal(expected_output, output) - - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - @pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1]) - def test_augmix_jit(self, interpolation, fill): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - - t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill) - t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - - @pytest.mark.parametrize( - "inpt", - [ - torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), - PIL.Image.new("RGB", (256, 256), 123), - tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), - ], - ) - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - PIL.Image.NEAREST, - ], - ) - def test_aa(self, inpt, interpolation): - aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") - t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) - t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) - - torch.manual_seed(12) - expected_output = t_ref(inpt) - - torch.manual_seed(12) - output = t(inpt) - - assert_equal(expected_output, output) - - @pytest.mark.parametrize( - "interpolation", - [ - v2_transforms.InterpolationMode.NEAREST, - v2_transforms.InterpolationMode.BILINEAR, - ], - ) - def test_aa_jit(self, interpolation): - inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) - aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") - t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) - t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) - - tt_ref = torch.jit.script(t_ref) - tt = torch.jit.script(t) - - torch.manual_seed(12) - expected_output = tt_ref(inpt) - - torch.manual_seed(12) - scripted_output = tt(inpt) - - assert_equal(scripted_output, expected_output) - - def import_transforms_from_references(reference): HERE = Path(__file__).parent PROJECT_ROOT = HERE.parent diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 052cc4291ad..e978f57f257 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -232,7 +232,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol): """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static ``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version can be called without error.""" - if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): + if not (type(input) is torch.Tensor or isinstance(input, PIL.Image.Image)): return v1_transform_cls = transform._v1_transform_cls @@ -250,7 +250,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol): with freeze_rng_state(): output_v1 = v1_transform(input) - assert_close(output_v2, output_v1, rtol=rtol, atol=atol) + assert_close(F.to_image(output_v2), F.to_image(output_v1), rtol=rtol, atol=atol) if isinstance(input, PIL.Image.Image): return @@ -2772,7 +2772,10 @@ def test_functional_signature(self, kernel, input_type): ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): - check_transform(transforms.RandomErasing(p=1), make_input(device=device)) + input = make_input(device=device) + check_transform( + transforms.RandomErasing(p=1), input, check_v1_compatibility=not isinstance(input, PIL.Image.Image) + ) def _reference_erase_image(self, image, *, i, j, h, w, v): mask = torch.zeros_like(image, dtype=torch.bool) @@ -2898,3 +2901,111 @@ def test__get_params(self, sigma): else: assert sigma[0] <= params["sigma"][0] <= sigma[1] assert sigma[0] <= params["sigma"][1] <= sigma[1] + + +class TestAutoAugmentTransforms: + # These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling. + # It's typically very hard to test the effect on some parameters without heavy mocking logic. + # This class adds correctness tests for the kernels that are specific to those transforms. The rest of kernels, e.g. + # rotate, are tested in their respective classes. The rest of the tests here are mostly smoke tests. + + def _reference_shear_translate(self, image, *, transform_id, magnitude, interpolation, fill): + if isinstance(image, PIL.Image.Image): + input = image + else: + input = F.to_pil_image(image) + + matrix = { + "ShearX": (1, magnitude, 0, 0, 1, 0), + "ShearY": (1, 0, 0, magnitude, 1, 0), + "TranslateX": (1, 0, -int(magnitude), 0, 1, 0), + "TranslateY": (1, 0, 0, 0, 1, -int(magnitude)), + }[transform_id] + + output = input.transform( + input.size, PIL.Image.AFFINE, matrix, resample=pil_modes_mapping[interpolation], fill=fill + ) + + if isinstance(image, PIL.Image.Image): + return output + else: + return F.to_image(output) + + @pytest.mark.parametrize("transform_id", ["ShearX", "ShearY", "TranslateX", "TranslateY"]) + @pytest.mark.parametrize("magnitude", [0.3, -0.2, 0.0]) + @pytest.mark.parametrize( + "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] + ) + @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) + @pytest.mark.parametrize("input_type", ["Tensor", "PIL"]) + def test_correctness_shear_translate(self, transform_id, magnitude, interpolation, fill, input_type): + # ShearX/Y and TranslateX/Y are the only ops that are native to the AA transforms. They are modeled after the + # reference implementation: + # https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L362 + # All other ops are checked in their respective dedicated tests. + + image = make_image(dtype=torch.uint8, device="cpu") + if input_type == "PIL": + image = F.to_pil_image(image) + + if "Translate" in transform_id: + # For TranslateX/Y magnitude is a value in pixels + magnitude *= min(F.get_size(image)) + + actual = transforms.AutoAugment()._apply_image_or_video_transform( + image, + transform_id=transform_id, + magnitude=magnitude, + interpolation=interpolation, + fill={type(image): fill}, + ) + expected = self._reference_shear_translate( + image, transform_id=transform_id, magnitude=magnitude, interpolation=interpolation, fill=fill + ) + + if input_type == "PIL": + actual, expected = F.to_image(actual), F.to_image(expected) + + if "Shear" in transform_id and input_type == "Tensor": + mae = (actual.float() - expected.float()).abs().mean() + assert mae < (12 if interpolation is transforms.InterpolationMode.NEAREST else 5) + else: + assert_close(actual, expected, rtol=0, atol=1) + + @pytest.mark.parametrize( + "transform", + [transforms.AutoAugment(), transforms.RandAugment(), transforms.TrivialAugmentWide(), transforms.AugMix()], + ) + @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_smoke(self, transform, make_input, dtype, device): + if make_input is make_image_pil and not (dtype is torch.uint8 and device == "cpu"): + pytest.skip( + "PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' " + "will degenerate to that anyway." + ) + input = make_input(dtype=dtype, device=device) + + with freeze_rng_state(): + # By default every test starts from the same random seed. This leads to minimal coverage of the sampling + # that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage, + # we build a reproducible random seed from the input type, dtype, and device. + torch.manual_seed(hash((make_input, dtype, device))) + + # For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1 + # and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks + # here and only check if we can script the v2 transform and subsequently call the result. + check_transform(transform, input, check_v1_compatibility=False) + + if type(input) is torch.Tensor and dtype is torch.uint8: + _script(transform)(input) + + def test_auto_augment_policy_error(self): + with pytest.raises(ValueError, match="provided policy"): + transforms.AutoAugment(policy=None) + + @pytest.mark.parametrize("severity", [0, 11]) + def test_aug_mix_severity_error(self, severity): + with pytest.raises(ValueError, match="severity must be between"): + transforms.AugMix(severity=severity)