From 94cf81e8d475874c35f6db794d300f25c392a4ae Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Aug 2023 00:48:27 +0200 Subject: [PATCH 01/10] register tensor and PIL kernel the same way as datapoints --- test/test_transforms_v2_functional.py | 58 +-- test/test_transforms_v2_refactored.py | 62 +-- .../transforms/v2/functional/_augment.py | 25 +- .../transforms/v2/functional/_color.py | 280 +++++------- .../transforms/v2/functional/_geometry.py | 422 ++++++++---------- torchvision/transforms/v2/functional/_meta.py | 86 ++-- torchvision/transforms/v2/functional/_misc.py | 69 ++- .../transforms/v2/functional/_temporal.py | 20 +- .../transforms/v2/functional/_utils.py | 49 +- 9 files changed, 410 insertions(+), 661 deletions(-) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 8d529732610..580a74d8055 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -2,7 +2,6 @@ import math import os import re -from unittest import mock import numpy as np import PIL.Image @@ -25,7 +24,6 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes -from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY from torchvision.transforms.v2.utils import is_simple_tensor from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS @@ -359,18 +357,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): def test_scriptable(self, dispatcher): script(dispatcher) - @image_sample_inputs - def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on): - (image_datapoint, *other_args), kwargs = args_kwargs.load() - image_simple_tensor = torch.Tensor(image_datapoint) - - kernel_info = info.kernel_infos[datapoints.Image] - spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id) - - info.dispatcher(image_simple_tensor, *other_args, **kwargs) - - spy.assert_called_once() - @image_sample_inputs def test_simple_tensor_output_type(self, info, args_kwargs): (image_datapoint, *other_args), kwargs = args_kwargs.load() @@ -381,25 +367,6 @@ def test_simple_tensor_output_type(self, info, args_kwargs): # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well assert type(output) is torch.Tensor - @make_info_args_kwargs_parametrization( - [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], - args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), - ) - def test_dispatch_pil(self, info, args_kwargs, spy_on): - (image_datapoint, *other_args), kwargs = args_kwargs.load() - - if image_datapoint.ndim > 3: - pytest.skip("Input is batched") - - image_pil = F.to_image_pil(image_datapoint) - - pil_kernel_info = info.pil_kernel_info - spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id) - - info.dispatcher(image_pil, *other_args, **kwargs) - - spy.assert_called_once() - @make_info_args_kwargs_parametrization( [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), @@ -416,28 +383,6 @@ def test_pil_output_type(self, info, args_kwargs): assert isinstance(output, PIL.Image.Image) - @make_info_args_kwargs_parametrization( - DISPATCHER_INFOS, - args_kwargs_fn=lambda info: info.sample_inputs(), - ) - def test_dispatch_datapoint(self, info, args_kwargs, spy_on): - (datapoint, *other_args), kwargs = args_kwargs.load() - - input_type = type(datapoint) - - wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type] - - # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the - # proper kernel was wrapped - if hasattr(wrapped_kernel, "__wrapped__"): - assert wrapped_kernel.__wrapped__ is info.kernels[input_type] - - spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) - with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}): - info.dispatcher(datapoint, *other_args, **kwargs) - - spy.assert_called_once() - @make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(), @@ -449,6 +394,9 @@ def test_datapoint_output_type(self, info, args_kwargs): assert isinstance(output, type(datapoint)) + if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes: + assert output.format == datapoint.format + @pytest.mark.parametrize( ("dispatcher_info", "datapoint_type", "kernel_info"), [ diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 45668fda1ca..20b198796ca 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs): dispatcher_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) -def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): - """Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is - preserved in doing so. For bounding boxes also checks that the format is preserved. - """ - input_type = type(input) - - if isinstance(input, datapoints.Datapoint): - wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type] - - # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the - # proper kernel was wrapped - if hasattr(wrapped_kernel, "__wrapped__"): - assert wrapped_kernel.__wrapped__ is kernel - - spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) - with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}): - output = dispatcher(input, *args, **kwargs) - - spy.assert_called_once() - else: - with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy: - output = dispatcher(input, *args, **kwargs) - - spy.assert_called_once() - - assert isinstance(output, input_type) - - if isinstance(input, datapoints.BoundingBoxes): - assert output.format == input.format - - def check_dispatcher( dispatcher, + # TODO: remove this parameter kernel, input, *args, check_scripted_smoke=True, - check_dispatch=True, **kwargs, ): unknown_input = object() + with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): + dispatcher(unknown_input, *args, **kwargs) + with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: - with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): - dispatcher(unknown_input, *args, **kwargs) + output = dispatcher(input, *args, **kwargs) spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") + assert isinstance(output, type(input)) + + if isinstance(input, datapoints.BoundingBoxes): + assert output.format == input.format + if check_scripted_smoke: _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) - if check_dispatch: - _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs) - def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): """Checks if the signature of the dispatcher matches the kernel signature.""" @@ -412,18 +385,20 @@ def transform(bbox): @pytest.mark.parametrize( - ("dispatcher", "registered_datapoint_clss"), + ("dispatcher", "registered_input_types"), [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], ) -def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): +def test_exhaustive_kernel_registration(dispatcher, registered_input_types): missing = { + torch.Tensor, + PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video, - } - registered_datapoint_clss + } - registered_input_types if missing: - names = sorted(f"datapoints.{cls.__name__}" for cls in missing) + names = sorted(str(t) for t in missing) raise AssertionError( "\n".join( [ @@ -1753,11 +1728,6 @@ def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, F.to_dtype, kernel, make_input(dtype=input_dtype, device=device), - # TODO: we could leave check_dispatch to True but it currently fails - # in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints. - # We should be able to put this back if we change the dispatch - # mechanism e.g. via https://github.com/pytorch/vision/pull/7733 - check_dispatch=False, dtype=output_dtype, scale=scale, ) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 95b4ed93786..89fa254374d 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -7,7 +7,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) @@ -20,23 +20,16 @@ def erase( v: torch.Tensor, inplace: bool = False, ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: - if not torch.jit.is_scripting(): - _log_api_usage_once(erase) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(erase, type(inpt)) - return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - elif isinstance(inpt, PIL.Image.Image): - return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(erase) + + kernel = _get_kernel(erase, type(inpt)) + return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) +@_register_kernel_internal(erase, torch.Tensor) @_register_kernel_internal(erase, datapoints.Image) def erase_image_tensor( image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False @@ -48,7 +41,7 @@ def erase_image_tensor( return image -@torch.jit.unused +@_register_kernel_internal(erase, PIL.Image.Image) def erase_image_pil( image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> PIL.Image.Image: diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 99dc1936259..4d84e6d23a5 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,29 +10,20 @@ from torchvision.utils import _log_api_usage_once from ._misc import _num_value_bits, to_dtype_image_tensor -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) def rgb_to_grayscale( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: - if not torch.jit.is_scripting(): - _log_api_usage_once(rgb_to_grayscale) - if num_output_channels not in (1, 3): - raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(rgb_to_grayscale, type(inpt)) - return kernel(inpt, num_output_channels=num_output_channels) - elif isinstance(inpt, PIL.Image.Image): - return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(rgb_to_grayscale) + + kernel = _get_kernel(rgb_to_grayscale, type(inpt)) + return kernel(inpt, num_output_channels=num_output_channels) # `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a @@ -56,12 +47,19 @@ def _rgb_to_grayscale_image_tensor( return l_img +@_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, datapoints.Image) def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) -rgb_to_grayscale_image_pil = _FP.to_grayscale +@_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image) +def rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + return _FP.to_grayscale(image, num_output_channels=num_output_channels) def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -74,23 +72,16 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_brightness) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_brightness, type(inpt)) - return kernel(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_brightness) + kernel = _get_kernel(adjust_brightness, type(inpt)) + return kernel(inpt, brightness_factor=brightness_factor) + + +@_register_kernel_internal(adjust_brightness, torch.Tensor) @_register_kernel_internal(adjust_brightness, datapoints.Image) def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: @@ -106,6 +97,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float return output if fp else output.to(image.dtype) +@_register_kernel_internal(adjust_brightness, PIL.Image.Image) def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: return _FP.adjust_brightness(image, brightness_factor=brightness_factor) @@ -117,23 +109,16 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_saturation) - - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if torch.jit.is_scripting(): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_saturation, type(inpt)) - return kernel(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_saturation) + + kernel = _get_kernel(adjust_saturation, type(inpt)) + return kernel(inpt, saturation_factor=saturation_factor) + +@_register_kernel_internal(adjust_saturation, torch.Tensor) @_register_kernel_internal(adjust_saturation, datapoints.Image) def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: if saturation_factor < 0: @@ -154,6 +139,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float adjust_saturation_image_pil = _FP.adjust_saturation +_register_kernel_internal(adjust_saturation, PIL.Image.Image)(adjust_saturation_image_pil) @_register_kernel_internal(adjust_saturation, datapoints.Video) @@ -163,23 +149,16 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_contrast) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_contrast, type(inpt)) - return kernel(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(adjust_contrast) + + kernel = _get_kernel(adjust_contrast, type(inpt)) + return kernel(inpt, contrast_factor=contrast_factor) +@_register_kernel_internal(adjust_contrast, torch.Tensor) @_register_kernel_internal(adjust_contrast, datapoints.Image) def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: if contrast_factor < 0: @@ -200,6 +179,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> adjust_contrast_image_pil = _FP.adjust_contrast +_register_kernel_internal(adjust_contrast, PIL.Image.Image)(adjust_contrast_image_pil) @_register_kernel_internal(adjust_contrast, datapoints.Video) @@ -209,23 +189,16 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_sharpness) - - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + if torch.jit.is_scripting(): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_sharpness, type(inpt)) - return kernel(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_sharpness) + + kernel = _get_kernel(adjust_sharpness, type(inpt)) + return kernel(inpt, sharpness_factor=sharpness_factor) + +@_register_kernel_internal(adjust_sharpness, torch.Tensor) @_register_kernel_internal(adjust_sharpness, datapoints.Image) def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: num_channels, height, width = image.shape[-3:] @@ -280,6 +253,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) adjust_sharpness_image_pil = _FP.adjust_sharpness +_register_kernel_internal(adjust_sharpness, PIL.Image.Image)(adjust_sharpness_image_pil) @_register_kernel_internal(adjust_sharpness, datapoints.Video) @@ -289,21 +263,13 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_hue) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_hue, type(inpt)) - return kernel(inpt, hue_factor=hue_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_hue_image_pil(inpt, hue_factor=hue_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(adjust_hue) + + kernel = _get_kernel(adjust_hue, type(inpt)) + return kernel(inpt, hue_factor=hue_factor) def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: @@ -370,6 +336,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3) +@_register_kernel_internal(adjust_hue, torch.Tensor) @_register_kernel_internal(adjust_hue, datapoints.Image) def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: if not (-0.5 <= hue_factor <= 0.5): @@ -399,6 +366,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten adjust_hue_image_pil = _FP.adjust_hue +_register_kernel_internal(adjust_hue, PIL.Image.Image)(adjust_hue_image_pil) @_register_kernel_internal(adjust_hue, datapoints.Video) @@ -408,23 +376,16 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_gamma) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(adjust_gamma, type(inpt)) - return kernel(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, PIL.Image.Image): - return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(adjust_gamma) + kernel = _get_kernel(adjust_gamma, type(inpt)) + return kernel(inpt, gamma=gamma, gain=gain) + + +@_register_kernel_internal(adjust_gamma, torch.Tensor) @_register_kernel_internal(adjust_gamma, datapoints.Image) def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: if gamma < 0: @@ -446,6 +407,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 adjust_gamma_image_pil = _FP.adjust_gamma +_register_kernel_internal(adjust_gamma, PIL.Image.Image)(adjust_gamma_image_pil) @_register_kernel_internal(adjust_gamma, datapoints.Video) @@ -455,23 +417,16 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(posterize) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return posterize_image_tensor(inpt, bits=bits) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(posterize, type(inpt)) - return kernel(inpt, bits=bits) - elif isinstance(inpt, PIL.Image.Image): - return posterize_image_pil(inpt, bits=bits) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(posterize) + + kernel = _get_kernel(posterize, type(inpt)) + return kernel(inpt, bits=bits) +@_register_kernel_internal(posterize, torch.Tensor) @_register_kernel_internal(posterize, datapoints.Image) def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: if image.is_floating_point(): @@ -487,6 +442,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: posterize_image_pil = _FP.posterize +_register_kernel_internal(posterize, PIL.Image.Image)(posterize_image_pil) @_register_kernel_internal(posterize, datapoints.Video) @@ -496,23 +452,16 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(solarize) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return solarize_image_tensor(inpt, threshold=threshold) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(solarize, type(inpt)) - return kernel(inpt, threshold=threshold) - elif isinstance(inpt, PIL.Image.Image): - return solarize_image_pil(inpt, threshold=threshold) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(solarize) + + kernel = _get_kernel(solarize, type(inpt)) + return kernel(inpt, threshold=threshold) + +@_register_kernel_internal(solarize, torch.Tensor) @_register_kernel_internal(solarize, datapoints.Image) def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: if threshold > _max_value(image.dtype): @@ -522,6 +471,7 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor solarize_image_pil = _FP.solarize +_register_kernel_internal(solarize, PIL.Image.Image)(solarize_image_pil) @_register_kernel_internal(solarize, datapoints.Video) @@ -531,25 +481,16 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(autocontrast) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return autocontrast_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(autocontrast, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return autocontrast_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(autocontrast) + + kernel = _get_kernel(autocontrast, type(inpt)) + return kernel(inpt) + +@_register_kernel_internal(autocontrast, torch.Tensor) @_register_kernel_internal(autocontrast, datapoints.Image) def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: c = image.shape[-3] @@ -581,6 +522,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: autocontrast_image_pil = _FP.autocontrast +_register_kernel_internal(autocontrast, PIL.Image.Image)(autocontrast_image_pil) @_register_kernel_internal(autocontrast, datapoints.Video) @@ -590,25 +532,16 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(equalize) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return equalize_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(equalize, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return equalize_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + _log_api_usage_once(equalize) + kernel = _get_kernel(equalize, type(inpt)) + return kernel(inpt) + + +@_register_kernel_internal(equalize, torch.Tensor) @_register_kernel_internal(equalize, datapoints.Image) def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: @@ -680,6 +613,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: equalize_image_pil = _FP.equalize +_register_kernel_internal(equalize, PIL.Image.Image)(equalize_image_pil) @_register_kernel_internal(equalize, datapoints.Video) @@ -689,25 +623,16 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(invert) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return invert_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(invert, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return invert_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(invert) + + kernel = _get_kernel(invert, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(invert, torch.Tensor) @_register_kernel_internal(invert, datapoints.Image) def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.is_floating_point(): @@ -720,6 +645,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: invert_image_pil = _FP.invert +_register_kernel_internal(invert, PIL.Image.Image)(invert_image_pil) @_register_kernel_internal(invert, datapoints.Video) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 21f2aa8df0a..8afe37384a3 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,13 +25,7 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import ( - _get_kernel, - _register_explicit_noop, - _register_five_ten_crop_kernel, - _register_kernel_internal, - is_simple_tensor, -) +from ._utils import _get_kernel, _register_explicit_noop, _register_five_ten_crop_kernel, _register_kernel_internal def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -46,30 +40,22 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(horizontal_flip) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return horizontal_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(horizontal_flip, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return horizontal_flip_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(horizontal_flip) + + kernel = _get_kernel(horizontal_flip, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(horizontal_flip, torch.Tensor) @_register_kernel_internal(horizontal_flip, datapoints.Image) def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-1) +@_register_kernel_internal(horizontal_flip, PIL.Image.Image) def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) @@ -110,30 +96,22 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(vertical_flip) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return vertical_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(vertical_flip, type(inpt)) - return kernel( - inpt, - ) - elif isinstance(inpt, PIL.Image.Image): - return vertical_flip_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(vertical_flip) + + kernel = _get_kernel(vertical_flip, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(vertical_flip, torch.Tensor) @_register_kernel_internal(vertical_flip, datapoints.Image) def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-2) +@_register_kernel_internal(vertical_flip, PIL.Image.Image) def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: return _FP.vflip(image) @@ -199,24 +177,16 @@ def resize( max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(resize) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(resize, type(inpt)) - return kernel(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, PIL.Image.Image): - if antialias is False: - warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") - return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + _log_api_usage_once(resize) + + kernel = _get_kernel(resize, type(inpt)) + return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) +@_register_kernel_internal(resize, torch.Tensor) @_register_kernel_internal(resize, datapoints.Image) def resize_image_tensor( image: torch.Tensor, @@ -297,7 +267,6 @@ def resize_image_tensor( return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) -@torch.jit.unused def resize_image_pil( image: PIL.Image.Image, size: Union[Sequence[int], int], @@ -319,6 +288,19 @@ def resize_image_pil( return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation]) +@_register_kernel_internal(resize, PIL.Image.Image) +def _resize_image_pil_dispatch( + image: PIL.Image.Image, + size: Union[Sequence[int], int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", +) -> PIL.Image.Image: + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) + + def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -391,26 +373,10 @@ def affine( fill: datapoints._FillTypeJIT = None, center: Optional[List[float]] = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(affine) - - # TODO: consider deprecating integers from angle and shear on the future - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return affine_image_tensor( inpt, - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(affine, type(inpt)) - return kernel( - inpt, - angle, + angle=angle, translate=translate, scale=scale, shear=shear, @@ -418,22 +384,20 @@ def affine( fill=fill, center=center, ) - elif isinstance(inpt, PIL.Image.Image): - return affine_image_pil( - inpt, - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(affine) + + kernel = _get_kernel(affine, type(inpt)) + return kernel( + inpt, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) def _affine_parse_args( @@ -684,6 +648,7 @@ def _affine_grid( return output_grid.view(1, oh, ow, 2) +@_register_kernel_internal(affine, torch.Tensor) @_register_kernel_internal(affine, datapoints.Image) def affine_image_tensor( image: torch.Tensor, @@ -736,7 +701,7 @@ def affine_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(affine, PIL.Image.Image) def affine_image_pil( image: PIL.Image.Image, angle: Union[int, float], @@ -983,23 +948,18 @@ def rotate( center: Optional[List[float]] = None, fill: datapoints._FillTypeJIT = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(rotate) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(rotate, type(inpt)) - return kernel(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, PIL.Image.Image): - return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." + if torch.jit.is_scripting(): + return rotate_image_tensor( + inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center ) + _log_api_usage_once(rotate) + kernel = _get_kernel(rotate, type(inpt)) + return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + + +@_register_kernel_internal(rotate, torch.Tensor) @_register_kernel_internal(rotate, datapoints.Image) def rotate_image_tensor( image: torch.Tensor, @@ -1045,7 +1005,7 @@ def rotate_image_tensor( return output.reshape(shape[:-3] + (num_channels, new_height, new_width)) -@torch.jit.unused +@_register_kernel_internal(rotate, PIL.Image.Image) def rotate_image_pil( image: PIL.Image.Image, angle: float, @@ -1162,22 +1122,13 @@ def pad( fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(pad) + if torch.jit.is_scripting(): + return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) + _log_api_usage_once(pad) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(pad, type(inpt)) - return kernel(inpt, padding, fill=fill, padding_mode=padding_mode) - elif isinstance(inpt, PIL.Image.Image): - return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + kernel = _get_kernel(pad, type(inpt)) + return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode) def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: @@ -1204,6 +1155,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: return [pad_left, pad_right, pad_top, pad_bottom] +@_register_kernel_internal(pad, torch.Tensor) @_register_kernel_internal(pad, datapoints.Image) def pad_image_tensor( image: torch.Tensor, @@ -1304,6 +1256,7 @@ def _pad_with_vector_fill( pad_image_pil = _FP.pad +_register_kernel_internal(pad, PIL.Image.Image)(pad_image_pil) @_register_kernel_internal(pad, datapoints.Mask) @@ -1385,23 +1338,16 @@ def pad_video( def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return crop_image_tensor(inpt, top, left, height, width) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(crop, type(inpt)) - return kernel(inpt, top, left, height, width) - elif isinstance(inpt, PIL.Image.Image): - return crop_image_pil(inpt, top, left, height, width) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return crop_image_tensor(inpt, top=top, left=left, height=height, width=width) + + _log_api_usage_once(crop) + kernel = _get_kernel(crop, type(inpt)) + return kernel(inpt, top=top, left=left, height=height, width=width) + +@_register_kernel_internal(crop, torch.Tensor) @_register_kernel_internal(crop, datapoints.Image) def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: h, w = image.shape[-2:] @@ -1422,6 +1368,7 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid crop_image_pil = _FP.crop +_register_kernel_internal(crop, PIL.Image.Image)(crop_image_pil) def crop_bounding_boxes( @@ -1484,25 +1431,28 @@ def perspective( fill: datapoints._FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(perspective) - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return perspective_image_tensor( - inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients - ) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(perspective, type(inpt)) - return kernel(inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients) - elif isinstance(inpt, PIL.Image.Image): - return perspective_image_pil( - inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients - ) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." + inpt, + startpoints=startpoints, + endpoints=endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, ) + _log_api_usage_once(perspective) + + kernel = _get_kernel(perspective, type(inpt)) + return kernel( + inpt, + startpoints=startpoints, + endpoints=endpoints, + interpolation=interpolation, + fill=fill, + coefficients=coefficients, + ) + def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/ @@ -1551,6 +1501,7 @@ def _perspective_coefficients( raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.") +@_register_kernel_internal(perspective, torch.Tensor) @_register_kernel_internal(perspective, datapoints.Image) def perspective_image_tensor( image: torch.Tensor, @@ -1598,7 +1549,7 @@ def perspective_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(perspective, PIL.Image.Image) def perspective_image_pil( image: PIL.Image.Image, startpoints: Optional[List[List[int]]], @@ -1787,29 +1738,19 @@ def elastic( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints._FillTypeJIT = None, ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(elastic) - - if not isinstance(displacement, torch.Tensor): - raise TypeError("Argument displacement should be a Tensor") - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(elastic, type(inpt)) - return kernel(inpt, displacement, interpolation=interpolation, fill=fill) - elif isinstance(inpt, PIL.Image.Image): - return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill) + + _log_api_usage_once(elastic) + + kernel = _get_kernel(elastic, type(inpt)) + return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill) elastic_transform = elastic +@_register_kernel_internal(elastic, torch.Tensor) @_register_kernel_internal(elastic, datapoints.Image) def elastic_image_tensor( image: torch.Tensor, @@ -1867,7 +1808,7 @@ def elastic_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(elastic, PIL.Image.Image) def elastic_image_pil( image: PIL.Image.Image, displacement: torch.Tensor, @@ -1990,21 +1931,13 @@ def elastic_video( def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(center_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return center_crop_image_tensor(inpt, output_size) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(center_crop, type(inpt)) - return kernel(inpt, output_size) - elif isinstance(inpt, PIL.Image.Image): - return center_crop_image_pil(inpt, output_size) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return center_crop_image_tensor(inpt, output_size=output_size) + + _log_api_usage_once(center_crop) + + kernel = _get_kernel(center_crop, type(inpt)) + return kernel(inpt, output_size=output_size) def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: @@ -2034,6 +1967,7 @@ def _center_crop_compute_crop_anchor( return crop_top, crop_left +@_register_kernel_internal(center_crop, torch.Tensor) @_register_kernel_internal(center_crop, datapoints.Image) def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) @@ -2054,7 +1988,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)] -@torch.jit.unused +@_register_kernel_internal(center_crop, PIL.Image.Image) def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) image_height, image_width = get_size_image_pil(image) @@ -2125,25 +2059,34 @@ def resized_crop( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(resized_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return resized_crop_image_tensor( - inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation - ) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(resized_crop, type(inpt)) - return kernel(inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) - elif isinstance(inpt, PIL.Image.Image): - return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." + inpt, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + antialias=antialias, ) + _log_api_usage_once(resized_crop) + + kernel = _get_kernel(resized_crop, type(inpt)) + return kernel( + inpt, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + antialias=antialias, + ) + +@_register_kernel_internal(resized_crop, torch.Tensor) @_register_kernel_internal(resized_crop, datapoints.Image) def resized_crop_image_tensor( image: torch.Tensor, @@ -2159,7 +2102,6 @@ def resized_crop_image_tensor( return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias) -@torch.jit.unused def resized_crop_image_pil( image: PIL.Image.Image, top: int, @@ -2173,6 +2115,30 @@ def resized_crop_image_pil( return resize_image_pil(image, size, interpolation=interpolation) +@_register_kernel_internal(resized_crop, PIL.Image.Image) +def resized_crop_image_pil_dispatch( + image: PIL.Image.Image, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", +) -> PIL.Image.Image: + if antialias is False: + pass + return resized_crop_image_pil( + image, + top=top, + left=left, + height=height, + width=width, + size=size, + interpolation=interpolation, + ) + + def resized_crop_bounding_boxes( bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, @@ -2244,21 +2210,13 @@ def five_crop( datapoints._InputTypeJIT, datapoints._InputTypeJIT, ]: - if not torch.jit.is_scripting(): - _log_api_usage_once(five_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return five_crop_image_tensor(inpt, size) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(five_crop, type(inpt)) - return kernel(inpt, size) - elif isinstance(inpt, PIL.Image.Image): - return five_crop_image_pil(inpt, size) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return five_crop_image_tensor(inpt, size=size) + + _log_api_usage_once(five_crop) + + kernel = _get_kernel(five_crop, type(inpt)) + return kernel(inpt, size=size) def _parse_five_crop_size(size: List[int]) -> List[int]: @@ -2275,6 +2233,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: return size +@_register_five_ten_crop_kernel(five_crop, torch.Tensor) @_register_five_ten_crop_kernel(five_crop, datapoints.Image) def five_crop_image_tensor( image: torch.Tensor, size: List[int] @@ -2294,7 +2253,7 @@ def five_crop_image_tensor( return tl, tr, bl, br, center -@torch.jit.unused +@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image) def five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: @@ -2335,23 +2294,16 @@ def ten_crop( datapoints._InputTypeJIT, datapoints._InputTypeJIT, ]: - if not torch.jit.is_scripting(): - _log_api_usage_once(ten_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(ten_crop, type(inpt)) - return kernel(inpt, size, vertical_flip=vertical_flip) - elif isinstance(inpt, PIL.Image.Image): - return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip) + + _log_api_usage_once(ten_crop) + + kernel = _get_kernel(ten_crop, type(inpt)) + return kernel(inpt, size=size, vertical_flip=vertical_flip) +@_register_five_ten_crop_kernel(ten_crop, torch.Tensor) @_register_five_ten_crop_kernel(ten_crop, datapoints.Image) def ten_crop_image_tensor( image: torch.Tensor, size: List[int], vertical_flip: bool = False @@ -2379,7 +2331,7 @@ def ten_crop_image_tensor( return non_flipped + flipped -@torch.jit.unused +@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image) def ten_crop_image_pil( image: PIL.Image.Image, size: List[int], vertical_flip: bool = False ) -> Tuple[ diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index a4bfe7df8e4..f434bfadb48 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -13,23 +13,16 @@ @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_dimensions) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_dimensions_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_dimensions, type(inpt)) - return kernel(inpt) - elif isinstance(inpt, PIL.Image.Image): - return get_dimensions_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_dimensions) + + kernel = _get_kernel(get_dimensions, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_dimensions, torch.Tensor) @_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: chw = list(image.shape[-3:]) @@ -44,6 +37,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: get_dimensions_image_pil = _FP.get_dimensions +_register_kernel_internal(get_dimensions, PIL.Image.Image)(get_dimensions_image_pil) @_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @@ -53,23 +47,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: @_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_num_channels) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_num_channels_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_num_channels, type(inpt)) - return kernel(inpt) - elif isinstance(inpt, PIL.Image.Image): - return get_num_channels_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_num_channels) + + kernel = _get_kernel(get_num_channels, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_num_channels, torch.Tensor) @_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) def get_num_channels_image_tensor(image: torch.Tensor) -> int: chw = image.shape[-3:] @@ -83,6 +70,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: get_num_channels_image_pil = _FP.get_image_num_channels +_register_kernel_internal(get_num_channels, PIL.Image.Image)(get_num_channels_image_pil) @_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) @@ -96,23 +84,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_size) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_size_image_tensor(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_size, type(inpt)) - return kernel(inpt) - elif isinstance(inpt, PIL.Image.Image): - return get_size_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_size) + + kernel = _get_kernel(get_size, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_size, torch.Tensor) @_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) def get_size_image_tensor(image: torch.Tensor) -> List[int]: hw = list(image.shape[-2:]) @@ -123,7 +104,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -@torch.jit.unused +@_register_kernel_internal(get_size, PIL.Image.Image) def get_size_image_pil(image: PIL.Image.Image) -> List[int]: width, height = _FP.get_image_size(image) return [height, width] @@ -146,21 +127,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] @_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_num_frames) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return get_num_frames_video(inpt) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(get_num_frames, type(inpt)) - return kernel(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(get_num_frames) + + kernel = _get_kernel(get_num_frames, type(inpt)) + return kernel(inpt) +@_register_kernel_internal(get_num_frames, torch.Tensor) @_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False) def get_num_frames_video(video: torch.Tensor) -> int: return video.shape[-4] diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 90a3e44e9d3..e3a800ea79c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -11,13 +11,7 @@ from torchvision.utils import _log_api_usage_once -from ._utils import ( - _get_kernel, - _register_explicit_noop, - _register_kernel_internal, - _register_unsupported_type, - is_simple_tensor, -) +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) @@ -28,19 +22,16 @@ def normalize( std: List[float], inplace: bool = False, ) -> torch.Tensor: - if not torch.jit.is_scripting(): - _log_api_usage_once(normalize) - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(normalize, type(inpt)) - return kernel(inpt, mean=mean, std=std, inplace=inplace) - else: - raise TypeError( - f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." - ) + + _log_api_usage_once(normalize) + + kernel = _get_kernel(normalize, type(inpt)) + return kernel(inpt, mean=mean, std=std, inplace=inplace) +@_register_kernel_internal(normalize, torch.Tensor) @_register_kernel_internal(normalize, datapoints.Image) def normalize_image_tensor( image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False @@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in def gaussian_blur( inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(gaussian_blur) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting(): return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(gaussian_blur, type(inpt)) - return kernel(inpt, kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, PIL.Image.Image): - return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + + _log_api_usage_once(gaussian_blur) + + kernel = _get_kernel(gaussian_blur, type(inpt)) + return kernel(inpt, kernel_size=kernel_size, sigma=sigma) def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: @@ -119,6 +102,7 @@ def _get_gaussian_kernel2d( return kernel2d +@_register_kernel_internal(gaussian_blur, torch.Tensor) @_register_kernel_internal(gaussian_blur, datapoints.Image) def gaussian_blur_image_tensor( image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None @@ -184,7 +168,7 @@ def gaussian_blur_image_tensor( return output -@torch.jit.unused +@_register_kernel_internal(gaussian_blur, PIL.Image.Image) def gaussian_blur_image_pil( image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> PIL.Image.Image: @@ -200,21 +184,17 @@ def gaussian_blur_video( return gaussian_blur_image_tensor(video, kernel_size, sigma) +@_register_unsupported_type(PIL.Image.Image) def to_dtype( inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False ) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(to_dtype) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return to_dtype_image_tensor(inpt, dtype, scale=scale) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(to_dtype, type(inpt)) - return kernel(inpt, dtype, scale=scale) - else: - raise TypeError( - f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." - ) + if torch.jit.is_scripting(): + return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) + + _log_api_usage_once(to_dtype) + + kernel = _get_kernel(to_dtype, type(inpt)) + return kernel(inpt, dtype=dtype, scale=scale) def _num_value_bits(dtype: torch.dtype) -> int: @@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") +@_register_kernel_internal(to_dtype, torch.Tensor) @_register_kernel_internal(to_dtype, datapoints.Image) def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 52c745f9901..62d12cb4b4e 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -5,27 +5,23 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal @_register_explicit_noop( PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True ) def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(uniform_temporal_subsample) + if torch.jit.is_scripting(): + return uniform_temporal_subsample_video(inpt, num_samples=num_samples) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return uniform_temporal_subsample_video(inpt, num_samples) - elif isinstance(inpt, datapoints.Datapoint): - kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) - return kernel(inpt, num_samples) - else: - raise TypeError( - f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." - ) + _log_api_usage_once(uniform_temporal_subsample) + kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) + return kernel(inpt, num_samples=num_samples) + +@_register_kernel_internal(uniform_temporal_subsample, torch.Tensor) @_register_kernel_internal(uniform_temporal_subsample, datapoints.Video) def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 63e029d6c77..0f0b72c08ad 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -2,6 +2,8 @@ import warnings from typing import Any, Callable, Dict, Type +import PIL.Image + import torch from torchvision import datapoints @@ -23,15 +25,17 @@ def wrapper(inpt, *args, **kwargs): return wrapper -def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=True): +def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) - if datapoint_cls in registry: - raise TypeError( - f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'." - ) + if input_type in registry: + raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") def decorator(kernel): - registry[datapoint_cls] = _kernel_datapoint_wrapper(kernel) if datapoint_wrapper else kernel + registry[input_type] = ( + _kernel_datapoint_wrapper(kernel) + if issubclass(input_type, datapoints.Datapoint) and datapoint_wrapper + else kernel + ) return kernel return decorator @@ -41,16 +45,22 @@ def register_kernel(dispatcher, datapoint_cls): return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) -def _get_kernel(dispatcher, datapoint_cls): +def _get_kernel(dispatcher, input_type): + if not issubclass(input_type, (torch.Tensor, PIL.Image.Image)): + raise TypeError( + f"Dispatcher '{dispatcher}' supports plain tensors, any TorchVision datapoint, or a PIL images, " + f"but got {input_type} instead." + ) + registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") - if datapoint_cls in registry: - return registry[datapoint_cls] + if input_type in registry: + return registry[input_type] - for registered_cls, kernel in registry.items(): - if issubclass(datapoint_cls, registered_cls): + for registered_type, kernel in registry.items(): + if issubclass(input_type, registered_type): return kernel return _noop @@ -99,13 +109,13 @@ def _noop(inpt, *args, __msg__=None, **kwargs): # TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that # to error later, this decorator can be removed, since the error will be raised by _get_kernel -def _register_unsupported_type(*datapoints_classes): +def _register_unsupported_type(*input_types): def kernel(inpt, *args, __dispatcher_name__, **kwargs): raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.") def decorator(dispatcher): - for cls in datapoints_classes: - register_kernel(dispatcher, cls)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)) + for input_type in input_types: + register_kernel(dispatcher, input_type)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)) return dispatcher return decorator @@ -113,13 +123,10 @@ def decorator(dispatcher): # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool -# TODO: decide if we want that -def _register_five_ten_crop_kernel(dispatcher, datapoint_cls): +def _register_five_ten_crop_kernel(dispatcher, input_type): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) - if datapoint_cls in registry: - raise TypeError( - f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'." - ) + if input_type in registry: + raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") def wrap(kernel): @functools.wraps(kernel) @@ -131,7 +138,7 @@ def wrapper(inpt, *args, **kwargs): return wrapper def decorator(kernel): - registry[datapoint_cls] = wrap(kernel) + registry[input_type] = wrap(kernel) if issubclass(input_type, datapoints.Datapoint) else kernel return kernel return decorator From 1eaf82e7eaffd1d21bcd4f6b5b4a08b67452a119 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Aug 2023 22:23:41 +0200 Subject: [PATCH 02/10] refactor _get_kernel and add test --- test/test_transforms_v2_refactored.py | 79 ++++++++++++++++++- .../transforms/v2/functional/_utils.py | 58 ++++++++++---- 2 files changed, 120 insertions(+), 17 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 20b198796ca..4810622a4e6 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -39,7 +39,7 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY +from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal @pytest.fixture(autouse=True) @@ -2151,3 +2151,80 @@ def test_unsupported_types(self, dispatcher, make_input): with pytest.raises(TypeError, match=re.escape(str(type(input)))): dispatcher(input) + + +class TestGetKernel: + def make_and_register_kernel(self, dispatcher, input_type): + return _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(object()) + + @pytest.fixture + def dispatcher_and_kernels(self, mocker): + mocker.patch.dict(_KERNEL_REGISTRY, clear=True) + + dispatcher = object() + + kernels = { + cls: self.make_and_register_kernel(dispatcher, cls) + for cls in [ + torch.Tensor, + PIL.Image.Image, + datapoints.Image, + datapoints.BoundingBoxes, + datapoints.Mask, + datapoints.Video, + ] + } + + yield dispatcher, kernels + + def test_unsupported_types(self, dispatcher_and_kernels): + dispatcher, _ = dispatcher_and_kernels + + class MyTensor(torch.Tensor): + pass + + class MyPILImage(PIL.Image.Image): + pass + + for input_type in [str, int, object, MyTensor, MyPILImage]: + with pytest.raises(TypeError, match=re.escape(str(input_type))): + _get_kernel(dispatcher, input_type) + + def test_exact_match(self, dispatcher_and_kernels): + dispatcher, kernels = dispatcher_and_kernels + + for input_type, kernel in kernels.items(): + assert _get_kernel(dispatcher, input_type) is kernel + + def test_builtin_datapoint_subclass(self, dispatcher_and_kernels): + dispatcher, kernels = dispatcher_and_kernels + + class MyImage(datapoints.Image): + pass + + class MyBoundingBoxes(datapoints.BoundingBoxes): + pass + + class MyMask(datapoints.Mask): + pass + + class MyVideo(datapoints.Video): + pass + + assert _get_kernel(dispatcher, MyImage) is kernels[datapoints.Image] + assert _get_kernel(dispatcher, MyBoundingBoxes) is kernels[datapoints.BoundingBoxes] + assert _get_kernel(dispatcher, MyMask) is kernels[datapoints.Mask] + assert _get_kernel(dispatcher, MyVideo) is kernels[datapoints.Video] + + def test_datapoint_subclass(self, dispatcher_and_kernels): + dispatcher, _ = dispatcher_and_kernels + + class MyDatapoint(datapoints.Datapoint): + pass + + # Note that this will be an error in the future + assert _get_kernel(dispatcher, MyDatapoint) is _noop + + kernel = self.make_and_register_kernel(dispatcher, MyDatapoint) + + assert _get_kernel(dispatcher, MyDatapoint) is kernel diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 0f0b72c08ad..c1563c4df86 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -2,8 +2,6 @@ import warnings from typing import Any, Callable, Dict, Type -import PIL.Image - import torch from torchvision import datapoints @@ -28,7 +26,7 @@ def wrapper(inpt, *args, **kwargs): def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) if input_type in registry: - raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") + raise TypeError(f"Dispatcher {dispatcher} already has a kernel registered for type {input_type}.") def decorator(kernel): registry[input_type] = ( @@ -42,28 +40,52 @@ def decorator(kernel): def register_kernel(dispatcher, datapoint_cls): + if not ( + callable(dispatcher) + and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional") + ): + raise ValueError( + f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, " + f"but got {dispatcher}." + ) + elif not ( + isinstance(datapoint_cls, type) + and issubclass(datapoint_cls, datapoints.Datapoint) + and datapoint_cls is not datapoints.Datapoint + ): + raise ValueError( + f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, " + f"but got {datapoint_cls}." + ) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) def _get_kernel(dispatcher, input_type): - if not issubclass(input_type, (torch.Tensor, PIL.Image.Image)): - raise TypeError( - f"Dispatcher '{dispatcher}' supports plain tensors, any TorchVision datapoint, or a PIL images, " - f"but got {input_type} instead." - ) - registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: - raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") + raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") + # in case we have an exact type match, we take a shortcut if input_type in registry: return registry[input_type] - for registered_type, kernel in registry.items(): - if issubclass(input_type, registered_type): - return kernel + # in case of datapoints, we check if we have a kernel for a superclass registered + if issubclass(input_type, datapoints.Datapoint): + for cls in input_type.__mro__: + if cls is datapoints.Datapoint: + break + elif cls in registry: + return registry[cls] + + # Note that in the future we are not going to return a noop here, but rather raise the + # error below + return _noop - return _noop + raise TypeError( + f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, " + f"and subclasses of torchvision.datapoints.Datapoint, " + f"but got {input_type} instead." + ) # Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate @@ -95,7 +117,9 @@ def decorator(dispatcher): f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"This will likely change in the future." ) - register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if warn_passthrough else None)) + _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)( + functools.partial(_noop, __msg__=msg if warn_passthrough else None) + ) return dispatcher return decorator @@ -115,7 +139,9 @@ def kernel(inpt, *args, __dispatcher_name__, **kwargs): def decorator(dispatcher): for input_type in input_types: - register_kernel(dispatcher, input_type)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)) + _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)( + functools.partial(kernel, __dispatcher_name__=dispatcher.__name__) + ) return dispatcher return decorator From 390021fbcc8cce3443843ab1c91ec03674bff587 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Aug 2023 22:41:56 +0200 Subject: [PATCH 03/10] add test for register_kernel --- test/test_transforms_v2_refactored.py | 24 ++++++++++++++++++- .../transforms/v2/functional/_utils.py | 4 ++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 4810622a4e6..c9783366fa5 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -39,7 +39,13 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal +from torchvision.transforms.v2.functional._utils import ( + _get_kernel, + _KERNEL_REGISTRY, + _noop, + _register_kernel_internal, + register_kernel, +) @pytest.fixture(autouse=True) @@ -2228,3 +2234,19 @@ class MyDatapoint(datapoints.Datapoint): kernel = self.make_and_register_kernel(dispatcher, MyDatapoint) assert _get_kernel(dispatcher, MyDatapoint) is kernel + + +class TestRegisterKernel: + def test_errors(self, mocker): + mocker.patch.dict(_KERNEL_REGISTRY, clear=True) + + with pytest.raises(TypeError, match="Kernels can only be registered on dispatchers"): + register_kernel(datapoints.Image, F.resize) + + with pytest.raises(TypeError, match="Kernels can only be registered for subclasses"): + register_kernel(F.resize, object) + + register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) + + with pytest.raises(TypeError, match="already has a kernel registered for type"): + register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index c1563c4df86..d0df229ad67 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -44,7 +44,7 @@ def register_kernel(dispatcher, datapoint_cls): callable(dispatcher) and getattr(dispatcher, "__module__", "").startswith("torchvision.transforms.v2.functional") ): - raise ValueError( + raise TypeError( f"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, " f"but got {dispatcher}." ) @@ -53,7 +53,7 @@ def register_kernel(dispatcher, datapoint_cls): and issubclass(datapoint_cls, datapoints.Datapoint) and datapoint_cls is not datapoints.Datapoint ): - raise ValueError( + raise TypeError( f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, " f"but got {datapoint_cls}." ) From 8c0a8ea1db69d70973fc8cef15b12e3715d9dffb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Aug 2023 22:55:18 +0200 Subject: [PATCH 04/10] fix antialias warning for resized_crop for PIL --- torchvision/transforms/v2/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 8afe37384a3..5787decfb13 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2127,7 +2127,7 @@ def resized_crop_image_pil_dispatch( antialias: Optional[Union[str, bool]] = "warn", ) -> PIL.Image.Image: if antialias is False: - pass + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") return resized_crop_image_pil( image, top=top, From 55997293438a97199a47bfe2e461462cf1e25136 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Aug 2023 22:59:25 +0200 Subject: [PATCH 05/10] improve kernel registration for PIL alias kernels --- .../transforms/v2/functional/_color.py | 30 +++++++------------ .../transforms/v2/functional/_geometry.py | 3 +- torchvision/transforms/v2/functional/_meta.py | 6 ++-- 3 files changed, 13 insertions(+), 26 deletions(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 4d84e6d23a5..71797fd2500 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -138,8 +138,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float return _blend(image, grayscale_image, saturation_factor) -adjust_saturation_image_pil = _FP.adjust_saturation -_register_kernel_internal(adjust_saturation, PIL.Image.Image)(adjust_saturation_image_pil) +adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) @_register_kernel_internal(adjust_saturation, datapoints.Video) @@ -178,8 +177,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> return _blend(image, mean, contrast_factor) -adjust_contrast_image_pil = _FP.adjust_contrast -_register_kernel_internal(adjust_contrast, PIL.Image.Image)(adjust_contrast_image_pil) +adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) @_register_kernel_internal(adjust_contrast, datapoints.Video) @@ -252,8 +250,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) return output -adjust_sharpness_image_pil = _FP.adjust_sharpness -_register_kernel_internal(adjust_sharpness, PIL.Image.Image)(adjust_sharpness_image_pil) +adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) @_register_kernel_internal(adjust_sharpness, datapoints.Video) @@ -365,8 +362,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True) -adjust_hue_image_pil = _FP.adjust_hue -_register_kernel_internal(adjust_hue, PIL.Image.Image)(adjust_hue_image_pil) +adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) @_register_kernel_internal(adjust_hue, datapoints.Video) @@ -406,8 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 return to_dtype_image_tensor(output, image.dtype, scale=True) -adjust_gamma_image_pil = _FP.adjust_gamma -_register_kernel_internal(adjust_gamma, PIL.Image.Image)(adjust_gamma_image_pil) +adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) @_register_kernel_internal(adjust_gamma, datapoints.Video) @@ -441,8 +436,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: return image & mask -posterize_image_pil = _FP.posterize -_register_kernel_internal(posterize, PIL.Image.Image)(posterize_image_pil) +posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) @_register_kernel_internal(posterize, datapoints.Video) @@ -470,8 +464,7 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor return torch.where(image >= threshold, invert_image_tensor(image), image) -solarize_image_pil = _FP.solarize -_register_kernel_internal(solarize, PIL.Image.Image)(solarize_image_pil) +solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) @_register_kernel_internal(solarize, datapoints.Video) @@ -521,8 +514,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) -autocontrast_image_pil = _FP.autocontrast -_register_kernel_internal(autocontrast, PIL.Image.Image)(autocontrast_image_pil) +autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) @_register_kernel_internal(autocontrast, datapoints.Video) @@ -612,8 +604,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: return to_dtype_image_tensor(output, output_dtype, scale=True) -equalize_image_pil = _FP.equalize -_register_kernel_internal(equalize, PIL.Image.Image)(equalize_image_pil) +equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) @_register_kernel_internal(equalize, datapoints.Video) @@ -644,8 +635,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) -invert_image_pil = _FP.invert -_register_kernel_internal(invert, PIL.Image.Image)(invert_image_pil) +invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) @_register_kernel_internal(invert, datapoints.Video) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 5787decfb13..bb19def2c93 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1255,8 +1255,7 @@ def _pad_with_vector_fill( return output -pad_image_pil = _FP.pad -_register_kernel_internal(pad, PIL.Image.Image)(pad_image_pil) +pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad) @_register_kernel_internal(pad, datapoints.Mask) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index f434bfadb48..fc1aa05f319 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -36,8 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -get_dimensions_image_pil = _FP.get_dimensions -_register_kernel_internal(get_dimensions, PIL.Image.Image)(get_dimensions_image_pil) +get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) @_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) @@ -69,8 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -get_num_channels_image_pil = _FP.get_image_num_channels -_register_kernel_internal(get_num_channels, PIL.Image.Image)(get_num_channels_image_pil) +get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) @_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) From f61f5957a20b4f77595feaee9fe0ed36635d34c7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Aug 2023 23:13:24 +0200 Subject: [PATCH 06/10] cleanup --- test/test_transforms_v2_refactored.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 33bd9b268f8..4748c54ff74 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -39,13 +39,7 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import ( - _get_kernel, - _KERNEL_REGISTRY, - _noop, - _register_kernel_internal, - register_kernel, -) +from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal @pytest.fixture(autouse=True) @@ -2193,15 +2187,15 @@ def test_errors(self, mocker): F.register_kernel("bad_name", datapoints.Image) with pytest.raises(ValueError, match="Kernels can only be registered on dispatchers"): - register_kernel(datapoints.Image, F.resize) + F.register_kernel(datapoints.Image, F.resize) with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): - register_kernel(F.resize, object) + F.register_kernel(F.resize, object) - register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) + F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) with pytest.raises(ValueError, match="already has a kernel registered for type"): - register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) + F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) class TestGetKernel: From 7f67e521091ca3d3b357c5147f2690144e7bdf71 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Aug 2023 15:51:50 +0200 Subject: [PATCH 07/10] (hopefully) simplify tests --- test/test_transforms_v2_refactored.py | 107 ++++++++++-------- .../transforms/v2/functional/_utils.py | 2 +- 2 files changed, 58 insertions(+), 51 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 4748c54ff74..c910882f9fd 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2155,9 +2155,7 @@ def test_unsupported_types(self, dispatcher, make_input): class TestRegisterKernel: @pytest.mark.parametrize("dispatcher", (F.resize, "resize")) - def test_register_kernel(self, mocker, dispatcher): - mocker.patch.dict(_KERNEL_REGISTRY, values={F.resize: _KERNEL_REGISTRY[F.resize]}, clear=True) - + def test_register_kernel(self, dispatcher): class CustomDatapoint(datapoints.Datapoint): pass @@ -2180,9 +2178,7 @@ def new_resize(dp, *args, **kwargs): t(torch.rand(3, 10, 10)).shape == (3, 224, 224) t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224) - def test_errors(self, mocker): - mocker.patch.dict(_KERNEL_REGISTRY, clear=True) - + def test_errors(self): with pytest.raises(ValueError, match="Could not find dispatcher with name"): F.register_kernel("bad_name", datapoints.Image) @@ -2192,39 +2188,23 @@ def test_errors(self, mocker): with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): F.register_kernel(F.resize, object) - F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) - with pytest.raises(ValueError, match="already has a kernel registered for type"): F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) class TestGetKernel: - def make_and_register_kernel(self, dispatcher, input_type): - return _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(object()) - - @pytest.fixture - def dispatcher_and_kernels(self, mocker): - mocker.patch.dict(_KERNEL_REGISTRY, clear=True) - - dispatcher = object() - - kernels = { - cls: self.make_and_register_kernel(dispatcher, cls) - for cls in [ - torch.Tensor, - PIL.Image.Image, - datapoints.Image, - datapoints.BoundingBoxes, - datapoints.Mask, - datapoints.Video, - ] - } - - yield dispatcher, kernels - - def test_unsupported_types(self, dispatcher_and_kernels): - dispatcher, _ = dispatcher_and_kernels + # We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination + # would also be fine + KERNELS = { + torch.Tensor: F.resize_image_tensor, + PIL.Image.Image: F.resize_image_pil, + datapoints.Image: F.resize_image_tensor, + datapoints.BoundingBoxes: F.resize_bounding_boxes, + datapoints.Mask: F.resize_mask, + datapoints.Video: F.resize_video, + } + def test_unsupported_types(self): class MyTensor(torch.Tensor): pass @@ -2232,17 +2212,34 @@ class MyPILImage(PIL.Image.Image): pass for input_type in [str, int, object, MyTensor, MyPILImage]: - with pytest.raises(TypeError, match=re.escape(str(input_type))): - _get_kernel(dispatcher, input_type) + with pytest.raises( + TypeError, + match=( + "supports inputs of type torch.Tensor, PIL.Image.Image, " + "and subclasses of torchvision.datapoints.Datapoint" + ), + ): + _get_kernel(F.resize, input_type) + + def test_exact_match(self): + # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher + # here, register the kernels without wrapper, and check the exact matching afterwards. + def resize_with_pure_kernels(): + pass - def test_exact_match(self, dispatcher_and_kernels): - dispatcher, kernels = dispatcher_and_kernels + for input_type, kernel in self.KERNELS.items(): + _register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel) - for input_type, kernel in kernels.items(): - assert _get_kernel(dispatcher, input_type) is kernel + assert _get_kernel(resize_with_pure_kernels, input_type) is kernel - def test_builtin_datapoint_subclass(self, dispatcher_and_kernels): - dispatcher, kernels = dispatcher_and_kernels + def test_builtin_datapoint_subclass(self): + # We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the + # ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher + # here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched + # to the kernel of the corresponding superclass + def resize_with_pure_kernels(): + pass class MyImage(datapoints.Image): pass @@ -2256,20 +2253,30 @@ class MyMask(datapoints.Mask): class MyVideo(datapoints.Video): pass - assert _get_kernel(dispatcher, MyImage) is kernels[datapoints.Image] - assert _get_kernel(dispatcher, MyBoundingBoxes) is kernels[datapoints.BoundingBoxes] - assert _get_kernel(dispatcher, MyMask) is kernels[datapoints.Mask] - assert _get_kernel(dispatcher, MyVideo) is kernels[datapoints.Video] + for custom_datapoint_subclass in [ + MyImage, + MyBoundingBoxes, + MyMask, + MyVideo, + ]: + builtin_datapoint_class = custom_datapoint_subclass.__mro__[1] + builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class] + _register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)( + builtin_datapoint_kernel + ) - def test_datapoint_subclass(self, dispatcher_and_kernels): - dispatcher, _ = dispatcher_and_kernels + assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel + def test_datapoint_subclass(self): class MyDatapoint(datapoints.Datapoint): pass # Note that this will be an error in the future - assert _get_kernel(dispatcher, MyDatapoint) is _noop + assert _get_kernel(F.resize, MyDatapoint) is _noop + + def resize_my_datapoint(): + pass - kernel = self.make_and_register_kernel(dispatcher, MyDatapoint) + _register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint) - assert _get_kernel(dispatcher, MyDatapoint) is kernel + assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 5f591fe11f6..0a24b9f1f93 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -86,7 +86,7 @@ def _get_kernel(dispatcher, input_type): # in case of datapoints, we check if we have a kernel for a superclass registered if issubclass(input_type, datapoints.Datapoint): - for cls in input_type.__mro__: + for cls in input_type.__mro__[1:]: if cls is datapoints.Datapoint: break elif cls in registry: From dca577ea93658e6d04dbfe37eb70b674e76d62aa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Aug 2023 15:54:43 +0200 Subject: [PATCH 08/10] Update torchvision/transforms/v2/functional/_utils.py Co-authored-by: Nicolas Hug --- torchvision/transforms/v2/functional/_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 5f591fe11f6..05ce2b13b9b 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -88,6 +88,9 @@ def _get_kernel(dispatcher, input_type): if issubclass(input_type, datapoints.Datapoint): for cls in input_type.__mro__: if cls is datapoints.Datapoint: + # we don't want user-defined datapoints to dispatch to the pure Tensor kernels, + # so we expliclty stop the mro traversal before hitting Tensor. We can even stop at Datapoint + # since we don't allow kernels to be registered for Datapoints anyway. break elif cls in registry: return registry[cls] From 8557cdfaadb2bfed9d4a69a3b370974b9fe80f58 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Aug 2023 15:56:06 +0200 Subject: [PATCH 09/10] cleanup --- torchvision/transforms/v2/functional/_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index d8d5383553a..b9dc86e8205 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -80,17 +80,18 @@ def _get_kernel(dispatcher, input_type): if not registry: raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") - # in case we have an exact type match, we take a shortcut + # In case we have an exact type match, we take a shortcut. if input_type in registry: return registry[input_type] - # in case of datapoints, we check if we have a kernel for a superclass registered + # In case of datapoints, we check if we have a kernel for a superclass registered if issubclass(input_type, datapoints.Datapoint): + # Since we have already checked for an exact match above, we can start the traversal at the superclass. for cls in input_type.__mro__[1:]: if cls is datapoints.Datapoint: - # we don't want user-defined datapoints to dispatch to the pure Tensor kernels, - # so we expliclty stop the mro traversal before hitting Tensor. We can even stop at Datapoint - # since we don't allow kernels to be registered for Datapoints anyway. + # We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the + # MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't + # allow kernels to be registered for datapoints.Datapoint anyway. break elif cls in registry: return registry[cls] From 88904dbbf4ee3ae286ea12a3e6200475af2446ce Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Aug 2023 16:00:27 +0200 Subject: [PATCH 10/10] more --- torchvision/transforms/v2/functional/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index b9dc86e8205..8371b9b23a8 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -96,8 +96,7 @@ def _get_kernel(dispatcher, input_type): elif cls in registry: return registry[cls] - # Note that in the future we are not going to return a noop here, but rather raise the - # error below + # Note that in the future we are not going to return a noop here, but rather raise the error below return _noop raise TypeError(