Skip to content

Commit

Permalink
break down massive parametrizations
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jun 29, 2023
1 parent 262e9c2 commit 39b91c8
Showing 1 changed file with 100 additions and 22 deletions.
122 changes: 100 additions & 22 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,57 +963,135 @@ def _adapt_fill(self, value, *, dtype):
k: next(v for v in vs if v is not None) for k, vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES.items()
}

def _check_kernel(self, kernel, input, *args, **kwargs):
kwargs_ = self._MINIMAL_AFFINE_KWARGS.copy()
kwargs_.update(kwargs)
check_kernel(kernel, input, *args, **kwargs_)

@pytest.mark.parametrize("angle", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor_angle(self, angle, dtype, device):
self._check_kernel(
F.affine_image_tensor,
self._make_input(torch.Tensor, dtype=dtype, device=device),
angle=angle,
)

@pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"])
@pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["scale"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor_translate(self, translate, dtype, device):
self._check_kernel(
F.affine_image_tensor,
self._make_input(torch.Tensor, dtype=dtype, device=device),
translate=translate,
)

@pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor_shear(self, shear, dtype, device):
self._check_kernel(
F.affine_image_tensor,
self._make_input(torch.Tensor, dtype=dtype, device=device),
shear=shear,
check_scripted_vs_eager=not isinstance(shear, (int, float)),
)

@pytest.mark.parametrize("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor_center(self, center, dtype, device):
self._check_kernel(
F.affine_image_tensor,
self._make_input(torch.Tensor, dtype=dtype, device=device),
center=center,
)

@pytest.mark.parametrize(
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
)
@pytest.mark.parametrize("fill", _EXHAUSTIVE_TYPE_FILLS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, angle, translate, scale, shear, center, interpolation, fill, dtype, device):
check_kernel(
def test_kernel_image_tensor_interpolation(self, interpolation, dtype, device):
self._check_kernel(
F.affine_image_tensor,
self._make_input(torch.Tensor, dtype=dtype, device=device),
angle=angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
interpolation=interpolation,
fill=self._adapt_fill(fill, dtype=dtype),
check_scripted_vs_eager=not (isinstance(shear, (int, float)) or isinstance(fill, (int, float))),
check_cuda_vs_cpu=dict(atol=1, rtol=0)
if dtype is torch.uint8 and interpolation is transforms.InterpolationMode.BILINEAR
else True,
)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("fill", _EXHAUSTIVE_TYPE_FILLS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor_fill(self, fill, dtype, device):
self._check_kernel(
F.affine_image_tensor,
self._make_input(torch.Tensor, dtype=dtype, device=device),
fill=self._adapt_fill(fill, dtype=dtype),
check_scripted_vs_eager=not isinstance(fill, (int, float)),
)

@pytest.mark.parametrize("angle", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"])
@pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["scale"])
@pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"])
@pytest.mark.parametrize("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box(self, format, angle, translate, scale, shear, center, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format)
check_kernel(
def test_kernel_bounding_box_angle(self, angle, format, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
self._check_kernel(
F.affine_bounding_box,
bounding_box,
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
format=format,
spatial_size=bounding_box.spatial_size,
angle=angle,
)

@pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"])
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box_translate(self, translate, format, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
self._check_kernel(
F.affine_bounding_box,
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
format=format,
spatial_size=bounding_box.spatial_size,
translate=translate,
scale=scale,
)

@pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"])
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box_shear(self, shear, format, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
self._check_kernel(
F.affine_bounding_box,
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
format=format,
spatial_size=bounding_box.spatial_size,
shear=shear,
center=center,
check_scripted_vs_eager=not isinstance(shear, (int, float)),
)

@pytest.mark.parametrize("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box_center(self, center, format, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device)
self._check_kernel(
F.affine_bounding_box,
self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device),
format=format,
spatial_size=bounding_box.spatial_size,
center=center,
)

@pytest.mark.parametrize("mask_type", ["segmentation", "detection"])
def test_kernel_mask(self, mask_type):
check_kernel(
Expand Down

0 comments on commit 39b91c8

Please sign in to comment.