diff --git a/test/common_utils.py b/test/common_utils.py index c5826a36ff5..bf324d691c0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -465,11 +465,15 @@ def load(self, device): class ImageLoader(TensorLoader): spatial_size: Tuple[int, int] = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False) + memory_format: torch.memory_format = torch.contiguous_format def __post_init__(self): self.spatial_size = self.shape[-2:] self.num_channels = self.shape[-3] + def load(self, device): + return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format) + NUM_CHANNELS_MAP = { "GRAY": 1, @@ -530,11 +534,13 @@ def make_image_loaders( make_images = from_loaders(make_image_loaders) -def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8): +def make_image_loader_for_interpolation( + size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format +): size = _parse_spatial_size(size) num_channels = get_num_channels(color_space) - def fn(shape, dtype, device): + def fn(shape, dtype, device, memory_format): height, width = shape[-2:] image_pil = ( @@ -550,19 +556,26 @@ def fn(shape, dtype, device): ) ) - image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) + image_tensor = to_image_tensor(image_pil) + if memory_format == torch.contiguous_format: + image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) + else: + image_tensor = image_tensor.to(device=device) + image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype) + assert image_tensor[None].is_contiguous(memory_format=memory_format) return datapoints.Image(image_tensor) - return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype) + return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format) def make_image_loaders_for_interpolation( sizes=((233, 147),), color_spaces=("RGB",), dtypes=(torch.uint8,), + memory_formats=(torch.contiguous_format, torch.channels_last), ): - for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): + for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats): yield make_image_loader_for_interpolation(**params) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index ee9576b6487..a74c7d73d26 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -1365,3 +1365,28 @@ def test_correctness_uniform_temporal_subsample(device): out_video = F.uniform_temporal_subsample(video, 8) assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9] + + +# TODO: We can remove this test and related torchvision workaround +# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 +@make_info_args_kwargs_parametrization( + [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor], + args_kwargs_fn=lambda info: info.reference_inputs_fn(), +) +def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): + (input, *other_args), kwargs = args_kwargs.load("cpu") + + output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs) + + error_msg_fn = parametrized_error_message(input, *other_args, **kwargs) + assert input.ndim == 3, error_msg_fn + input_stride = input.stride() + output_stride = output.stride() + if input_stride[-1] == 1: + expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1) + assert expected_stride == output_stride, error_msg_fn("") + elif input_stride[0] == 1: + expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0]) + assert expected_stride == output_stride, error_msg_fn("") + else: + assert False, error_msg_fn("") diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index e3f18e435c7..6656d2fed0d 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -192,10 +192,21 @@ def resize_image_tensor( elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu": # uint8 dtype support for bilinear mode is limited to cpu and # according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path - # TODO: enable torchscript and torch.backends.cpu.get_cpu_capability if "AVX2" in torch.backends.cpu.get_cpu_capability(): acceptable_dtypes.append(torch.uint8) + if image.is_contiguous(memory_format=torch.channels_last): + strides = image.stride() + numel = image.numel() + if image.shape[0] == 1 and numel != strides[0]: + # This is the case when channels last tensor was squeezed and unsqueezed such that + # stride[0] set as image.shape[1] * image.stride()[1] instead of being image.numel() + # Let's restride image such that it will be correctly treated as channels last. + # Related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 + new_strides = list(strides) + new_strides[0] = numel + image = image.as_strided((1, num_channels, old_height, old_width), new_strides) + need_cast = dtype not in acceptable_dtypes if need_cast: image = image.to(dtype=torch.float32)