diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 61de769d885..4e8595d2185 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -927,6 +927,29 @@ def test_randaug(self, inpt, interpolation, mocker): assert_close(expected_output, output, atol=1, rtol=0.1) + @pytest.mark.parametrize( + "interpolation", + [ + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, + ], + ) + def test_randaug_jit(self, interpolation): + inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) + t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) + t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1) + + tt_ref = torch.jit.script(t_ref) + tt = torch.jit.script(t) + + torch.manual_seed(12) + expected_output = tt_ref(inpt) + + torch.manual_seed(12) + scripted_output = tt(inpt) + + assert_equal(scripted_output, expected_output) + @pytest.mark.parametrize( "inpt", [ @@ -979,6 +1002,29 @@ def test_trivial_aug(self, inpt, interpolation, mocker): assert_close(expected_output, output, atol=1, rtol=0.1) + @pytest.mark.parametrize( + "interpolation", + [ + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, + ], + ) + def test_trivial_aug_jit(self, interpolation): + inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) + t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) + t = v2_transforms.TrivialAugmentWide(interpolation=interpolation) + + tt_ref = torch.jit.script(t_ref) + tt = torch.jit.script(t) + + torch.manual_seed(12) + expected_output = tt_ref(inpt) + + torch.manual_seed(12) + scripted_output = tt(inpt) + + assert_equal(scripted_output, expected_output) + @pytest.mark.parametrize( "inpt", [ @@ -1032,6 +1078,30 @@ def test_augmix(self, inpt, interpolation, mocker): assert_equal(expected_output, output) + @pytest.mark.parametrize( + "interpolation", + [ + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, + ], + ) + def test_augmix_jit(self, interpolation): + inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) + + t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) + t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) + + tt_ref = torch.jit.script(t_ref) + tt = torch.jit.script(t) + + torch.manual_seed(12) + expected_output = tt_ref(inpt) + + torch.manual_seed(12) + scripted_output = tt(inpt) + + assert_equal(scripted_output, expected_output) + @pytest.mark.parametrize( "inpt", [ @@ -1061,6 +1131,30 @@ def test_aa(self, inpt, interpolation): assert_equal(expected_output, output) + @pytest.mark.parametrize( + "interpolation", + [ + v2_transforms.InterpolationMode.NEAREST, + v2_transforms.InterpolationMode.BILINEAR, + ], + ) + def test_aa_jit(self, interpolation): + inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8) + aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") + t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) + t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation) + + tt_ref = torch.jit.script(t_ref) + tt = torch.jit.script(t) + + torch.manual_seed(12) + expected_output = tt_ref(inpt) + + torch.manual_seed(12) + scripted_output = tt(inpt) + + assert_equal(scripted_output, expected_output) + def import_transforms_from_references(reference): HERE = Path(__file__).parent diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 097e90fc4ab..2c82d092ec2 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -28,7 +28,16 @@ def __init__( ) -> None: super().__init__() self.interpolation = _check_interpolation(interpolation) - self.fill = _setup_fill_arg(fill) + self.fill = fill + self._fill = _setup_fill_arg(fill) + + def _extract_params_for_v1_transform(self) -> Dict[str, Any]: + params = super()._extract_params_for_v1_transform() + + if not (params["fill"] is None or isinstance(params["fill"], (int, float))): + raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") + + return params def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: keys = tuple(dct.keys()) @@ -335,7 +344,7 @@ def forward(self, *inputs: Any) -> Any: magnitude = 0.0 image_or_video = self._apply_image_or_video_transform( - image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) @@ -419,7 +428,7 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 image_or_video = self._apply_image_or_video_transform( - image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) @@ -491,7 +500,7 @@ def forward(self, *inputs: Any) -> Any: magnitude = 0.0 image_or_video = self._apply_image_or_video_transform( - image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) @@ -614,7 +623,7 @@ def forward(self, *inputs: Any) -> Any: magnitude = 0.0 aug = self._apply_image_or_video_transform( - aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill ) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)