Skip to content

Commit

Permalink
Added strides fix for 3D CL-like tensors in Resize
Browse files Browse the repository at this point in the history
Added tests on mem format
  • Loading branch information
vfdev-5 committed May 9, 2023
1 parent 791488b commit 427c4c1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
23 changes: 18 additions & 5 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,15 @@ def load(self, device):
class ImageLoader(TensorLoader):
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format

def __post_init__(self):
self.spatial_size = self.shape[-2:]
self.num_channels = self.shape[-3]

def load(self, device):
return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format)


NUM_CHANNELS_MAP = {
"GRAY": 1,
Expand Down Expand Up @@ -530,11 +534,13 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders)


def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
def make_image_loader_for_interpolation(
size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)

def fn(shape, dtype, device):
def fn(shape, dtype, device, memory_format):
height, width = shape[-2:]

image_pil = (
Expand All @@ -550,19 +556,26 @@ def fn(shape, dtype, device):
)
)

image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
image_tensor = to_image_tensor(image_pil)
if memory_format == torch.contiguous_format:
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else:
image_tensor = image_tensor.to(device=device)
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)
assert image_tensor[None].is_contiguous(memory_format=memory_format)

return datapoints.Image(image_tensor)

return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)


def make_image_loaders_for_interpolation(
sizes=((233, 147),),
color_spaces=("RGB",),
dtypes=(torch.uint8,),
memory_formats=(torch.contiguous_format, torch.channels_last),
):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats):
yield make_image_loader_for_interpolation(**params)


Expand Down
25 changes: 25 additions & 0 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,3 +1365,28 @@ def test_correctness_uniform_temporal_subsample(device):

out_video = F.uniform_temporal_subsample(video, 8)
assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu")

output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
assert input.ndim == 3, error_msg_fn
input_stride = input.stride()
output_stride = output.stride()
if input_stride[-1] == 1:
expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
assert expected_stride == output_stride, error_msg_fn("")
elif input_stride[0] == 1:
expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
assert expected_stride == output_stride, error_msg_fn("")
else:
assert False, error_msg_fn("")
13 changes: 12 additions & 1 deletion torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,21 @@ def resize_image_tensor(
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
# TODO: enable torchscript and torch.backends.cpu.get_cpu_capability
if "AVX2" in torch.backends.cpu.get_cpu_capability():
acceptable_dtypes.append(torch.uint8)

if image.is_contiguous(memory_format=torch.channels_last):
strides = image.stride()
numel = image.numel()
if image.shape[0] == 1 and numel != strides[0]:
# This is the case when channels last tensor was squeezed and unsqueezed such that
# stride[0] set as image.shape[1] * image.stride()[1] instead of being image.numel()
# Let's restride image such that it will be correctly treated as channels last.
# Related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
new_strides = list(strides)
new_strides[0] = numel
image = image.as_strided((1, num_channels, old_height, old_width), new_strides)

need_cast = dtype not in acceptable_dtypes
if need_cast:
image = image.to(dtype=torch.float32)
Expand Down

0 comments on commit 427c4c1

Please sign in to comment.