Skip to content

Commit

Permalink
[fbsync] port AA tests (#7927)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>

Reviewed By: matteobettini

Differential Revision: D49600791

fbshipit-source-id: abf058e28a949717be7ad343e5417fca098d4078
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 26, 2023
1 parent f584964 commit fcfd1b2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 278 deletions.
275 changes: 0 additions & 275 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,281 +705,6 @@ def test_to_tensor(self):
assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))


class TestAATransforms:
@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
# Stable API, if signed there is another random call
if t._AUGMENTATION_SPACE[keys[i]][1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

for i in range(le):
expected_output = t_ref(inpt)
output = t(inpt)

assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_randaug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)

randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

for _ in range(le):
expected_output = t_ref(inpt)
output = t(inpt)

assert_close(expected_output, output, atol=1, rtol=0.1)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_trivial_aug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)

le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
randint_values = []
for i in range(le):
# Stable API, op_index random call
randint_values.append(i)
key = keys[i]
# Stable API, random magnitude
aug_op = t._AUGMENTATION_SPACE[key]
magnitudes = aug_op[0](2, 0, 0)
if magnitudes is not None:
randint_values.append(5)
# Stable API, if signed there is another random call
if aug_op[1]:
randint_values.append(0)
# New API, _get_random_item
randint_values.append(i)
# New API, random magnitude
if magnitudes is not None:
randint_values.append(5)

randint_values = iter(randint_values)

mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
mocker.patch("torch.rand", return_value=1.0)

expected_output = t_ref(inpt)
output = t(inpt)

assert_equal(expected_output, output)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_augmix_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)

t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)

@pytest.mark.parametrize(
"inpt",
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)

torch.manual_seed(12)
expected_output = t_ref(inpt)

torch.manual_seed(12)
output = t(inpt)

assert_equal(expected_output, output)

@pytest.mark.parametrize(
"interpolation",
[
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_aa_jit(self, interpolation):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)

torch.manual_seed(12)
expected_output = tt_ref(inpt)

torch.manual_seed(12)
scripted_output = tt(inpt)

assert_equal(scripted_output, expected_output)


def import_transforms_from_references(reference):
HERE = Path(__file__).parent
PROJECT_ROOT = HERE.parent
Expand Down
Loading

0 comments on commit fcfd1b2

Please sign in to comment.