Skip to content

Commit

Permalink
[fbsync] Fixed issue with jitted AA transforms in v2 and added tests (#…
Browse files Browse the repository at this point in the history
…7839)

Summary: (Note: this ignores all push blocking failures!)

Reviewed By: matteobettini

Differential Revision: D48900412

fbshipit-source-id: 516d3744e6c6115394abd80918ba4ad87eb5c5d5
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 6, 2023
1 parent c5630e9 commit 0d4aa66
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 5 deletions.
94 changes: 94 additions & 0 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,29 @@ def test_randaug(self, inpt, interpolation, mocker):

assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_randaug_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
Expand Down Expand Up @@ -979,6 +1002,29 @@ def test_trivial_aug(self, inpt, interpolation, mocker):

assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_trivial_aug_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
Expand Down Expand Up @@ -1032,6 +1078,30 @@ def test_augmix(self, inpt, interpolation, mocker):

assert_equal(expected_output, output)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_augmix_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)

t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
Expand Down Expand Up @@ -1061,6 +1131,30 @@ def test_aa(self, inpt, interpolation):

assert_equal(expected_output, output)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_aa_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)


def import_transforms_from_references(reference):
HERE = Path(__file__).parent
Expand Down
19 changes: 14 additions & 5 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ def __init__(
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()

if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")

return params

def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
keys = tuple(dct.keys())
Expand Down Expand Up @@ -335,7 +344,7 @@ def forward(self, *inputs: Any) -> Any:
magnitude = 0.0

image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)

return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
Expand Down Expand Up @@ -419,7 +428,7 @@ def forward(self, *inputs: Any) -> Any:
else:
magnitude = 0.0
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)

return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
Expand Down Expand Up @@ -491,7 +500,7 @@ def forward(self, *inputs: Any) -> Any:
magnitude = 0.0

image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)

Expand Down Expand Up @@ -614,7 +623,7 @@ def forward(self, *inputs: Any) -> Any:
magnitude = 0.0

aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
Expand Down

0 comments on commit 0d4aa66

Please sign in to comment.