Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable v1 vs. v2 consistency in refactored tests #7882

Merged
merged 12 commits into from
Aug 28, 2023
142 changes: 0 additions & 142 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum
import importlib.machinery
import importlib.util
import inspect
Expand Down Expand Up @@ -83,35 +82,6 @@ def __init__(
supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
),
ConsistencyConfig(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these configs are now obsolete. Going forwards, whenever we port a new test to test_*_refactored.py, we can also remove the configs here.

v2_transforms.Resize,
legacy_transforms.Resize,
[
NotScriptableArgsKwargs(32),
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
ArgsKwargs([31], max_size=32),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
),
ConsistencyConfig(
v2_transforms.CenterCrop,
legacy_transforms.CenterCrop,
Expand Down Expand Up @@ -187,20 +157,6 @@ def __init__(
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.ConvertImageDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
ArgsKwargs(torch.bfloat16),
ArgsKwargs(torch.float32),
ArgsKwargs(torch.float64),
ArgsKwargs(torch.uint8),
],
supports_pil=False,
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.ToPILImage,
legacy_transforms.ToPILImage,
Expand All @@ -226,22 +182,6 @@ def __init__(
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.RandomHorizontalFlip,
legacy_transforms.RandomHorizontalFlip,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
v2_transforms.RandomVerticalFlip,
legacy_transforms.RandomVerticalFlip,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
),
ConsistencyConfig(
v2_transforms.RandomEqualize,
legacy_transforms.RandomEqualize,
Expand Down Expand Up @@ -367,30 +307,6 @@ def __init__(
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
*[
ConsistencyConfig(
v2_transforms.ElasticTransform,
legacy_transforms.ElasticTransform,
[
ArgsKwargs(),
ArgsKwargs(alpha=20.0),
ArgsKwargs(alpha=(15.3, 27.2)),
ArgsKwargs(sigma=3.0),
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]),
# We updated gaussian blur kernel generation with a faster and numerically more stable version
# This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
closeness_kwargs=ckw,
)
for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
],
ConsistencyConfig(
v2_transforms.GaussianBlur,
legacy_transforms.GaussianBlur,
Expand All @@ -402,26 +318,6 @@ def __init__(
],
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
),
ConsistencyConfig(
v2_transforms.RandomAffine,
legacy_transforms.RandomAffine,
[
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)),
ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)),
ArgsKwargs(degrees=0.0, shear=13),
ArgsKwargs(degrees=0.0, shear=(8, 17)),
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
],
removed_params=["fillcolor", "resample"],
),
ConsistencyConfig(
v2_transforms.RandomCrop,
legacy_transforms.RandomCrop,
Expand Down Expand Up @@ -456,21 +352,6 @@ def __init__(
],
closeness_kwargs={"atol": None, "rtol": None},
),
ConsistencyConfig(
v2_transforms.RandomRotation,
legacy_transforms.RandomRotation,
[
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
],
removed_params=["resample"],
),
ConsistencyConfig(
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
Expand Down Expand Up @@ -514,23 +395,6 @@ def __init__(
]


def test_automatic_coverage():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was useful in the beginning, to ensure we are not missing something. But now that we start to remove configs here, we need to remove this check as well.

available = {
name
for name, obj in legacy_transforms.__dict__.items()
if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum)
}

checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}

missing = available - checked
if missing:
raise AssertionError(
f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
f"are not checked for consistency although a legacy counterpart exists."
)


@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
Expand Down Expand Up @@ -708,15 +572,9 @@ def test_call_consistency(config, args_kwargs):
(v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(
v2_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
),
(v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(v2_transforms.AutoAugment, ArgsKwargs(5)),
]
],
Expand Down
62 changes: 43 additions & 19 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,26 +228,39 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
assert functional_param == kernel_param


def _check_transform_v1_compatibility(transform, input):
def _check_transform_v1_compatibility(v2_transform_eager, input, *, rtol, atol):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method, is scriptable, and the scripted version can be called without error."""
if transform._v1_transform_cls is None:
``get_params`` method that is the v1 equivalent, and the output is close to v1 in eager and scripted mode."""
if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image):
return

if type(input) is not torch.Tensor:
v1_transform_cls = v2_transform_eager._v1_transform_cls
if v1_transform_cls is None:
return

if hasattr(transform._v1_transform_cls, "get_params"):
assert type(transform).get_params is transform._v1_transform_cls.get_params
if hasattr(v1_transform_cls, "get_params"):
assert type(v2_transform_eager).get_params is v1_transform_cls.get_params

scripted_transform = _script(transform)
with ignore_jit_no_profile_information_warning():
scripted_transform(input)
v1_transform_eager = v1_transform_cls(**v2_transform_eager._extract_params_for_v1_transform())

def check_close(transform_v2, transform_v1):
torch.manual_seed(0)
output_v2 = transform_v2(input)

torch.manual_seed(0)
output_v1 = transform_v1(input)

assert_close(output_v2, output_v1, rtol=rtol, atol=atol)

check_close(v2_transform_eager, v1_transform_eager)

if isinstance(input, PIL.Image.Image):
return

def check_transform(transform_cls, input, *args, **kwargs):
transform = transform_cls(*args, **kwargs)
check_close(_script(v2_transform_eager), _script(v1_transform_eager))


def check_transform(transform, input, *, check_v1_compatibility=True):
pickle.loads(pickle.dumps(transform))

output = transform(input)
Expand All @@ -256,7 +269,8 @@ def check_transform(transform_cls, input, *args, **kwargs):
if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format

_check_transform_v1_compatibility(transform, input)
if check_v1_compatibility:
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility))


def transform_cls_to_functional(transform_cls, **transform_specific_kwargs):
Expand Down Expand Up @@ -524,7 +538,12 @@ def test_functional_signature(self, kernel, input_type):
],
)
def test_transform(self, size, device, make_input):
check_transform(transforms.Resize, make_input(self.INPUT_SIZE, device=device), size=size, antialias=True)
check_transform(
transforms.Resize(size=size, antialias=True),
make_input(self.INPUT_SIZE, device=device),
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_v1_compatibility=dict(rtol=0, atol=1),
)

def _check_output_size(self, input, output, *, size, max_size):
assert tuple(F.get_size(output)) == self._compute_output_size(
Expand Down Expand Up @@ -848,7 +867,7 @@ def test_functional_signature(self, kernel, input_type):
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(transforms.RandomHorizontalFlip, make_input(device=device), p=1)
check_transform(transforms.RandomHorizontalFlip(p=1), make_input(device=device))

@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
Expand Down Expand Up @@ -1026,7 +1045,7 @@ def test_functional_signature(self, kernel, input_type):
def test_transform(self, make_input, device):
input = make_input(device=device)

check_transform(transforms.RandomAffine, input, **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES)
check_transform(transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), input)

@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
Expand Down Expand Up @@ -1313,7 +1332,7 @@ def test_functional_signature(self, kernel, input_type):
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(transforms.RandomVerticalFlip, make_input(device=device), p=1)
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))

@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
Expand Down Expand Up @@ -1464,7 +1483,7 @@ def test_functional_signature(self, kernel, input_type):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, device):
check_transform(
transforms.RandomRotation, make_input(device=device), **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES
transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), make_input(device=device)
)

@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
Expand Down Expand Up @@ -1726,7 +1745,7 @@ def test_transform(self, make_input, input_dtype, output_dtype, device, scale, a
input = make_input(dtype=input_dtype, device=device)
if as_dict:
output_dtype = {type(input): output_dtype}
check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale)
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input)

def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
input_dtype = image.dtype
Expand Down Expand Up @@ -2415,7 +2434,12 @@ def test_displacement_error(self, make_input):
@pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, size, device):
check_transform(transforms.ElasticTransform, make_input(size, device=device))
check_transform(
transforms.ElasticTransform(),
make_input(size, device=device),
# We updated gaussian blur kernel generation with a faster and numerically more stable version
check_v1_compatibility=dict(rtol=0, atol=1),
)


class TestToPureTensor:
Expand Down