From d9e13795ce6828e49ed8d35221664b9421194182 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 19 Jul 2023 13:24:06 +0200 Subject: [PATCH 01/22] [PoC] refactor Datapoint dispatch mechanism --- test/test_transforms_v2_refactored.py | 27 +++++--- torchvision/datapoints/_bounding_box.py | 15 ---- torchvision/datapoints/_datapoint.py | 11 --- torchvision/datapoints/_image.py | 12 ---- torchvision/datapoints/_mask.py | 10 --- torchvision/datapoints/_video.py | 16 ----- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_geometry.py | 68 ++++++++++++------- .../transforms/v2/functional/_utils.py | 23 +++++++ 9 files changed, 83 insertions(+), 101 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 69180b99dbc..fc96175cc71 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -165,17 +165,20 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): preserved in doing so. For bounding boxes also checks that the format is preserved. """ if isinstance(input, datapoints._datapoint.Datapoint): - # Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly, - # but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel. - spy = mock.MagicMock(wraps=kernel, name=kernel.__name__) - with mock.patch.object(F, kernel.__name__, spy): - # Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class. - # Since that is not the case here, we need to prefix f"_{cls.__name__}" - # See https://docs.python.org/3/tutorial/classes.html#private-variables for details - with mock.patch.object(datapoints._datapoint.Datapoint, "_Datapoint__F", new=F): - output = dispatcher(input, *args, **kwargs) - - spy.assert_called_once() + if dispatcher is F.resize: + output = dispatcher(input, *args, **kwargs) + else: + # Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly, + # but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel. + spy = mock.MagicMock(wraps=kernel, name=kernel.__name__) + with mock.patch.object(F, kernel.__name__, spy): + # Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class. + # Since that is not the case here, we need to prefix f"_{cls.__name__}" + # See https://docs.python.org/3/tutorial/classes.html#private-variables for details + with mock.patch.object(datapoints._datapoint.Datapoint, "_Datapoint__F", new=F): + output = dispatcher(input, *args, **kwargs) + + spy.assert_called_once() else: with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy: output = dispatcher(input, *args, **kwargs) @@ -249,6 +252,8 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): def _check_dispatcher_datapoint_signature_match(dispatcher): """Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class.""" + if dispatcher is F.resize: + return dispatcher_signature = inspect.signature(dispatcher) dispatcher_params = list(dispatcher_signature.parameters.values())[1:] diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 11d42f171e4..ca7b8661a85 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -110,21 +110,6 @@ def vertical_flip(self) -> BoundingBox: ) return BoundingBox.wrap_like(self, output) - def resize( # type: ignore[override] - self, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[Union[str, bool]] = "warn", - ) -> BoundingBox: - output, spatial_size = self._F.resize_bounding_box( - self.as_subclass(torch.Tensor), - spatial_size=self.spatial_size, - size=size, - max_size=max_size, - ) - return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) - def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: output, spatial_size = self._F.crop_bounding_box( self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 0dabec58f25..998b606888c 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -148,17 +148,6 @@ def horizontal_flip(self) -> Datapoint: def vertical_flip(self) -> Datapoint: return self - # TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize - # https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593 - def resize( # type: ignore[override] - self, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Datapoint: - return self - def crop(self, top: int, left: int, height: int, width: int) -> Datapoint: return self diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index e47a6c10fc3..ea6435a9811 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -72,18 +72,6 @@ def vertical_flip(self) -> Image: output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor)) return Image.wrap_like(self, output) - def resize( # type: ignore[override] - self, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Image: - output = self._F.resize_image_tensor( - self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias - ) - return Image.wrap_like(self, output) - def crop(self, top: int, left: int, height: int, width: int) -> Image: output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width) return Image.wrap_like(self, output) diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 0135d793d32..3d88a34764b 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -63,16 +63,6 @@ def vertical_flip(self) -> Mask: output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor)) return Mask.wrap_like(self, output) - def resize( # type: ignore[override] - self, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - max_size: Optional[int] = None, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Mask: - output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size) - return Mask.wrap_like(self, output) - def crop(self, top: int, left: int, height: int, width: int) -> Mask: output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width) return Mask.wrap_like(self, output) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index a6fbe2bd473..592d85938cd 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -66,22 +66,6 @@ def vertical_flip(self) -> Video: output = self._F.vertical_flip_video(self.as_subclass(torch.Tensor)) return Video.wrap_like(self, output) - def resize( # type: ignore[override] - self, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Video: - output = self._F.resize_video( - self.as_subclass(torch.Tensor), - size, - interpolation=interpolation, - max_size=max_size, - antialias=antialias, - ) - return Video.wrap_like(self, output) - def crop(self, top: int, left: int, height: int, width: int) -> Video: output = self._F.crop_video(self.as_subclass(torch.Tensor), top, left, height, width) return Video.wrap_like(self, output) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index b4803f4f1b9..27acdda0a2c 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_simple_tensor # usort: skip +from ._utils import is_simple_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_box, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index e1dd2866bc5..6c0250be9ec 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,7 +25,7 @@ from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil -from ._utils import is_simple_tensor +from ._utils import _get_kernel, is_simple_tensor, register_kernel def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -158,6 +158,32 @@ def _compute_resized_output_size( return __compute_resized_output_size(spatial_size, size=size, max_size=max_size) +def resize( + inpt: datapoints._InputTypeJIT, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[Union[str, bool]] = "warn", +) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(resize) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(resize, type(inpt)) + return kernel(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) + elif isinstance(inpt, PIL.Image.Image): + if antialias is False: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@register_kernel(resize, datapoints.Image) def resize_image_tensor( image: torch.Tensor, size: List[int], @@ -274,6 +300,11 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N return output +@register_kernel(resize, datapoints.Mask) +def _resize_mask_dispatch(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None, **kwargs): + return resize_mask(mask, size, max_size=max_size) + + def resize_bounding_box( bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -292,6 +323,17 @@ def resize_bounding_box( ) +@register_kernel(resize, datapoints.BoundingBox, datapoint_wrapping=False) +def _resize_bounding_box_dispatch( + bounding_box: datapoints.BoundingBox, size: List[int], max_size: Optional[int] = None, **kwargs +): + output, spatial_size = resize_bounding_box( + bounding_box.as_subclass(torch.Tensor), bounding_box.spatial_size, size, max_size=max_size + ) + return datapoints.BoundingBox.wrap_like(bounding_box, output, spatial_size=spatial_size) + + +@register_kernel(resize, datapoints.Video) def resize_video( video: torch.Tensor, size: List[int], @@ -302,30 +344,6 @@ def resize_video( return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) -def resize( - inpt: datapoints._InputTypeJIT, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[Union[str, bool]] = "warn", -) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(resize) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, PIL.Image.Image): - if antialias is False: - warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") - return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - def _affine_parse_args( angle: Union[int, float], translate: List[float], diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index f31ccb939a5..2164fe9ead1 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,3 +1,4 @@ +import functools from typing import Any import torch @@ -6,3 +7,25 @@ def is_simple_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) + + +_KERNEL_REGISTRY = {} + + +def register_kernel(dispatcher, datapoint_cls, *, datapoint_wrapping=True): + def datapoint_wrapper(kernel): + @functools.wraps(kernel) + def wrapper(inpt, *args, **kwargs): + return type(inpt).wrap_like(inpt, kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)) + + return wrapper + + def decorator(kernel): + _KERNEL_REGISTRY[(dispatcher, datapoint_cls)] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel + return kernel + + return decorator + + +def _get_kernel(dispatcher, datapoint_cls): + return _KERNEL_REGISTRY[(dispatcher, datapoint_cls)] From 36b9d36cae5179d9961dc380b2517c3c848a7683 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 19 Jul 2023 14:22:07 +0200 Subject: [PATCH 02/22] fix test --- test/test_transforms_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3743581794f..e2f8aff5815 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1419,8 +1419,6 @@ def test_antialias_warning(): with pytest.warns(UserWarning, match=match): datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20)) - with pytest.warns(UserWarning, match=match): - datapoints.Video(tensor_video).resize((20, 20)) with pytest.warns(UserWarning, match=match): datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20)) From bbaa35c915fb385978ffbbc6d4b82e9d3d0bf925 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jul 2023 08:51:02 +0200 Subject: [PATCH 03/22] add dispatch to adjust_brightness --- test/test_transforms_v2_refactored.py | 4 +- .../transforms/v2/functional/_color.py | 39 ++++++++++--------- .../transforms/v2/functional/_utils.py | 15 ++++++- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 87a85f42f45..04ace4c9d3f 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -167,7 +167,7 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): preserved in doing so. For bounding boxes also checks that the format is preserved. """ if isinstance(input, datapoints._datapoint.Datapoint): - if dispatcher is F.resize: + if dispatcher in {F.resize, F.adjust_brightness}: output = dispatcher(input, *args, **kwargs) else: # Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly, @@ -254,7 +254,7 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): def _check_dispatcher_datapoint_signature_match(dispatcher): """Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class.""" - if dispatcher is F.resize: + if dispatcher in {F.resize, F.adjust_brightness}: return dispatcher_signature = inspect.signature(dispatcher) dispatcher_params = list(dispatcher_signature.parameters.values())[1:] diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 13417e4a990..154c7b63141 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,7 +10,7 @@ from torchvision.utils import _log_api_usage_once from ._meta import _num_value_bits, convert_dtype_image_tensor -from ._utils import is_simple_tensor +from ._utils import _get_kernel, is_simple_tensor, register_kernel def _rgb_to_grayscale_image_tensor( @@ -69,6 +69,25 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te return output if fp else output.to(image1.dtype) +def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_brightness) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + kernel = _get_kernel(adjust_brightness, type(inpt)) + return kernel(inpt, brightness_factor=brightness_factor) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.adjust_brightness(brightness_factor=brightness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@register_kernel(adjust_brightness, datapoints.Image) def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") @@ -86,27 +105,11 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float adjust_brightness_image_pil = _FP.adjust_brightness +@register_kernel(adjust_brightness, datapoints.Video) def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) -def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_brightness) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.adjust_brightness(brightness_factor=brightness_factor) - elif isinstance(inpt, PIL.Image.Image): - return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 2164fe9ead1..326826b56d6 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -20,12 +20,23 @@ def wrapper(inpt, *args, **kwargs): return wrapper + registry = _KERNEL_REGISTRY.get(dispatcher, {}) + def decorator(kernel): - _KERNEL_REGISTRY[(dispatcher, datapoint_cls)] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel + registry[datapoint_cls] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel return kernel return decorator +def _noop(inpt, *args, **kwargs): + return inpt + + def _get_kernel(dispatcher, datapoint_cls): - return _KERNEL_REGISTRY[(dispatcher, datapoint_cls)] + registry = _KERNEL_REGISTRY.get(dispatcher, {}) + + if datapoint_cls in registry: + return registry[datapoint_cls] + + return _noop From ca4ad32e9aa74cb9207a8a55dde727b6681717a0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jul 2023 09:06:35 +0200 Subject: [PATCH 04/22] enforce no register overwrite --- torchvision/transforms/v2/functional/_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 326826b56d6..4474600ccff 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -20,7 +20,11 @@ def wrapper(inpt, *args, **kwargs): return wrapper - registry = _KERNEL_REGISTRY.get(dispatcher, {}) + registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) + if datapoint_cls in registry: + raise TypeError( + f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'." + ) def decorator(kernel): registry[datapoint_cls] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel @@ -34,9 +38,15 @@ def _noop(inpt, *args, **kwargs): def _get_kernel(dispatcher, datapoint_cls): - registry = _KERNEL_REGISTRY.get(dispatcher, {}) + registry = _KERNEL_REGISTRY.get(dispatcher) + if not registry: + raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") if datapoint_cls in registry: return registry[datapoint_cls] + for registered_cls, kernel in registry.items(): + if issubclass(datapoint_cls, registered_cls): + return kernel + return _noop From d23a80ea900879a701f55f66b5ed867b76f77de9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jul 2023 10:39:56 +0200 Subject: [PATCH 05/22] [PoC] make wrapping interal kernel more convenient --- .../transforms/v2/functional/_color.py | 6 +-- .../transforms/v2/functional/_geometry.py | 23 ++------ .../transforms/v2/functional/_utils.py | 54 ++++++++++++++++--- 3 files changed, 55 insertions(+), 28 deletions(-) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 154c7b63141..fd441e7b9ff 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,7 +10,7 @@ from torchvision.utils import _log_api_usage_once from ._meta import _num_value_bits, convert_dtype_image_tensor -from ._utils import _get_kernel, is_simple_tensor, register_kernel +from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor def _rgb_to_grayscale_image_tensor( @@ -87,7 +87,7 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) ) -@register_kernel(adjust_brightness, datapoints.Image) +@_register_kernel_internal(adjust_brightness, datapoints.Image) def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") @@ -105,7 +105,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float adjust_brightness_image_pil = _FP.adjust_brightness -@register_kernel(adjust_brightness, datapoints.Video) +@_register_kernel_internal(adjust_brightness, datapoints.Video) def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 6c0250be9ec..b7ac48c40ec 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,7 +25,7 @@ from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil -from ._utils import _get_kernel, is_simple_tensor, register_kernel +from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -183,7 +183,7 @@ def resize( ) -@register_kernel(resize, datapoints.Image) +@_register_kernel_internal(resize, datapoints.Image) def resize_image_tensor( image: torch.Tensor, size: List[int], @@ -285,6 +285,7 @@ def resize_image_pil( return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation]) +@_register_kernel_internal(resize, datapoints.Mask) def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -300,11 +301,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N return output -@register_kernel(resize, datapoints.Mask) -def _resize_mask_dispatch(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None, **kwargs): - return resize_mask(mask, size, max_size=max_size) - - +@_register_kernel_internal(resize, datapoints.BoundingBox) def resize_bounding_box( bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -323,17 +320,7 @@ def resize_bounding_box( ) -@register_kernel(resize, datapoints.BoundingBox, datapoint_wrapping=False) -def _resize_bounding_box_dispatch( - bounding_box: datapoints.BoundingBox, size: List[int], max_size: Optional[int] = None, **kwargs -): - output, spatial_size = resize_bounding_box( - bounding_box.as_subclass(torch.Tensor), bounding_box.spatial_size, size, max_size=max_size - ) - return datapoints.BoundingBox.wrap_like(bounding_box, output, spatial_size=spatial_size) - - -@register_kernel(resize, datapoints.Video) +@_register_kernel_internal(resize, datapoints.Video) def resize_video( video: torch.Tensor, size: List[int], diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 4474600ccff..a1cd3b7891c 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,7 +1,9 @@ import functools +import inspect from typing import Any import torch +from torchvision import datapoints from torchvision.datapoints._datapoint import Datapoint @@ -12,14 +14,48 @@ def is_simple_tensor(inpt: Any) -> bool: _KERNEL_REGISTRY = {} -def register_kernel(dispatcher, datapoint_cls, *, datapoint_wrapping=True): - def datapoint_wrapper(kernel): - @functools.wraps(kernel) - def wrapper(inpt, *args, **kwargs): - return type(inpt).wrap_like(inpt, kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)) +def _kernel_wrapper_internal(dispatcher, kernel): + dispatcher_params = list(inspect.signature(dispatcher).parameters)[1:] + kernel_params = list(inspect.signature(kernel).parameters)[1:] - return wrapper + needs_args_kwargs_handling = kernel_params != dispatcher_params + # this avoids converting list -> set at runtime below + kernel_params = set(kernel_params) + + @functools.wraps(kernel) + def wrapper(inpt, *args, **kwargs): + input_type = type(inpt) + + if needs_args_kwargs_handling: + # Convert args to kwargs to simplify further processing + kwargs.update(dict(zip(dispatcher_params, args))) + args = () + + # drop parameters that are not relevant for the kernel, but have a default value + # in the dispatcher + for kwarg in kwargs.keys() - kernel_params: + del kwargs[kwarg] + + # add parameters that are passed implicitly to the dispatcher as metadata, + # but have to be explicit for the kernel + for kwarg in input_type.__annotations__.keys() & kernel_params: + kwargs[kwarg] = getattr(inpt, kwarg) + + output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) + + if isinstance(inpt, datapoints.BoundingBox) and isinstance(output, tuple): + output, spatial_size = output + metadata = dict(spatial_size=spatial_size) + else: + metadata = dict() + + return input_type.wrap_like(inpt, output, **metadata) + + return wrapper + + +def _register_kernel_internal(dispatcher, datapoint_cls, *, wrap_kernel=True): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) if datapoint_cls in registry: raise TypeError( @@ -27,12 +63,16 @@ def wrapper(inpt, *args, **kwargs): ) def decorator(kernel): - registry[datapoint_cls] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel + registry[datapoint_cls] = _kernel_wrapper_internal(dispatcher, kernel) if wrap_kernel else kernel return kernel return decorator +def register_kernel(dispatcher, datapoint_cls): + return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False) + + def _noop(inpt, *args, **kwargs): return inpt From bf471887130033cf0e14254aa6e0fa07541b264d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jul 2023 13:44:09 +0200 Subject: [PATCH 06/22] [PoC] enforce explicit no-ops --- test/test_transforms_v2_refactored.py | 28 +++++++++++++++++++ .../transforms/v2/functional/_color.py | 5 +++- .../transforms/v2/functional/_utils.py | 18 ++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 04ace4c9d3f..66231010a36 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -34,6 +34,7 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY @pytest.fixture(autouse=True) @@ -428,6 +429,33 @@ def transform(bbox): return torch.stack([transform(b) for b in bounding_box.reshape(-1, 4).unbind()]).reshape(bounding_box.shape) +@pytest.mark.parametrize( + ("dispatcher", "registered_datapoint_clss"), + [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], +) +def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): + missing = { + datapoints.Image, + datapoints.BoundingBox, + datapoints.Mask, + datapoints.Video, + } - registered_datapoint_clss + if missing: + names = sorted(f"datapoints.{cls.__name__}" for cls in missing) + raise AssertionError( + "\n".join( + [ + f"The dispatcher '{dispatcher.__name__}' hs no kernels registered for", + "", + *[f"- {name}" for name in names], + "", + f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).", + f"If not, register explicit no-ops with _register_explicit_noops({dispatcher.__name__}, {', '.join(names)})", + ] + ) + ) + + class TestResize: INPUT_SIZE = (17, 11) OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)] diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index fd441e7b9ff..c39b108cd20 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,7 +10,7 @@ from torchvision.utils import _log_api_usage_once from ._meta import _num_value_bits, convert_dtype_image_tensor -from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noops, _register_kernel_internal, is_simple_tensor def _rgb_to_grayscale_image_tensor( @@ -87,6 +87,9 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) ) +_register_explicit_noops(adjust_brightness, datapoints.BoundingBox, datapoints.Mask) + + @_register_kernel_internal(adjust_brightness, datapoints.Image) def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index a1cd3b7891c..666d5d85e54 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -77,6 +77,24 @@ def _noop(inpt, *args, **kwargs): return inpt +def _register_explicit_noops(dispatcher, *datapoints_clss): + """ + Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users + from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. + + For example, without explicit no-op registration the following would be valid user code: + + .. code:: + from torchvision.transforms.v2 import functional as F + + @F.register_kernel(F.adjust_brightness, datapoints.BoundingBox) + def lol(...): + ... + """ + for datapoint_cls in datapoints_clss: + register_kernel(dispatcher, datapoint_cls)(_noop) + + def _get_kernel(dispatcher, datapoint_cls): registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: From 74d50549e593bc8eb1d5c615d4782d8056e0017b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jul 2023 15:14:41 +0200 Subject: [PATCH 07/22] fix adjust_brightness tests and remove methods --- test/test_transforms_v2_refactored.py | 50 +++++++++++++++++++ test/transforms_v2_dispatcher_infos.py | 8 --- test/transforms_v2_kernel_infos.py | 40 --------------- torchvision/datapoints/_datapoint.py | 3 -- torchvision/datapoints/_image.py | 6 --- torchvision/datapoints/_video.py | 4 -- .../transforms/v2/functional/_color.py | 7 +-- 7 files changed, 54 insertions(+), 64 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 66231010a36..7f1f4e98998 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1730,3 +1730,53 @@ def call_transform(): assert isinstance(output, tuple) and len(output) == 2 assert output[0] is image assert output[1] is label + + +class TestAdjustBrightness: + _CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0] + _DEFAULT_BRIGHTNESS_FACTOR = _CORRECTNESS_BRIGHTNESS_FACTORS[0] + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.adjust_brightness_image_tensor, make_image), + (F.adjust_brightness_video, make_video), + ], + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel(self, kernel, make_input, dtype, device): + check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.adjust_brightness_image_tensor, make_image_tensor), + (F.adjust_brightness_image_pil, make_image_pil), + (F.adjust_brightness_image_tensor, make_image), + (F.adjust_brightness_video, make_video), + ], + ) + def test_dispatcher(self, kernel, make_input): + check_dispatcher(F.adjust_brightness, kernel, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.adjust_brightness_image_tensor, torch.Tensor), + (F.adjust_brightness_image_pil, PIL.Image.Image), + (F.adjust_brightness_image_tensor, datapoints.Image), + (F.adjust_brightness_video, datapoints.Video), + ], + ) + def test_dispatcher_signature(self, kernel, input_type): + check_dispatcher_signatures_match(F.adjust_brightness, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS) + def test_image_correctness(self, brightness_factor): + image = make_image(dtype=torch.uint8, device="cpu") + + actual = F.adjust_brightness(image, brightness_factor=brightness_factor) + expected = F.to_image_tensor(F.adjust_brightness(F.to_image_pil(image), brightness_factor=brightness_factor)) + + torch.testing.assert_close(actual, expected) diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index 6f61526f382..ccd00977e93 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -289,14 +289,6 @@ def fill_sequence_needs_broadcast(args_kwargs): skip_dispatch_datapoint, ], ), - DispatcherInfo( - F.adjust_brightness, - kernels={ - datapoints.Image: F.adjust_brightness_image_tensor, - datapoints.Video: F.adjust_brightness_video, - }, - pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"), - ), DispatcherInfo( F.adjust_contrast, kernels={ diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index dc04fbfc7a9..556cf49cf70 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -1261,46 +1261,6 @@ def sample_inputs_erase_video(): ] ) -_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] - - -def sample_inputs_adjust_brightness_image_tensor(): - for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")): - yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) - - -def reference_inputs_adjust_brightness_image_tensor(): - for image_loader, brightness_factor in itertools.product( - make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]), - _ADJUST_BRIGHTNESS_FACTORS, - ): - yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) - - -def sample_inputs_adjust_brightness_video(): - for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]): - yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.adjust_brightness_image_tensor, - kernel_name="adjust_brightness_image_tensor", - sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil), - reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor, - float32_vs_uint8=True, - closeness_kwargs=float32_vs_uint8_pixel_difference(), - ), - KernelInfo( - F.adjust_brightness_video, - sample_inputs_fn=sample_inputs_adjust_brightness_video, - ), - ] -) - - _ADJUST_CONTRAST_FACTORS = [0.1, 0.5] diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 998b606888c..4a9a8a27c1c 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -217,9 +217,6 @@ def elastic( def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint: return self - def adjust_brightness(self, brightness_factor: float) -> Datapoint: - return self - def adjust_saturation(self, saturation_factor: float) -> Datapoint: return self diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index ea6435a9811..32a7836d8f4 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -181,12 +181,6 @@ def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image: ) return Image.wrap_like(self, output) - def adjust_brightness(self, brightness_factor: float) -> Image: - output = self._F.adjust_brightness_image_tensor( - self.as_subclass(torch.Tensor), brightness_factor=brightness_factor - ) - return Image.wrap_like(self, output) - def adjust_saturation(self, saturation_factor: float) -> Image: output = self._F.adjust_saturation_image_tensor( self.as_subclass(torch.Tensor), saturation_factor=saturation_factor diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index 592d85938cd..99d6d0e2edb 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -175,10 +175,6 @@ def rgb_to_grayscale(self, num_output_channels: int = 1) -> Video: ) return Video.wrap_like(self, output) - def adjust_brightness(self, brightness_factor: float) -> Video: - output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor) - return Video.wrap_like(self, output) - def adjust_saturation(self, saturation_factor: float) -> Video: output = self._F.adjust_saturation_video(self.as_subclass(torch.Tensor), saturation_factor=saturation_factor) return Video.wrap_like(self, output) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index c39b108cd20..0916bcf61ae 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -74,10 +74,10 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) _log_api_usage_once(adjust_brightness) if torch.jit.is_scripting() or is_simple_tensor(inpt): + return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) + elif isinstance(inpt, datapoints._datapoint.Datapoint): kernel = _get_kernel(adjust_brightness, type(inpt)) return kernel(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.adjust_brightness(brightness_factor=brightness_factor) elif isinstance(inpt, PIL.Image.Image): return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) else: @@ -105,7 +105,8 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float return output if fp else output.to(image.dtype) -adjust_brightness_image_pil = _FP.adjust_brightness +def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: + return _FP.adjust_brightness(image, brightness_factor=brightness_factor) @_register_kernel_internal(adjust_brightness, datapoints.Video) From f178373ae9b488c35286641fb30033edc575692e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 27 Jul 2023 16:03:19 +0200 Subject: [PATCH 08/22] address minor comments --- test/test_transforms_v2_refactored.py | 2 +- torchvision/transforms/v2/functional/_color.py | 4 ++-- torchvision/transforms/v2/functional/_utils.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 5d5fe1c3c21..d0aa4540c3f 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -450,7 +450,7 @@ def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): raise AssertionError( "\n".join( [ - f"The dispatcher '{dispatcher.__name__}' hs no kernels registered for", + f"The dispatcher '{dispatcher.__name__}' has no kernel registered for", "", *[f"- {name}" for name in names], "", diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 3d44e472010..5a7008f8f0a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,7 +10,7 @@ from torchvision.utils import _log_api_usage_once from ._meta import _num_value_bits, to_dtype_image_tensor -from ._utils import _get_kernel, _register_explicit_noops, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor def _rgb_to_grayscale_image_tensor( @@ -87,7 +87,7 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) ) -_register_explicit_noops(adjust_brightness, datapoints.BoundingBox, datapoints.Mask) +_register_explicit_noop(adjust_brightness, datapoints.BoundingBox, datapoints.Mask) @_register_kernel_internal(adjust_brightness, datapoints.Image) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 666d5d85e54..23eeb626b9b 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -77,7 +77,7 @@ def _noop(inpt, *args, **kwargs): return inpt -def _register_explicit_noops(dispatcher, *datapoints_clss): +def _register_explicit_noop(dispatcher, *datapoints_classes): """ Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. @@ -91,8 +91,8 @@ def _register_explicit_noops(dispatcher, *datapoints_clss): def lol(...): ... """ - for datapoint_cls in datapoints_clss: - register_kernel(dispatcher, datapoint_cls)(_noop) + for cls in datapoints_classes: + register_kernel(dispatcher, cls)(_noop) def _get_kernel(dispatcher, datapoint_cls): From 65e80d01a1c4fa19ab9a8d36f2873d35017b976d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 28 Jul 2023 09:25:47 +0200 Subject: [PATCH 09/22] make no-op registration a decorator --- test/test_transforms_v2_refactored.py | 2 +- torchvision/transforms/v2/functional/_color.py | 4 +--- torchvision/transforms/v2/functional/_utils.py | 11 ++++++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index d0aa4540c3f..7581dbf4cd4 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -455,7 +455,7 @@ def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): *[f"- {name}" for name in names], "", f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).", - f"If not, register explicit no-ops with _register_explicit_noops({dispatcher.__name__}, {', '.join(names)})", + f"If not, register explicit no-ops with @_register_explicit_noops({', '.join(names)})", ] ) ) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5a7008f8f0a..9798156ea49 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -69,6 +69,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te return output if fp else output.to(image1.dtype) +@_register_explicit_noop(datapoints.BoundingBox, datapoints.Mask) def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(adjust_brightness) @@ -87,9 +88,6 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) ) -_register_explicit_noop(adjust_brightness, datapoints.BoundingBox, datapoints.Mask) - - @_register_kernel_internal(adjust_brightness, datapoints.Image) def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 23eeb626b9b..4d34c3624e9 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -77,7 +77,7 @@ def _noop(inpt, *args, **kwargs): return inpt -def _register_explicit_noop(dispatcher, *datapoints_classes): +def _register_explicit_noop(*datapoints_classes): """ Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. @@ -91,8 +91,13 @@ def _register_explicit_noop(dispatcher, *datapoints_classes): def lol(...): ... """ - for cls in datapoints_classes: - register_kernel(dispatcher, cls)(_noop) + + def decorator(dispatcher): + for cls in datapoints_classes: + register_kernel(dispatcher, cls)(_noop) + return dispatcher + + return decorator def _get_kernel(dispatcher, datapoint_cls): From 6ac08e45815dbb5bb07e5d7b06c3b77857f79bd2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Aug 2023 10:15:39 +0200 Subject: [PATCH 10/22] explicit metadata --- torchvision/transforms/v2/functional/_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 4f044cd2654..5c90a1ba1fb 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -20,8 +20,11 @@ def _kernel_wrapper_internal(dispatcher, kernel): needs_args_kwargs_handling = kernel_params != dispatcher_params - # this avoids converting list -> set at runtime below kernel_params = set(kernel_params) + explicit_metadata = { + input_type: available_metadata & kernel_params + for input_type, available_metadata in [(datapoints.BoundingBoxes, {"format", "canvas_size"})] + } @functools.wraps(kernel) def wrapper(inpt, *args, **kwargs): @@ -39,7 +42,7 @@ def wrapper(inpt, *args, **kwargs): # add parameters that are passed implicitly to the dispatcher as metadata, # but have to be explicit for the kernel - for kwarg in input_type.__annotations__.keys() & kernel_params: + for kwarg in explicit_metadata.get(input_type, set()): kwargs[kwarg] = getattr(inpt, kwarg) output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) From cac079ba2cc97f658810ddd93ec9d2a958bbda1f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Aug 2023 10:40:20 +0200 Subject: [PATCH 11/22] implement dispatchers for erase five/ten_crop and temporal_subsample --- torchvision/transforms/v2/_augment.py | 2 - torchvision/transforms/v2/_geometry.py | 14 --- torchvision/transforms/v2/_temporal.py | 5 +- .../transforms/v2/functional/_augment.py | 60 +++++------ .../transforms/v2/functional/_geometry.py | 102 +++++++++--------- .../transforms/v2/functional/_temporal.py | 29 ++--- .../transforms/v2/functional/_utils.py | 15 ++- 7 files changed, 112 insertions(+), 115 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 3291c2f5004..7779a68afda 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -56,8 +56,6 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: value="random" if self.value is None else self.value, ) - _transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video) - def __init__( self, p: float = 0.5, diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 9e7ca64d41c..f9485f887f2 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -354,13 +354,6 @@ class FiveCrop(Transform): _v1_transform_cls = _transforms.FiveCrop - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -401,13 +394,6 @@ class TenCrop(Transform): _v1_transform_cls = _transforms.TenCrop - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index df4ad66643a..868314e9e33 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -1,10 +1,9 @@ from typing import Any, Dict +import torch from torchvision import datapoints from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2.utils import is_simple_tensor - class UniformTemporalSubsample(Transform): """[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video. @@ -20,7 +19,7 @@ class UniformTemporalSubsample(Transform): num_samples (int): The number of equispaced samples to be selected """ - _transformed_types = (is_simple_tensor, datapoints.Video) + _transformed_types = (torch.Tensor,) def __init__(self, num_samples: int): super().__init__() diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 9aedae814bd..35003a3e809 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -7,9 +7,37 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, future_warning=True) +def erase( + inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], + i: int, + j: int, + h: int, + w: int, + v: torch.Tensor, + inplace: bool = False, +) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(erase) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(erase, type(inpt)) + return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + elif isinstance(inpt, PIL.Image.Image): + return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(erase, datapoints.Image) def erase_image_tensor( image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> torch.Tensor: @@ -29,36 +57,8 @@ def erase_image_pil( return to_pil_image(output, mode=image.mode) +@_register_kernel_internal(erase, datapoints.Video) def erase_video( video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> torch.Tensor: return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - - -def erase( - inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], - i: int, - j: int, - h: int, - w: int, - v: torch.Tensor, - inplace: bool = False, -) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: - if not torch.jit.is_scripting(): - _log_api_usage_once(erase) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - elif isinstance(inpt, datapoints.Image): - output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) - return datapoints.Image.wrap_like(inpt, output) - elif isinstance(inpt, datapoints.Video): - output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace) - return datapoints.Video.wrap_like(inpt, output) - elif isinstance(inpt, PIL.Image.Image): - return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - else: - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index aebfc180a99..b63c4080ec8 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,7 +25,7 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -1968,6 +1968,30 @@ def resized_crop( ) +ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT] + + +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, future_warning=True) +def five_crop( + inpt: ImageOrVideoTypeJIT, size: List[int] +) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(five_crop) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return five_crop_image_tensor(inpt, size) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(five_crop, type(inpt)) + return kernel(inpt, size) + elif isinstance(inpt, PIL.Image.Image): + return five_crop_image_pil(inpt, size) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): s = int(size) @@ -1982,6 +2006,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: return size +@_register_kernel_internal(five_crop, datapoints.Image) def five_crop_image_tensor( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2019,38 +2044,46 @@ def five_crop_image_pil( return tl, tr, bl, br, center +@_register_kernel_internal(five_crop, datapoints.Video) def five_crop_video( video: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return five_crop_image_tensor(video, size) -ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT] - - -def five_crop( - inpt: ImageOrVideoTypeJIT, size: List[int] -) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, future_warning=True) +def ten_crop( + inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False +) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, +]: if not torch.jit.is_scripting(): - _log_api_usage_once(five_crop) + _log_api_usage_once(ten_crop) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return five_crop_image_tensor(inpt, size) - elif isinstance(inpt, datapoints.Image): - output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) - return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] - elif isinstance(inpt, datapoints.Video): - output = five_crop_video(inpt.as_subclass(torch.Tensor), size) - return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] + return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(ten_crop, type(inpt)) + return kernel(inpt, size, vertical_flip=vertical_flip) elif isinstance(inpt, PIL.Image.Image): - return five_crop_image_pil(inpt, size) + return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) else: raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"but got {type(inpt)} instead." ) +@_register_kernel_internal(ten_crop, datapoints.Image) def ten_crop_image_tensor( image: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2104,6 +2137,7 @@ def ten_crop_image_pil( return non_flipped + flipped +@_register_kernel_internal(ten_crop, datapoints.Video) def ten_crop_video( video: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2119,37 +2153,3 @@ def ten_crop_video( torch.Tensor, ]: return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) - - -def ten_crop( - inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False -) -> Tuple[ - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, -]: - if not torch.jit.is_scripting(): - _log_api_usage_once(ten_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - elif isinstance(inpt, datapoints.Image): - output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) - return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] - elif isinstance(inpt, datapoints.Video): - output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip) - return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value] - elif isinstance(inpt, PIL.Image.Image): - return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) - else: - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 5612a38779e..37bd5e810e8 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -4,24 +4,29 @@ from torchvision.utils import _log_api_usage_once -from ._utils import is_simple_tensor - - -def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: - # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 - t_max = video.shape[-4] - 1 - indices = torch.linspace(0, t_max, num_samples, device=video.device).long() - return torch.index_select(video, -4, indices) +from ._utils import _get_kernel, _register_explicit_noop, is_simple_tensor, register_kernel +@_register_explicit_noop(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, future_warning=True) def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(uniform_temporal_subsample) if torch.jit.is_scripting() or is_simple_tensor(inpt): return uniform_temporal_subsample_video(inpt, num_samples) - elif isinstance(inpt, datapoints.Video): - output = uniform_temporal_subsample_video(inpt.as_subclass(torch.Tensor), num_samples) - return datapoints.Video.wrap_like(inpt, output) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) + return kernel(inpt, num_samples) else: - raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@register_kernel(uniform_temporal_subsample, datapoints.Video) +def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: + # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 + t_max = video.shape[-4] - 1 + indices = torch.linspace(0, t_max, num_samples, device=video.device).long() + return torch.index_select(video, -4, indices) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 5c90a1ba1fb..2967cc428cf 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,5 +1,6 @@ import functools import inspect +import warnings from typing import Any import torch @@ -76,11 +77,13 @@ def register_kernel(dispatcher, datapoint_cls): return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False) -def _noop(inpt, *args, **kwargs): +def _noop(inpt, *args, __future_warning__=None, **kwargs): + if __future_warning__: + warnings.warn(__future_warning__, FutureWarning, stacklevel=2) return inpt -def _register_explicit_noop(*datapoints_classes): +def _register_explicit_noop(*datapoints_classes, future_warning=False): """ Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. @@ -97,7 +100,13 @@ def lol(...): def decorator(dispatcher): for cls in datapoints_classes: - register_kernel(dispatcher, cls)(_noop) + msg = ( + f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " + f"This will likely change in the future." + ) + register_kernel(dispatcher, cls)( + functools.partial(_noop, __future_warning__=msg if future_warning else None) + ) return dispatcher return decorator From c7256b4ec195a5aaee03706c46a5a749413bfe96 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Aug 2023 11:21:28 +0200 Subject: [PATCH 12/22] make shape getters proper dispatchers --- test/common_utils.py | 4 + test/test_transforms_v2_refactored.py | 90 +++++++++++ .../transforms/v2/functional/__init__.py | 1 + torchvision/transforms/v2/functional/_meta.py | 145 +++++++++--------- .../transforms/v2/functional/_utils.py | 14 ++ 5 files changed, 182 insertions(+), 72 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index b5edda3edb2..493233404ae 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -818,6 +818,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs): return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs)) +def make_video_tensor(*args, **kwargs): + return make_video(*args, **kwargs).as_subclass(torch.Tensor) + + def make_video_loader( size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c42018885f0..7d4dfdcdabb 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -26,6 +26,7 @@ make_image_tensor, make_segmentation_mask, make_video, + make_video_tensor, needs_cuda, set_rng_seed, ) @@ -2119,3 +2120,92 @@ def test_labels_getter_default_heuristic(key, sample_type): # it takes precedence over other keys which would otherwise be a match d = {key: "something_else", "labels": labels} assert transforms._utils._find_labels_default_heuristic(d) is labels + + +class TestShapeGetters: + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.get_dimensions_image_tensor, make_image_tensor), + (F.get_dimensions_image_pil, make_image_pil), + (F.get_dimensions_image_tensor, make_image), + (F.get_dimensions_video, make_video), + ], + ) + def test_get_dimensions(self, kernel, make_input): + size = (10, 10) + color_space, num_channels = "RGB", 3 + + input = make_input(size, color_space=color_space) + + assert kernel(input) == F.get_dimensions(input) == [num_channels, *size] + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.get_num_channels_image_tensor, make_image_tensor), + (F.get_num_channels_image_pil, make_image_pil), + (F.get_num_channels_image_tensor, make_image), + (F.get_num_channels_video, make_video), + ], + ) + def test_get_num_channels(self, kernel, make_input): + color_space, num_channels = "RGB", 3 + + input = make_input(color_space=color_space) + + assert kernel(input) == F.get_num_channels(input) == num_channels + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.get_size_image_tensor, make_image_tensor), + (F.get_size_image_pil, make_image_pil), + (F.get_size_image_tensor, make_image), + (F.get_size_bounding_boxes, make_bounding_box), + (F.get_size_mask, make_detection_mask), + (F.get_size_mask, make_segmentation_mask), + (F.get_size_video, make_video), + ], + ) + def test_get_size(self, kernel, make_input): + size = (10, 10) + + input = make_input(size) + + assert kernel(input) == F.get_size(input) == list(size) + + @pytest.mark.parametrize( + ("kernel", "make_input"), + [ + (F.get_num_frames_video, make_video_tensor), + (F.get_num_frames_video, make_video), + ], + ) + def test_get_num_frames(self, kernel, make_input): + num_frames = 4 + + input = make_input(num_frames=num_frames) + + assert kernel(input) == F.get_num_frames(input) == num_frames + + @pytest.mark.parametrize( + ("dispatcher", "make_input"), + [ + (F.get_dimensions, make_bounding_box), + (F.get_dimensions, make_detection_mask), + (F.get_dimensions, make_segmentation_mask), + (F.get_num_channels, make_bounding_box), + (F.get_num_channels, make_detection_mask), + (F.get_num_channels, make_segmentation_mask), + (F.get_num_frames, make_image), + (F.get_num_frames, make_bounding_box), + (F.get_num_frames, make_detection_mask), + (F.get_num_frames, make_segmentation_mask), + ], + ) + def test_unsupported_types(self, dispatcher, make_input): + input = make_input() + + with pytest.raises(TypeError, match=re.escape(str(type(input)))): + dispatcher(input) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index ddd1d57f353..163a55fad38 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -7,6 +7,7 @@ convert_format_bounding_boxes, get_dimensions_image_tensor, get_dimensions_image_pil, + get_dimensions_video, get_dimensions, get_num_frames_video, get_num_frames, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 91b370675b9..5e0777fc852 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -8,9 +8,29 @@ from torchvision.utils import _log_api_usage_once -from ._utils import is_simple_tensor +from ._utils import _get_kernel, _register_kernel_internal, _register_unsupported_type, is_simple_tensor +@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) +def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_dimensions) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return get_dimensions_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(get_dimensions, type(inpt)) + return kernel(inpt) + elif isinstance(inpt, PIL.Image.Image): + return get_dimensions_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(get_dimensions, datapoints.Image, wrap_kernel=False) def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: chw = list(image.shape[-3:]) ndims = len(chw) @@ -26,31 +46,31 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: get_dimensions_image_pil = _FP.get_dimensions +@_register_kernel_internal(get_dimensions, datapoints.Video, wrap_kernel=False) def get_dimensions_video(video: torch.Tensor) -> List[int]: return get_dimensions_image_tensor(video) -def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: +@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) +def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: if not torch.jit.is_scripting(): - _log_api_usage_once(get_dimensions) + _log_api_usage_once(get_num_channels) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return get_dimensions_image_tensor(inpt) - - for typ, get_size_fn in { - datapoints.Image: get_dimensions_image_tensor, - datapoints.Video: get_dimensions_video, - PIL.Image.Image: get_dimensions_image_pil, - }.items(): - if isinstance(inpt, typ): - return get_size_fn(inpt) - - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) + return get_num_channels_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(get_num_channels, type(inpt)) + return kernel(inpt) + elif isinstance(inpt, PIL.Image.Image): + return get_num_channels_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) +@_register_kernel_internal(get_num_channels, datapoints.Image, wrap_kernel=False) def get_num_channels_image_tensor(image: torch.Tensor) -> int: chw = image.shape[-3:] ndims = len(chw) @@ -65,36 +85,35 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: get_num_channels_image_pil = _FP.get_image_num_channels +@_register_kernel_internal(get_num_channels, datapoints.Video, wrap_kernel=False) def get_num_channels_video(video: torch.Tensor) -> int: return get_num_channels_image_tensor(video) -def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_num_channels) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return get_num_channels_image_tensor(inpt) - - for typ, get_size_fn in { - datapoints.Image: get_num_channels_image_tensor, - datapoints.Video: get_num_channels_video, - PIL.Image.Image: get_num_channels_image_pil, - }.items(): - if isinstance(inpt, typ): - return get_size_fn(inpt) - - raise TypeError( - f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without # deprecating the old names. get_image_num_channels = get_num_channels +def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: + if not torch.jit.is_scripting(): + _log_api_usage_once(get_size) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return get_size_image_tensor(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(get_size, type(inpt)) + return kernel(inpt) + elif isinstance(inpt, PIL.Image.Image): + return get_size_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(get_size, datapoints.Image, wrap_kernel=False) def get_size_image_tensor(image: torch.Tensor) -> List[int]: hw = list(image.shape[-2:]) ndims = len(hw) @@ -110,59 +129,41 @@ def get_size_image_pil(image: PIL.Image.Image) -> List[int]: return [height, width] +@_register_kernel_internal(get_size, datapoints.Video, wrap_kernel=False) def get_size_video(video: torch.Tensor) -> List[int]: return get_size_image_tensor(video) +@_register_kernel_internal(get_size, datapoints.Mask, wrap_kernel=False) def get_size_mask(mask: torch.Tensor) -> List[int]: return get_size_image_tensor(mask) -@torch.jit.unused +@_register_kernel_internal(get_size, datapoints.BoundingBoxes, wrap_kernel=False) def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]: return list(bounding_box.canvas_size) -def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: - if not torch.jit.is_scripting(): - _log_api_usage_once(get_size) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return get_size_image_tensor(inpt) - - # TODO: This is just the poor mans version of a dispatcher. This will be properly addressed with - # https://github.com/pytorch/vision/pull/7747 when we can register the kernels above without the need to have - # a method on the datapoint class - for typ, get_size_fn in { - datapoints.Image: get_size_image_tensor, - datapoints.BoundingBoxes: get_size_bounding_boxes, - datapoints.Mask: get_size_mask, - datapoints.Video: get_size_video, - PIL.Image.Image: get_size_image_pil, - }.items(): - if isinstance(inpt, typ): - return get_size_fn(inpt) - - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - -def get_num_frames_video(video: torch.Tensor) -> int: - return video.shape[-4] - - +@_register_unsupported_type(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: if not torch.jit.is_scripting(): _log_api_usage_once(get_num_frames) if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_num_frames_video(inpt) - elif isinstance(inpt, datapoints.Video): - return get_num_frames_video(inpt) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(get_num_frames, type(inpt)) + return kernel(inpt) else: - raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(get_num_frames, datapoints.Video, wrap_kernel=False) +def get_num_frames_video(video: torch.Tensor) -> int: + return video.shape[-4] def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 2967cc428cf..2dda82e69a6 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -112,6 +112,20 @@ def decorator(dispatcher): return decorator +# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that +# to error later, this decorator can be removed, since the error will be raised by _get_kernel +def _register_unsupported_type(*datapoints_classes): + def kernel(inpt, *args, __dispatcher_name__, **kwargs): + raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.") + + def decorator(dispatcher): + for cls in datapoints_classes: + register_kernel(dispatcher, cls)(functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)) + return dispatcher + + return decorator + + def _get_kernel(dispatcher, datapoint_cls): registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: From bf78cd6c5fab6569519f74649dd012942f320730 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Aug 2023 14:08:28 +0200 Subject: [PATCH 13/22] fix --- .../transforms/v2/functional/_geometry.py | 16 ++++++++---- .../transforms/v2/functional/_temporal.py | 4 +-- .../transforms/v2/functional/_utils.py | 26 +++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b63c4080ec8..9c923af118f 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,7 +25,13 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor +from ._utils import ( + _get_kernel, + _register_explicit_noop, + _register_five_ten_crop_kernel, + _register_kernel_internal, + is_simple_tensor, +) def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -2006,7 +2012,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: return size -@_register_kernel_internal(five_crop, datapoints.Image) +@_register_five_ten_crop_kernel(five_crop, datapoints.Image) def five_crop_image_tensor( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2044,7 +2050,7 @@ def five_crop_image_pil( return tl, tr, bl, br, center -@_register_kernel_internal(five_crop, datapoints.Video) +@_register_five_ten_crop_kernel(five_crop, datapoints.Video) def five_crop_video( video: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2083,7 +2089,7 @@ def ten_crop( ) -@_register_kernel_internal(ten_crop, datapoints.Image) +@_register_five_ten_crop_kernel(ten_crop, datapoints.Image) def ten_crop_image_tensor( image: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2137,7 +2143,7 @@ def ten_crop_image_pil( return non_flipped + flipped -@_register_kernel_internal(ten_crop, datapoints.Video) +@_register_five_ten_crop_kernel(ten_crop, datapoints.Video) def ten_crop_video( video: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 37bd5e810e8..2e4a6f7eadd 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -4,7 +4,7 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, is_simple_tensor, register_kernel +from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor @_register_explicit_noop(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, future_warning=True) @@ -24,7 +24,7 @@ def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) ) -@register_kernel(uniform_temporal_subsample, datapoints.Video) +@_register_kernel_internal(uniform_temporal_subsample, datapoints.Video) def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor: # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 t_max = video.shape[-4] - 1 diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 2dda82e69a6..ff11e2108ec 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -139,3 +139,29 @@ def _get_kernel(dispatcher, datapoint_cls): return kernel return _noop + + +# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop +# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool +# TODO: decide if we want that +def _register_five_ten_crop_kernel(dispatcher, datapoint_cls): + registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) + if datapoint_cls in registry: + raise TypeError( + f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'." + ) + + def wrap(kernel): + @functools.wraps(kernel) + def wrapper(inpt, *args, **kwargs): + output = kernel(inpt, *args, **kwargs) + container_type = type(output) + return container_type(type(inpt).wrap_like(inpt, o) for o in output) + + return wrapper + + def decorator(kernel): + registry[datapoint_cls] = wrap(kernel) + return kernel + + return decorator From f86f89b5696279492ecf7972986129535f60823a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 09:33:23 +0200 Subject: [PATCH 14/22] port normalize and to_dtype --- test/test_transforms_v2_refactored.py | 1 + torchvision/transforms/v2/functional/_meta.py | 2 +- torchvision/transforms/v2/functional/_misc.py | 91 ++++++++++++------- .../transforms/v2/functional/_temporal.py | 8 +- 4 files changed, 63 insertions(+), 39 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 7d4dfdcdabb..12adbcd26c8 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2198,6 +2198,7 @@ def test_get_num_frames(self, kernel, make_input): (F.get_num_channels, make_bounding_box), (F.get_num_channels, make_detection_mask), (F.get_num_channels, make_segmentation_mask), + (F.get_num_frames, make_image_pil), (F.get_num_frames, make_image), (F.get_num_frames, make_bounding_box), (F.get_num_frames, make_detection_mask), diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 5e0777fc852..1b12af3b401 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -144,7 +144,7 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] return list(bounding_box.canvas_size) -@_register_unsupported_type(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) +@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: if not torch.jit.is_scripting(): _log_api_usage_once(get_num_frames) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index cda85ba906e..6a100341b2e 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -11,9 +11,37 @@ from torchvision.utils import _log_api_usage_once -from ._utils import is_simple_tensor +from ._utils import ( + _get_kernel, + _register_explicit_noop, + _register_kernel_internal, + _register_unsupported_type, + is_simple_tensor, +) +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +@_register_unsupported_type(PIL.Image.Image) +def normalize( + inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT], + mean: List[float], + std: List[float], + inplace: bool = False, +) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(normalize) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(normalize, type(inpt)) + return kernel(inpt, mean=mean, std=std, inplace=inplace) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, " f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(normalize, datapoints.Image) def normalize_image_tensor( image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False ) -> torch.Tensor: @@ -49,28 +77,11 @@ def normalize_image_tensor( return image.div_(std) +@_register_kernel_internal(normalize, datapoints.Video) def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: return normalize_image_tensor(video, mean, std, inplace=inplace) -def normalize( - inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT], - mean: List[float], - std: List[float], - inplace: bool = False, -) -> torch.Tensor: - if not torch.jit.is_scripting(): - _log_api_usage_once(normalize) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) - elif isinstance(inpt, (datapoints.Image, datapoints.Video)): - return inpt.normalize(mean=mean, std=std, inplace=inplace) - else: - raise TypeError( - f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead." - ) - - def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma) x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) @@ -185,6 +196,23 @@ def gaussian_blur( ) +def to_dtype( + inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False +) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(to_dtype) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return to_dtype_image_tensor(inpt, dtype, scale=scale) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + kernel = _get_kernel(to_dtype, type(inpt)) + return kernel(inpt, dtype, scale=scale) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, " f"but got {type(inpt)} instead." + ) + + def _num_value_bits(dtype: torch.dtype) -> int: if dtype == torch.uint8: return 8 @@ -200,6 +228,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") +@_register_kernel_internal(to_dtype, datapoints.Image) def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: if image.dtype == dtype: @@ -257,23 +286,15 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) return to_dtype_image_tensor(image, dtype=dtype, scale=True) +@_register_kernel_internal(to_dtype, datapoints.Video) def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: return to_dtype_image_tensor(video, dtype, scale=scale) -def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: - if not torch.jit.is_scripting(): - _log_api_usage_once(to_dtype) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return to_dtype_image_tensor(inpt, dtype, scale=scale) - elif isinstance(inpt, datapoints.Image): - output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale) - return datapoints.Image.wrap_like(inpt, output) - elif isinstance(inpt, datapoints.Video): - output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale) - return datapoints.Video.wrap_like(inpt, output) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.to(dtype) - else: - raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.") +@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, wrap_kernel=False) +@_register_kernel_internal(to_dtype, datapoints.Mask, wrap_kernel=False) +def _to_dtype_tensor_dispatch( + inpt: datapoints._InputTypeJIT, dtype: torch.dtype, scale: bool = False +) -> datapoints._InputTypeJIT: + # We don't need to unwrap and rewrap here, since Datapoint.to() preserves the type + return inpt.to(dtype) diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 2e4a6f7eadd..803d2a2e250 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -1,3 +1,4 @@ +import PIL.Image import torch from torchvision import datapoints @@ -7,7 +8,9 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor -@_register_explicit_noop(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, future_warning=True) +@_register_explicit_noop( + PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, future_warning=True +) def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: if not torch.jit.is_scripting(): _log_api_usage_once(uniform_temporal_subsample) @@ -19,8 +22,7 @@ def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) return kernel(inpt, num_samples) else: raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." + f"Input can either be a plain tensor, any TorchVision datapoint, " f"but got {type(inpt)} instead." ) From d90daf6c55f2375bc3311900e8e53da731412355 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 09:36:44 +0200 Subject: [PATCH 15/22] address comments --- torchvision/transforms/v2/functional/_utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ff11e2108ec..77291666ca9 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -77,9 +77,9 @@ def register_kernel(dispatcher, datapoint_cls): return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False) -def _noop(inpt, *args, __future_warning__=None, **kwargs): - if __future_warning__: - warnings.warn(__future_warning__, FutureWarning, stacklevel=2) +def _noop(inpt, *args, __msg__=None, **kwargs): + if __msg__: + warnings.warn(__msg__, UserWarning, stacklevel=2) return inpt @@ -104,9 +104,7 @@ def decorator(dispatcher): f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"This will likely change in the future." ) - register_kernel(dispatcher, cls)( - functools.partial(_noop, __future_warning__=msg if future_warning else None) - ) + register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if future_warning else None)) return dispatcher return decorator From 09eec9adddd4e0c4d7050542627567bb7a0aa62e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 10:06:45 +0200 Subject: [PATCH 16/22] address comments and cleanup --- .../transforms/v2/functional/_augment.py | 2 +- .../transforms/v2/functional/_geometry.py | 26 ++++- torchvision/transforms/v2/functional/_meta.py | 18 ++-- torchvision/transforms/v2/functional/_misc.py | 4 +- .../transforms/v2/functional/_temporal.py | 2 +- .../transforms/v2/functional/_utils.py | 101 +++++++----------- 6 files changed, 70 insertions(+), 83 deletions(-) diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 35003a3e809..d70be15985d 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -10,7 +10,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor -@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, future_warning=True) +@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) def erase( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], i: int, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 9c923af118f..b0a3cfd354a 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import PIL.Image import torch @@ -291,7 +291,6 @@ def resize_image_pil( return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation]) -@_register_kernel_internal(resize, datapoints.Mask) def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -307,7 +306,14 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N return output -@_register_kernel_internal(resize, datapoints.BoundingBoxes) +@_register_kernel_internal(resize, datapoints.Mask, datapoint_wrapper=False) +def _resize_mask_dispatch( + inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any +) -> datapoints.Mask: + output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size) + return datapoints.Mask.wrap_like(inpt, output) + + def resize_bounding_boxes( bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -326,6 +332,16 @@ def resize_bounding_boxes( ) +@_register_kernel_internal(resize, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _resize_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any +) -> datapoints.BoundingBoxes: + output, canvas_size = resize_bounding_boxes( + inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + + @_register_kernel_internal(resize, datapoints.Video) def resize_video( video: torch.Tensor, @@ -1977,7 +1993,7 @@ def resized_crop( ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT] -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, future_warning=True) +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def five_crop( inpt: ImageOrVideoTypeJIT, size: List[int] ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: @@ -2057,7 +2073,7 @@ def five_crop_video( return five_crop_image_tensor(video, size) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, future_warning=True) +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def ten_crop( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False ) -> Tuple[ diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 1b12af3b401..7148c003b28 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -30,7 +30,7 @@ def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJI ) -@_register_kernel_internal(get_dimensions, datapoints.Image, wrap_kernel=False) +@_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: chw = list(image.shape[-3:]) ndims = len(chw) @@ -46,7 +46,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: get_dimensions_image_pil = _FP.get_dimensions -@_register_kernel_internal(get_dimensions, datapoints.Video, wrap_kernel=False) +@_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) def get_dimensions_video(video: torch.Tensor) -> List[int]: return get_dimensions_image_tensor(video) @@ -70,7 +70,7 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType ) -@_register_kernel_internal(get_num_channels, datapoints.Image, wrap_kernel=False) +@_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) def get_num_channels_image_tensor(image: torch.Tensor) -> int: chw = image.shape[-3:] ndims = len(chw) @@ -85,7 +85,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: get_num_channels_image_pil = _FP.get_image_num_channels -@_register_kernel_internal(get_num_channels, datapoints.Video, wrap_kernel=False) +@_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) def get_num_channels_video(video: torch.Tensor) -> int: return get_num_channels_image_tensor(video) @@ -113,7 +113,7 @@ def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: ) -@_register_kernel_internal(get_size, datapoints.Image, wrap_kernel=False) +@_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) def get_size_image_tensor(image: torch.Tensor) -> List[int]: hw = list(image.shape[-2:]) ndims = len(hw) @@ -129,17 +129,17 @@ def get_size_image_pil(image: PIL.Image.Image) -> List[int]: return [height, width] -@_register_kernel_internal(get_size, datapoints.Video, wrap_kernel=False) +@_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False) def get_size_video(video: torch.Tensor) -> List[int]: return get_size_image_tensor(video) -@_register_kernel_internal(get_size, datapoints.Mask, wrap_kernel=False) +@_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False) def get_size_mask(mask: torch.Tensor) -> List[int]: return get_size_image_tensor(mask) -@_register_kernel_internal(get_size, datapoints.BoundingBoxes, wrap_kernel=False) +@_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False) def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]: return list(bounding_box.canvas_size) @@ -161,7 +161,7 @@ def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: ) -@_register_kernel_internal(get_num_frames, datapoints.Video, wrap_kernel=False) +@_register_kernel_internal(get_num_frames, datapoints.Video, datapoint_wrapper=False) def get_num_frames_video(video: torch.Tensor) -> int: return video.shape[-4] diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 6a100341b2e..af179c0aa99 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -291,8 +291,8 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: return to_dtype_image_tensor(video, dtype, scale=scale) -@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, wrap_kernel=False) -@_register_kernel_internal(to_dtype, datapoints.Mask, wrap_kernel=False) +@_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False) +@_register_kernel_internal(to_dtype, datapoints.Mask, datapoint_wrapper=False) def _to_dtype_tensor_dispatch( inpt: datapoints._InputTypeJIT, dtype: torch.dtype, scale: bool = False ) -> datapoints._InputTypeJIT: diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 803d2a2e250..81d6793f179 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -9,7 +9,7 @@ @_register_explicit_noop( - PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, future_warning=True + PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True ) def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: if not torch.jit.is_scripting(): diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 77291666ca9..e8201a05cea 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,10 +1,8 @@ import functools -import inspect import warnings -from typing import Any +from typing import Any, Callable, Dict, Type import torch -from torchvision import datapoints from torchvision.datapoints._datapoint import Datapoint @@ -12,54 +10,19 @@ def is_simple_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) -_KERNEL_REGISTRY = {} +_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} -def _kernel_wrapper_internal(dispatcher, kernel): - dispatcher_params = list(inspect.signature(dispatcher).parameters)[1:] - kernel_params = list(inspect.signature(kernel).parameters)[1:] - - needs_args_kwargs_handling = kernel_params != dispatcher_params - - kernel_params = set(kernel_params) - explicit_metadata = { - input_type: available_metadata & kernel_params - for input_type, available_metadata in [(datapoints.BoundingBoxes, {"format", "canvas_size"})] - } - +def _kernel_datapoint_wrapper(kernel): @functools.wraps(kernel) def wrapper(inpt, *args, **kwargs): - input_type = type(inpt) - - if needs_args_kwargs_handling: - # Convert args to kwargs to simplify further processing - kwargs.update(dict(zip(dispatcher_params, args))) - args = () - - # drop parameters that are not relevant for the kernel, but have a default value - # in the dispatcher - for kwarg in kwargs.keys() - kernel_params: - del kwargs[kwarg] - - # add parameters that are passed implicitly to the dispatcher as metadata, - # but have to be explicit for the kernel - for kwarg in explicit_metadata.get(input_type, set()): - kwargs[kwarg] = getattr(inpt, kwarg) - output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) - - if isinstance(inpt, datapoints.BoundingBoxes) and isinstance(output, tuple): - output, canvas_size = output - metadata = dict(canvas_size=canvas_size) - else: - metadata = dict() - - return input_type.wrap_like(inpt, output, **metadata) + return type(inpt).wrap_like(inpt, output) return wrapper -def _register_kernel_internal(dispatcher, datapoint_cls, *, wrap_kernel=True): +def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=True): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) if datapoint_cls in registry: raise TypeError( @@ -67,23 +30,40 @@ def _register_kernel_internal(dispatcher, datapoint_cls, *, wrap_kernel=True): ) def decorator(kernel): - registry[datapoint_cls] = _kernel_wrapper_internal(dispatcher, kernel) if wrap_kernel else kernel + registry[datapoint_cls] = _kernel_datapoint_wrapper(kernel) if datapoint_wrapper else kernel return kernel return decorator def register_kernel(dispatcher, datapoint_cls): - return _register_kernel_internal(dispatcher, datapoint_cls, wrap_kernel=False) + return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) -def _noop(inpt, *args, __msg__=None, **kwargs): - if __msg__: - warnings.warn(__msg__, UserWarning, stacklevel=2) - return inpt +def _get_kernel(dispatcher, datapoint_cls): + registry = _KERNEL_REGISTRY.get(dispatcher) + if not registry: + raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") + + if datapoint_cls in registry: + return registry[datapoint_cls] + + for registered_cls, kernel in registry.items(): + if issubclass(datapoint_cls, registered_cls): + return kernel + return _noop -def _register_explicit_noop(*datapoints_classes, future_warning=False): + +# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate +# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details. + + +# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we +# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher. +# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside +# register_kernel. +def _register_explicit_noop(*datapoints_classes, warn_passthrough=False): """ Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. @@ -104,12 +84,18 @@ def decorator(dispatcher): f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"This will likely change in the future." ) - register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if future_warning else None)) + register_kernel(dispatcher, cls)(functools.partial(_noop, __msg__=msg if warn_passthrough else None)) return dispatcher return decorator +def _noop(inpt, *args, __msg__=None, **kwargs): + if __msg__: + warnings.warn(__msg__, UserWarning, stacklevel=2) + return inpt + + # TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that # to error later, this decorator can be removed, since the error will be raised by _get_kernel def _register_unsupported_type(*datapoints_classes): @@ -124,21 +110,6 @@ def decorator(dispatcher): return decorator -def _get_kernel(dispatcher, datapoint_cls): - registry = _KERNEL_REGISTRY.get(dispatcher) - if not registry: - raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.") - - if datapoint_cls in registry: - return registry[datapoint_cls] - - for registered_cls, kernel in registry.items(): - if issubclass(datapoint_cls, registered_cls): - return kernel - - return _noop - - # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool # TODO: decide if we want that From 3730811feb3e688e276374df1d33070f5a44912d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 10:26:50 +0200 Subject: [PATCH 17/22] more cleanup --- torchvision/transforms/v2/_augment.py | 6 ++-- torchvision/transforms/v2/_geometry.py | 19 ++--------- .../transforms/v2/functional/_geometry.py | 33 ++++++++++--------- torchvision/transforms/v2/functional/_misc.py | 4 +-- 4 files changed, 24 insertions(+), 38 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 7779a68afda..4687ba5dc0f 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple import PIL.Image import torch @@ -129,9 +129,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) - def _transform( - self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] - ) -> Union[datapoints._ImageType, datapoints._VideoType]: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: inpt = F.erase(inpt, **params, inplace=self.inplace) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index f9485f887f2..0354ff2b382 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -358,9 +358,7 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _transform( - self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any] - ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.five_crop(inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: @@ -403,20 +401,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") - def _transform( - self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] - ) -> Tuple[ - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ]: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b0a3cfd354a..2d7d83e31f6 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1990,13 +1990,16 @@ def resized_crop( ) -ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT] - - @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def five_crop( - inpt: ImageOrVideoTypeJIT, size: List[int] -) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: + inpt: datapoints._InputTypeJIT, size: List[int] +) -> Tuple[ + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, +]: if not torch.jit.is_scripting(): _log_api_usage_once(five_crop) @@ -2077,16 +2080,16 @@ def five_crop_video( def ten_crop( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False ) -> Tuple[ - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, - ImageOrVideoTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, + datapoints._InputTypeJIT, ]: if not torch.jit.is_scripting(): _log_api_usage_once(ten_crop) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index af179c0aa99..c7ca4fe1d62 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -37,7 +37,7 @@ def normalize( return kernel(inpt, mean=mean, std=std, inplace=inplace) else: raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, " f"but got {type(inpt)} instead." + f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." ) @@ -209,7 +209,7 @@ def to_dtype( return kernel(inpt, dtype, scale=scale) else: raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, " f"but got {type(inpt)} instead." + f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." ) From 31bee5f6432148750f2ca1185891eff7c845bdd8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 28 Jul 2023 09:54:31 +0200 Subject: [PATCH 18/22] port all remaining dispatchers to the new mechanism --- test/datasets_utils.py | 6 +- test/test_transforms_v2_functional.py | 49 +- test/test_transforms_v2_refactored.py | 95 ++- test/transforms_v2_dispatcher_infos.py | 17 +- torchvision/datapoints/__init__.py | 2 +- torchvision/datapoints/_bounding_box.py | 128 +--- torchvision/datapoints/_datapoint.py | 124 ---- torchvision/datapoints/_image.py | 176 +----- torchvision/datapoints/_mask.py | 97 +-- torchvision/datapoints/_video.py | 170 +----- .../transforms/v2/functional/_augment.py | 2 +- .../transforms/v2/functional/_color.py | 253 ++++---- .../transforms/v2/functional/_geometry.py | 556 +++++++++++++----- torchvision/transforms/v2/functional/_meta.py | 8 +- torchvision/transforms/v2/functional/_misc.py | 46 +- .../transforms/v2/functional/_temporal.py | 4 +- .../transforms/v2/functional/_utils.py | 4 +- 17 files changed, 644 insertions(+), 1093 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index ab325a8062e..b6f22d766df 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -567,7 +567,7 @@ def test_transforms(self, config): @test_all_configs def test_transforms_v2_wrapper(self, config): - from torchvision.datapoints._datapoint import Datapoint + from torchvision import datapoints from torchvision.datasets import wrap_dataset_for_transforms_v2 try: @@ -588,7 +588,9 @@ def test_transforms_v2_wrapper(self, config): assert len(wrapped_dataset) == info["num_examples"] wrapped_sample = wrapped_dataset[0] - assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) + assert tree_any( + lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample + ) except TypeError as error: msg = f"No wrapper exists for dataset class {type(dataset).__name__}" if str(error).startswith(msg): diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 230695ff93e..3075efab9a0 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -3,8 +3,6 @@ import os import re -from typing import get_type_hints - import numpy as np import PIL.Image import pytest @@ -417,22 +415,6 @@ def test_pil_output_type(self, info, args_kwargs): assert isinstance(output, PIL.Image.Image) - @make_info_args_kwargs_parametrization( - DISPATCHER_INFOS, - args_kwargs_fn=lambda info: info.sample_inputs(), - ) - def test_dispatch_datapoint(self, info, args_kwargs, spy_on): - (datapoint, *other_args), kwargs = args_kwargs.load() - - method_name = info.id - method = getattr(datapoint, method_name) - datapoint_type = type(datapoint) - spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}") - - info.dispatcher(datapoint, *other_args, **kwargs) - - spy.assert_called_once() - @make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(), @@ -462,9 +444,12 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi kernel_params = list(kernel_signature.parameters.values())[1:] # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be - # explicit passed to the kernel. - datapoint_type_metadata = datapoint_type.__annotations__.keys() - kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata] + # explicitly passed to the kernel. + input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel) + explicit_metadata = { + datapoints.BoundingBoxes: {"format", "canvas_size"}, + } + kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())] dispatcher_params = iter(dispatcher_params) for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): @@ -481,28 +466,6 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi assert dispatcher_param == kernel_param - @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) - def test_dispatcher_datapoint_signatures_consistency(self, info): - try: - datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id) - except AttributeError: - pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.") - - dispatcher_signature = inspect.signature(info.dispatcher) - dispatcher_params = list(dispatcher_signature.parameters.values())[1:] - - datapoint_signature = inspect.signature(datapoint_method) - datapoint_params = list(datapoint_signature.parameters.values())[1:] - - # Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is - # defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the - # natively concrete dispatcher annotations. - datapoint_annotations = get_type_hints(datapoint_method) - for param in datapoint_params: - param._annotation = datapoint_annotations[param.name] - - assert dispatcher_params == datapoint_params - @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) def test_unkown_type(self, info): unkown_input = object() diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 36e3e8d2601..45668fda1ca 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -3,7 +3,6 @@ import inspect import math import re -from typing import get_type_hints from unittest import mock import numpy as np @@ -178,28 +177,28 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): """Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is preserved in doing so. For bounding boxes also checks that the format is preserved. """ - if isinstance(input, datapoints._datapoint.Datapoint): - if dispatcher in {F.resize, F.adjust_brightness}: + input_type = type(input) + + if isinstance(input, datapoints.Datapoint): + wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type] + + # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the + # proper kernel was wrapped + if hasattr(wrapped_kernel, "__wrapped__"): + assert wrapped_kernel.__wrapped__ is kernel + + spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) + with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}): output = dispatcher(input, *args, **kwargs) - else: - # Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly, - # but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel. - spy = mock.MagicMock(wraps=kernel, name=kernel.__name__) - with mock.patch.object(F, kernel.__name__, spy): - # Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class. - # Since that is not the case here, we need to prefix f"_{cls.__name__}" - # See https://docs.python.org/3/tutorial/classes.html#private-variables for details - with mock.patch.object(datapoints._datapoint.Datapoint, "_Datapoint__F", new=F): - output = dispatcher(input, *args, **kwargs) - spy.assert_called_once() + spy.assert_called_once() else: with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy: output = dispatcher(input, *args, **kwargs) spy.assert_called_once() - assert isinstance(output, type(input)) + assert isinstance(output, input_type) if isinstance(input, datapoints.BoundingBoxes): assert output.format == input.format @@ -214,15 +213,13 @@ def check_dispatcher( check_dispatch=True, **kwargs, ): + unknown_input = object() with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: - dispatcher(input, *args, **kwargs) + with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): + dispatcher(unknown_input, *args, **kwargs) spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") - unknown_input = object() - with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): - dispatcher(unknown_input, *args, **kwargs) - if check_scripted_smoke: _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) @@ -230,18 +227,18 @@ def check_dispatcher( _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs) -def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): +def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): """Checks if the signature of the dispatcher matches the kernel signature.""" - dispatcher_signature = inspect.signature(dispatcher) - dispatcher_params = list(dispatcher_signature.parameters.values())[1:] - - kernel_signature = inspect.signature(kernel) - kernel_params = list(kernel_signature.parameters.values())[1:] + dispatcher_params = list(inspect.signature(dispatcher).parameters.values())[1:] + kernel_params = list(inspect.signature(kernel).parameters.values())[1:] - if issubclass(input_type, datapoints._datapoint.Datapoint): + if issubclass(input_type, datapoints.Datapoint): # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be # explicitly passed to the kernel. - kernel_params = [param for param in kernel_params if param.name not in input_type.__annotations__.keys()] + explicit_metadata = { + datapoints.BoundingBoxes: {"format", "canvas_size"}, + } + kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())] dispatcher_params = iter(dispatcher_params) for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): @@ -264,32 +261,6 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): assert dispatcher_param == kernel_param -def _check_dispatcher_datapoint_signature_match(dispatcher): - """Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class.""" - if dispatcher in {F.resize, F.adjust_brightness}: - return - dispatcher_signature = inspect.signature(dispatcher) - dispatcher_params = list(dispatcher_signature.parameters.values())[1:] - - datapoint_method = getattr(datapoints._datapoint.Datapoint, dispatcher.__name__) - datapoint_signature = inspect.signature(datapoint_method) - datapoint_params = list(datapoint_signature.parameters.values())[1:] - - # Some annotations in the `datapoints._datapoint` module - # are stored as strings. The block below makes them concrete again (non-strings), so they can be compared to the - # natively concrete dispatcher annotations. - datapoint_annotations = get_type_hints(datapoint_method) - for param in datapoint_params: - param._annotation = datapoint_annotations[param.name] - - assert dispatcher_params == datapoint_params - - -def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type): - _check_dispatcher_kernel_signature_match(dispatcher, kernel=kernel, input_type=input_type) - _check_dispatcher_datapoint_signature_match(dispatcher) - - def _check_transform_v1_compatibility(transform, input): """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static ``get_params`` method, is scriptable, and the scripted version can be called without error.""" @@ -461,7 +432,7 @@ def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss): *[f"- {name}" for name in names], "", f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).", - f"If not, register explicit no-ops with @_register_explicit_noops({', '.join(names)})", + f"If not, register explicit no-ops with @_register_explicit_noop({', '.join(names)})", ] ) ) @@ -602,7 +573,7 @@ def test_dispatcher(self, size, kernel, make_input): ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_signatures_match(F.resize, kernel=kernel, input_type=input_type) + check_dispatcher_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("device", cpu_and_cuda()) @@ -800,7 +771,7 @@ def test_noop(self, size, make_input): # This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there # is a good reason to break this, feel free to downgrade to an equality check. - if isinstance(input, datapoints._datapoint.Datapoint): + if isinstance(input, datapoints.Datapoint): # We can't test identity directly, since that checks for the identity of the Python object. Since all # datapoints unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check # that the underlying storage is the same @@ -884,7 +855,7 @@ def test_dispatcher(self, kernel, make_input): ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_signatures_match(F.horizontal_flip, kernel=kernel, input_type=input_type) + check_dispatcher_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -1067,7 +1038,7 @@ def test_dispatcher(self, kernel, make_input): ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_signatures_match(F.affine, kernel=kernel, input_type=input_type) + check_dispatcher_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -1363,7 +1334,7 @@ def test_dispatcher(self, kernel, make_input): ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_signatures_match(F.vertical_flip, kernel=kernel, input_type=input_type) + check_dispatcher_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -1520,7 +1491,7 @@ def test_dispatcher(self, kernel, make_input): ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_signatures_match(F.rotate, kernel=kernel, input_type=input_type) + check_dispatcher_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "make_input", @@ -1971,7 +1942,7 @@ def test_dispatcher(self, kernel, make_input): ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_signatures_match(F.adjust_brightness, kernel=kernel, input_type=input_type) + check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS) def test_image_correctness(self, brightness_factor): diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index 74f20466b99..cef5c360430 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -69,14 +69,15 @@ def sample_inputs(self, *datapoint_types, filter_metadata=True): import itertools for args_kwargs in sample_inputs: - for name in itertools.chain( - datapoint_type.__annotations__.keys(), - # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a - # per-dispatcher level. However, so far there is no option for that. - (f"old_{name}" for name in datapoint_type.__annotations__.keys()), - ): - if name in args_kwargs.kwargs: - del args_kwargs.kwargs[name] + if hasattr(datapoint_type, "__annotations__"): + for name in itertools.chain( + datapoint_type.__annotations__.keys(), + # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a + # per-dispatcher level. However, so far there is no option for that. + (f"old_{name}" for name in datapoint_type.__annotations__.keys()), + ): + if name in args_kwargs.kwargs: + del args_kwargs.kwargs[name] yield args_kwargs diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py index fb51f0497ea..03469ca0cde 100644 --- a/torchvision/datapoints/__init__.py +++ b/torchvision/datapoints/__init__.py @@ -1,7 +1,7 @@ from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS from ._bounding_box import BoundingBoxes, BoundingBoxFormat -from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT +from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT, Datapoint from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._mask import Mask from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index 9ad5e6bc7c0..912cc3bca08 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -1,12 +1,11 @@ from __future__ import annotations from enum import Enum -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch -from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms -from ._datapoint import _FillTypeJIT, Datapoint +from ._datapoint import Datapoint class BoundingBoxFormat(Enum): @@ -97,126 +96,3 @@ def wrap_like( def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr(format=self.format, canvas_size=self.canvas_size) - - def horizontal_flip(self) -> BoundingBoxes: - output = self._F.horizontal_flip_bounding_boxes( - self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size - ) - return BoundingBoxes.wrap_like(self, output) - - def vertical_flip(self) -> BoundingBoxes: - output = self._F.vertical_flip_bounding_boxes( - self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size - ) - return BoundingBoxes.wrap_like(self, output) - - def crop(self, top: int, left: int, height: int, width: int) -> BoundingBoxes: - output, canvas_size = self._F.crop_bounding_boxes( - self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width - ) - return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) - - def center_crop(self, output_size: List[int]) -> BoundingBoxes: - output, canvas_size = self._F.center_crop_bounding_boxes( - self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size, output_size=output_size - ) - return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) - - def resized_crop( - self, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ) -> BoundingBoxes: - output, canvas_size = self._F.resized_crop_bounding_boxes( - self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size - ) - return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) - - def pad( - self, - padding: Union[int, Sequence[int]], - fill: Optional[Union[int, float, List[float]]] = None, - padding_mode: str = "constant", - ) -> BoundingBoxes: - output, canvas_size = self._F.pad_bounding_boxes( - self.as_subclass(torch.Tensor), - format=self.format, - canvas_size=self.canvas_size, - padding=padding, - padding_mode=padding_mode, - ) - return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) - - def rotate( - self, - angle: float, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[float]] = None, - fill: _FillTypeJIT = None, - ) -> BoundingBoxes: - output, canvas_size = self._F.rotate_bounding_boxes( - self.as_subclass(torch.Tensor), - format=self.format, - canvas_size=self.canvas_size, - angle=angle, - expand=expand, - center=center, - ) - return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size) - - def affine( - self, - angle: Union[int, float], - translate: List[float], - scale: float, - shear: List[float], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: _FillTypeJIT = None, - center: Optional[List[float]] = None, - ) -> BoundingBoxes: - output = self._F.affine_bounding_boxes( - self.as_subclass(torch.Tensor), - self.format, - self.canvas_size, - angle, - translate=translate, - scale=scale, - shear=shear, - center=center, - ) - return BoundingBoxes.wrap_like(self, output) - - def perspective( - self, - startpoints: Optional[List[List[int]]], - endpoints: Optional[List[List[int]]], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - coefficients: Optional[List[float]] = None, - ) -> BoundingBoxes: - output = self._F.perspective_bounding_boxes( - self.as_subclass(torch.Tensor), - format=self.format, - canvas_size=self.canvas_size, - startpoints=startpoints, - endpoints=endpoints, - coefficients=coefficients, - ) - return BoundingBoxes.wrap_like(self, output) - - def elastic( - self, - displacement: torch.Tensor, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - ) -> BoundingBoxes: - output = self._F.elastic_bounding_boxes( - self.as_subclass(torch.Tensor), self.format, self.canvas_size, displacement=displacement - ) - return BoundingBoxes.wrap_like(self, output) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 58e15151474..384273301de 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -1,13 +1,11 @@ from __future__ import annotations -from types import ModuleType from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union import PIL.Image import torch from torch._C import DisableTorchFunctionSubclass from torch.types import _device, _dtype, _size -from torchvision.transforms import InterpolationMode D = TypeVar("D", bound="Datapoint") @@ -16,8 +14,6 @@ class Datapoint(torch.Tensor): - __F: Optional[ModuleType] = None - @staticmethod def _to_tensor( data: Any, @@ -99,18 +95,6 @@ def _make_repr(self, **kwargs: Any) -> str: extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items()) return f"{super().__repr__()[:-1]}, {extra_repr})" - @property - def _F(self) -> ModuleType: - # This implements a lazy import of the functional to get around the cyclic import. This import is deferred - # until the first time we need reference to the functional module and it's shared across all instances of - # the class. This approach avoids the DataLoader issue described at - # https://github.com/pytorch/vision/pull/6476#discussion_r953588621 - if Datapoint.__F is None: - from ..transforms.v2 import functional - - Datapoint.__F = functional - return Datapoint.__F - # Add properties for common attributes like shape, dtype, device, ndim etc # this way we return the result without passing into __torch_function__ @property @@ -142,114 +126,6 @@ def __deepcopy__(self: D, memo: Dict[int, Any]) -> D: # `BoundingBoxes.clone()`. return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value] - def horizontal_flip(self) -> Datapoint: - return self - - def vertical_flip(self) -> Datapoint: - return self - - def crop(self, top: int, left: int, height: int, width: int) -> Datapoint: - return self - - def center_crop(self, output_size: List[int]) -> Datapoint: - return self - - def resized_crop( - self, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Datapoint: - return self - - def pad( - self, - padding: List[int], - fill: Optional[Union[int, float, List[float]]] = None, - padding_mode: str = "constant", - ) -> Datapoint: - return self - - def rotate( - self, - angle: float, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[float]] = None, - fill: _FillTypeJIT = None, - ) -> Datapoint: - return self - - def affine( - self, - angle: Union[int, float], - translate: List[float], - scale: float, - shear: List[float], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: _FillTypeJIT = None, - center: Optional[List[float]] = None, - ) -> Datapoint: - return self - - def perspective( - self, - startpoints: Optional[List[List[int]]], - endpoints: Optional[List[List[int]]], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - coefficients: Optional[List[float]] = None, - ) -> Datapoint: - return self - - def elastic( - self, - displacement: torch.Tensor, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - ) -> Datapoint: - return self - - def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint: - return self - - def adjust_saturation(self, saturation_factor: float) -> Datapoint: - return self - - def adjust_contrast(self, contrast_factor: float) -> Datapoint: - return self - - def adjust_sharpness(self, sharpness_factor: float) -> Datapoint: - return self - - def adjust_hue(self, hue_factor: float) -> Datapoint: - return self - - def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint: - return self - - def posterize(self, bits: int) -> Datapoint: - return self - - def solarize(self, threshold: float) -> Datapoint: - return self - - def autocontrast(self) -> Datapoint: - return self - - def equalize(self) -> Datapoint: - return self - - def invert(self) -> Datapoint: - return self - - def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint: - return self - _InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] _InputTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_image.py b/torchvision/datapoints/_image.py index 4a1d6f064ad..dccfc81a605 100644 --- a/torchvision/datapoints/_image.py +++ b/torchvision/datapoints/_image.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import PIL.Image import torch -from torchvision.transforms.functional import InterpolationMode -from ._datapoint import _FillTypeJIT, Datapoint +from ._datapoint import Datapoint class Image(Datapoint): @@ -56,177 +55,6 @@ def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image: def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() - def horizontal_flip(self) -> Image: - output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor)) - return Image.wrap_like(self, output) - - def vertical_flip(self) -> Image: - output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor)) - return Image.wrap_like(self, output) - - def crop(self, top: int, left: int, height: int, width: int) -> Image: - output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width) - return Image.wrap_like(self, output) - - def center_crop(self, output_size: List[int]) -> Image: - output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size) - return Image.wrap_like(self, output) - - def resized_crop( - self, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Image: - output = self._F.resized_crop_image_tensor( - self.as_subclass(torch.Tensor), - top, - left, - height, - width, - size=list(size), - interpolation=interpolation, - antialias=antialias, - ) - return Image.wrap_like(self, output) - - def pad( - self, - padding: List[int], - fill: Optional[Union[int, float, List[float]]] = None, - padding_mode: str = "constant", - ) -> Image: - output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) - return Image.wrap_like(self, output) - - def rotate( - self, - angle: float, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[float]] = None, - fill: _FillTypeJIT = None, - ) -> Image: - output = self._F.rotate_image_tensor( - self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center - ) - return Image.wrap_like(self, output) - - def affine( - self, - angle: Union[int, float], - translate: List[float], - scale: float, - shear: List[float], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: _FillTypeJIT = None, - center: Optional[List[float]] = None, - ) -> Image: - output = self._F.affine_image_tensor( - self.as_subclass(torch.Tensor), - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - return Image.wrap_like(self, output) - - def perspective( - self, - startpoints: Optional[List[List[int]]], - endpoints: Optional[List[List[int]]], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - coefficients: Optional[List[float]] = None, - ) -> Image: - output = self._F.perspective_image_tensor( - self.as_subclass(torch.Tensor), - startpoints, - endpoints, - interpolation=interpolation, - fill=fill, - coefficients=coefficients, - ) - return Image.wrap_like(self, output) - - def elastic( - self, - displacement: torch.Tensor, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - ) -> Image: - output = self._F.elastic_image_tensor( - self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill - ) - return Image.wrap_like(self, output) - - def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image: - output = self._F.rgb_to_grayscale_image_tensor( - self.as_subclass(torch.Tensor), num_output_channels=num_output_channels - ) - return Image.wrap_like(self, output) - - def adjust_saturation(self, saturation_factor: float) -> Image: - output = self._F.adjust_saturation_image_tensor( - self.as_subclass(torch.Tensor), saturation_factor=saturation_factor - ) - return Image.wrap_like(self, output) - - def adjust_contrast(self, contrast_factor: float) -> Image: - output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor) - return Image.wrap_like(self, output) - - def adjust_sharpness(self, sharpness_factor: float) -> Image: - output = self._F.adjust_sharpness_image_tensor( - self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor - ) - return Image.wrap_like(self, output) - - def adjust_hue(self, hue_factor: float) -> Image: - output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor) - return Image.wrap_like(self, output) - - def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: - output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain) - return Image.wrap_like(self, output) - - def posterize(self, bits: int) -> Image: - output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits) - return Image.wrap_like(self, output) - - def solarize(self, threshold: float) -> Image: - output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold) - return Image.wrap_like(self, output) - - def autocontrast(self) -> Image: - output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor)) - return Image.wrap_like(self, output) - - def equalize(self) -> Image: - output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor)) - return Image.wrap_like(self, output) - - def invert(self) -> Image: - output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor)) - return Image.wrap_like(self, output) - - def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: - output = self._F.gaussian_blur_image_tensor( - self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma - ) - return Image.wrap_like(self, output) - - def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image: - output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) - return Image.wrap_like(self, output) - _ImageType = Union[torch.Tensor, PIL.Image.Image, Image] _ImageTypeJIT = torch.Tensor diff --git a/torchvision/datapoints/_mask.py b/torchvision/datapoints/_mask.py index 4a75ecb5373..2b95eca72e2 100644 --- a/torchvision/datapoints/_mask.py +++ b/torchvision/datapoints/_mask.py @@ -1,12 +1,11 @@ from __future__ import annotations -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import PIL.Image import torch -from torchvision.transforms import InterpolationMode -from ._datapoint import _FillTypeJIT, Datapoint +from ._datapoint import Datapoint class Mask(Datapoint): @@ -50,95 +49,3 @@ def wrap_like( tensor: torch.Tensor, ) -> Mask: return cls._wrap(tensor) - - def horizontal_flip(self) -> Mask: - output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor)) - return Mask.wrap_like(self, output) - - def vertical_flip(self) -> Mask: - output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor)) - return Mask.wrap_like(self, output) - - def crop(self, top: int, left: int, height: int, width: int) -> Mask: - output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width) - return Mask.wrap_like(self, output) - - def center_crop(self, output_size: List[int]) -> Mask: - output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size) - return Mask.wrap_like(self, output) - - def resized_crop( - self, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Mask: - output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size) - return Mask.wrap_like(self, output) - - def pad( - self, - padding: List[int], - fill: Optional[Union[int, float, List[float]]] = None, - padding_mode: str = "constant", - ) -> Mask: - output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill) - return Mask.wrap_like(self, output) - - def rotate( - self, - angle: float, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[float]] = None, - fill: _FillTypeJIT = None, - ) -> Mask: - output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill) - return Mask.wrap_like(self, output) - - def affine( - self, - angle: Union[int, float], - translate: List[float], - scale: float, - shear: List[float], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: _FillTypeJIT = None, - center: Optional[List[float]] = None, - ) -> Mask: - output = self._F.affine_mask( - self.as_subclass(torch.Tensor), - angle, - translate=translate, - scale=scale, - shear=shear, - fill=fill, - center=center, - ) - return Mask.wrap_like(self, output) - - def perspective( - self, - startpoints: Optional[List[List[int]]], - endpoints: Optional[List[List[int]]], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: _FillTypeJIT = None, - coefficients: Optional[List[float]] = None, - ) -> Mask: - output = self._F.perspective_mask( - self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients - ) - return Mask.wrap_like(self, output) - - def elastic( - self, - displacement: torch.Tensor, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: _FillTypeJIT = None, - ) -> Mask: - output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill) - return Mask.wrap_like(self, output) diff --git a/torchvision/datapoints/_video.py b/torchvision/datapoints/_video.py index b12dd5480eb..11d6e2a854d 100644 --- a/torchvision/datapoints/_video.py +++ b/torchvision/datapoints/_video.py @@ -1,11 +1,10 @@ from __future__ import annotations -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch -from torchvision.transforms.functional import InterpolationMode -from ._datapoint import _FillTypeJIT, Datapoint +from ._datapoint import Datapoint class Video(Datapoint): @@ -46,171 +45,6 @@ def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video: def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr() - def horizontal_flip(self) -> Video: - output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor)) - return Video.wrap_like(self, output) - - def vertical_flip(self) -> Video: - output = self._F.vertical_flip_video(self.as_subclass(torch.Tensor)) - return Video.wrap_like(self, output) - - def crop(self, top: int, left: int, height: int, width: int) -> Video: - output = self._F.crop_video(self.as_subclass(torch.Tensor), top, left, height, width) - return Video.wrap_like(self, output) - - def center_crop(self, output_size: List[int]) -> Video: - output = self._F.center_crop_video(self.as_subclass(torch.Tensor), output_size=output_size) - return Video.wrap_like(self, output) - - def resized_crop( - self, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", - ) -> Video: - output = self._F.resized_crop_video( - self.as_subclass(torch.Tensor), - top, - left, - height, - width, - size=list(size), - interpolation=interpolation, - antialias=antialias, - ) - return Video.wrap_like(self, output) - - def pad( - self, - padding: List[int], - fill: Optional[Union[int, float, List[float]]] = None, - padding_mode: str = "constant", - ) -> Video: - output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode) - return Video.wrap_like(self, output) - - def rotate( - self, - angle: float, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[float]] = None, - fill: _FillTypeJIT = None, - ) -> Video: - output = self._F.rotate_video( - self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center - ) - return Video.wrap_like(self, output) - - def affine( - self, - angle: Union[int, float], - translate: List[float], - scale: float, - shear: List[float], - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: _FillTypeJIT = None, - center: Optional[List[float]] = None, - ) -> Video: - output = self._F.affine_video( - self.as_subclass(torch.Tensor), - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - return Video.wrap_like(self, output) - - def perspective( - self, - startpoints: Optional[List[List[int]]], - endpoints: Optional[List[List[int]]], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - coefficients: Optional[List[float]] = None, - ) -> Video: - output = self._F.perspective_video( - self.as_subclass(torch.Tensor), - startpoints, - endpoints, - interpolation=interpolation, - fill=fill, - coefficients=coefficients, - ) - return Video.wrap_like(self, output) - - def elastic( - self, - displacement: torch.Tensor, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: _FillTypeJIT = None, - ) -> Video: - output = self._F.elastic_video( - self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill - ) - return Video.wrap_like(self, output) - - def rgb_to_grayscale(self, num_output_channels: int = 1) -> Video: - output = self._F.rgb_to_grayscale_image_tensor( - self.as_subclass(torch.Tensor), num_output_channels=num_output_channels - ) - return Video.wrap_like(self, output) - - def adjust_saturation(self, saturation_factor: float) -> Video: - output = self._F.adjust_saturation_video(self.as_subclass(torch.Tensor), saturation_factor=saturation_factor) - return Video.wrap_like(self, output) - - def adjust_contrast(self, contrast_factor: float) -> Video: - output = self._F.adjust_contrast_video(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor) - return Video.wrap_like(self, output) - - def adjust_sharpness(self, sharpness_factor: float) -> Video: - output = self._F.adjust_sharpness_video(self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor) - return Video.wrap_like(self, output) - - def adjust_hue(self, hue_factor: float) -> Video: - output = self._F.adjust_hue_video(self.as_subclass(torch.Tensor), hue_factor=hue_factor) - return Video.wrap_like(self, output) - - def adjust_gamma(self, gamma: float, gain: float = 1) -> Video: - output = self._F.adjust_gamma_video(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain) - return Video.wrap_like(self, output) - - def posterize(self, bits: int) -> Video: - output = self._F.posterize_video(self.as_subclass(torch.Tensor), bits=bits) - return Video.wrap_like(self, output) - - def solarize(self, threshold: float) -> Video: - output = self._F.solarize_video(self.as_subclass(torch.Tensor), threshold=threshold) - return Video.wrap_like(self, output) - - def autocontrast(self) -> Video: - output = self._F.autocontrast_video(self.as_subclass(torch.Tensor)) - return Video.wrap_like(self, output) - - def equalize(self) -> Video: - output = self._F.equalize_video(self.as_subclass(torch.Tensor)) - return Video.wrap_like(self, output) - - def invert(self) -> Video: - output = self._F.invert_video(self.as_subclass(torch.Tensor)) - return Video.wrap_like(self, output) - - def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video: - output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma) - return Video.wrap_like(self, output) - - def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video: - output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace) - return Video.wrap_like(self, output) - _VideoType = Union[torch.Tensor, Video] _VideoTypeJIT = torch.Tensor diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index d70be15985d..95b4ed93786 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -25,7 +25,7 @@ def erase( if torch.jit.is_scripting() or is_simple_tensor(inpt): return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(erase, type(inpt)) return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) elif isinstance(inpt, PIL.Image.Image): diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 57ef96e879a..99dc1936259 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -13,29 +13,7 @@ from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor -def _rgb_to_grayscale_image_tensor( - image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True -) -> torch.Tensor: - if image.shape[-3] == 1: - return image.clone() - - r, g, b = image.unbind(dim=-3) - l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) - l_img = l_img.unsqueeze(dim=-3) - if preserve_dtype: - l_img = l_img.to(image.dtype) - if num_output_channels == 3: - l_img = l_img.expand(image.shape) - return l_img - - -def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: - return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) - - -rgb_to_grayscale_image_pil = _FP.to_grayscale - - +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) def rgb_to_grayscale( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: @@ -45,8 +23,9 @@ def rgb_to_grayscale( raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") if torch.jit.is_scripting() or is_simple_tensor(inpt): return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.rgb_to_grayscale(num_output_channels=num_output_channels) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(rgb_to_grayscale, type(inpt)) + return kernel(inpt, num_output_channels=num_output_channels) elif isinstance(inpt, PIL.Image.Image): return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels) else: @@ -61,6 +40,30 @@ def rgb_to_grayscale( to_grayscale = rgb_to_grayscale +def _rgb_to_grayscale_image_tensor( + image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True +) -> torch.Tensor: + if image.shape[-3] == 1: + return image.clone() + + r, g, b = image.unbind(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.unsqueeze(dim=-3) + if preserve_dtype: + l_img = l_img.to(image.dtype) + if num_output_channels == 3: + l_img = l_img.expand(image.shape) + return l_img + + +@_register_kernel_internal(rgb_to_grayscale, datapoints.Image) +def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) + + +rgb_to_grayscale_image_pil = _FP.to_grayscale + + def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) fp = image1.is_floating_point() @@ -76,7 +79,7 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) if torch.jit.is_scripting() or is_simple_tensor(inpt): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(adjust_brightness, type(inpt)) return kernel(inpt, brightness_factor=brightness_factor) elif isinstance(inpt, PIL.Image.Image): @@ -112,6 +115,26 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(adjust_saturation) + + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(adjust_saturation, type(inpt)) + return kernel(inpt, saturation_factor=saturation_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(adjust_saturation, datapoints.Image) def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") @@ -133,22 +156,23 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float adjust_saturation_image_pil = _FP.adjust_saturation +@_register_kernel_internal(adjust_saturation, datapoints.Video) def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) -def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_saturation) + _log_api_usage_once(adjust_contrast) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): - return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.adjust_saturation(saturation_factor=saturation_factor) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(adjust_contrast, type(inpt)) + return kernel(inpt, contrast_factor=contrast_factor) elif isinstance(inpt, PIL.Image.Image): - return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) + return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -156,6 +180,7 @@ def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) ) +@_register_kernel_internal(adjust_contrast, datapoints.Image) def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") @@ -177,20 +202,23 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> adjust_contrast_image_pil = _FP.adjust_contrast +@_register_kernel_internal(adjust_contrast, datapoints.Video) def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) -def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_contrast) + _log_api_usage_once(adjust_sharpness) - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.adjust_contrast(contrast_factor=contrast_factor) + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)): + return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(adjust_sharpness, type(inpt)) + return kernel(inpt, sharpness_factor=sharpness_factor) elif isinstance(inpt, PIL.Image.Image): - return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) + return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -198,6 +226,7 @@ def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> d ) +@_register_kernel_internal(adjust_sharpness, datapoints.Image) def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: num_channels, height, width = image.shape[-3:] if num_channels not in (1, 3): @@ -253,22 +282,23 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) adjust_sharpness_image_pil = _FP.adjust_sharpness +@_register_kernel_internal(adjust_sharpness, datapoints.Video) def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) -def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_sharpness) + _log_api_usage_once(adjust_hue) - if isinstance(inpt, torch.Tensor) and ( - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) - ): - return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(adjust_hue, type(inpt)) + return kernel(inpt, hue_factor=hue_factor) elif isinstance(inpt, PIL.Image.Image): - return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) + return adjust_hue_image_pil(inpt, hue_factor=hue_factor) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -340,6 +370,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3) +@_register_kernel_internal(adjust_hue, datapoints.Image) def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") @@ -370,20 +401,23 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten adjust_hue_image_pil = _FP.adjust_hue +@_register_kernel_internal(adjust_hue, datapoints.Video) def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: return adjust_hue_image_tensor(video, hue_factor=hue_factor) -def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_hue) + _log_api_usage_once(adjust_gamma) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.adjust_hue(hue_factor=hue_factor) + return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(adjust_gamma, type(inpt)) + return kernel(inpt, gamma=gamma, gain=gain) elif isinstance(inpt, PIL.Image.Image): - return adjust_hue_image_pil(inpt, hue_factor=hue_factor) + return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -391,6 +425,7 @@ def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints. ) +@_register_kernel_internal(adjust_gamma, datapoints.Image) def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: if gamma < 0: raise ValueError("Gamma should be a non-negative real number") @@ -413,20 +448,23 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 adjust_gamma_image_pil = _FP.adjust_gamma +@_register_kernel_internal(adjust_gamma, datapoints.Video) def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) -def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(adjust_gamma) + _log_api_usage_once(posterize) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.adjust_gamma(gamma=gamma, gain=gain) + return posterize_image_tensor(inpt, bits=bits) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(posterize, type(inpt)) + return kernel(inpt, bits=bits) elif isinstance(inpt, PIL.Image.Image): - return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) + return posterize_image_pil(inpt, bits=bits) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -434,6 +472,7 @@ def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) ) +@_register_kernel_internal(posterize, datapoints.Image) def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: if image.is_floating_point(): levels = 1 << bits @@ -450,20 +489,23 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: posterize_image_pil = _FP.posterize +@_register_kernel_internal(posterize, datapoints.Video) def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: return posterize_image_tensor(video, bits=bits) -def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(posterize) + _log_api_usage_once(solarize) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return posterize_image_tensor(inpt, bits=bits) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.posterize(bits=bits) + return solarize_image_tensor(inpt, threshold=threshold) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(solarize, type(inpt)) + return kernel(inpt, threshold=threshold) elif isinstance(inpt, PIL.Image.Image): - return posterize_image_pil(inpt, bits=bits) + return solarize_image_pil(inpt, threshold=threshold) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -471,6 +513,7 @@ def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTyp ) +@_register_kernel_internal(solarize, datapoints.Image) def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: if threshold > _max_value(image.dtype): raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") @@ -481,20 +524,25 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor solarize_image_pil = _FP.solarize +@_register_kernel_internal(solarize, datapoints.Video) def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: return solarize_image_tensor(video, threshold=threshold) -def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(solarize) + _log_api_usage_once(autocontrast) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return solarize_image_tensor(inpt, threshold=threshold) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.solarize(threshold=threshold) + return autocontrast_image_tensor(inpt) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(autocontrast, type(inpt)) + return kernel( + inpt, + ) elif isinstance(inpt, PIL.Image.Image): - return solarize_image_pil(inpt, threshold=threshold) + return autocontrast_image_pil(inpt) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -502,6 +550,7 @@ def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._In ) +@_register_kernel_internal(autocontrast, datapoints.Image) def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: c = image.shape[-3] if c not in [1, 3]: @@ -534,20 +583,25 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: autocontrast_image_pil = _FP.autocontrast +@_register_kernel_internal(autocontrast, datapoints.Video) def autocontrast_video(video: torch.Tensor) -> torch.Tensor: return autocontrast_image_tensor(video) -def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(autocontrast) + _log_api_usage_once(equalize) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return autocontrast_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.autocontrast() + return equalize_image_tensor(inpt) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(equalize, type(inpt)) + return kernel( + inpt, + ) elif isinstance(inpt, PIL.Image.Image): - return autocontrast_image_pil(inpt) + return equalize_image_pil(inpt) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -555,6 +609,7 @@ def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: ) +@_register_kernel_internal(equalize, datapoints.Image) def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image @@ -627,20 +682,25 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: equalize_image_pil = _FP.equalize +@_register_kernel_internal(equalize, datapoints.Video) def equalize_video(video: torch.Tensor) -> torch.Tensor: return equalize_image_tensor(video) -def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(equalize) + _log_api_usage_once(invert) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return equalize_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.equalize() + return invert_image_tensor(inpt) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(invert, type(inpt)) + return kernel( + inpt, + ) elif isinstance(inpt, PIL.Image.Image): - return equalize_image_pil(inpt) + return invert_image_pil(inpt) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -648,6 +708,7 @@ def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: ) +@_register_kernel_internal(invert, datapoints.Image) def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.is_floating_point(): return 1.0 - image @@ -661,22 +722,6 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: invert_image_pil = _FP.invert +@_register_kernel_internal(invert, datapoints.Video) def invert_video(video: torch.Tensor) -> torch.Tensor: return invert_image_tensor(video) - - -def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(invert) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return invert_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.invert() - elif isinstance(inpt, PIL.Image.Image): - return invert_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 2d7d83e31f6..21f2aa8df0a 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -45,6 +45,27 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp return interpolation +def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(horizontal_flip) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return horizontal_flip_image_tensor(inpt) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(horizontal_flip, type(inpt)) + return kernel( + inpt, + ) + elif isinstance(inpt, PIL.Image.Image): + return horizontal_flip_image_pil(inpt) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + +@_register_kernel_internal(horizontal_flip, datapoints.Image) def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-1) @@ -53,6 +74,7 @@ def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) +@_register_kernel_internal(horizontal_flip, datapoints.Mask) def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image_tensor(mask) @@ -74,20 +96,32 @@ def horizontal_flip_bounding_boxes( return bounding_boxes.reshape(shape) +@_register_kernel_internal(horizontal_flip, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes: + output = horizontal_flip_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output) + + +@_register_kernel_internal(horizontal_flip, datapoints.Video) def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: return horizontal_flip_image_tensor(video) -def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: +def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(horizontal_flip) + _log_api_usage_once(vertical_flip) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return horizontal_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.horizontal_flip() + return vertical_flip_image_tensor(inpt) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(vertical_flip, type(inpt)) + return kernel( + inpt, + ) elif isinstance(inpt, PIL.Image.Image): - return horizontal_flip_image_pil(inpt) + return vertical_flip_image_pil(inpt) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -95,6 +129,7 @@ def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: ) +@_register_kernel_internal(vertical_flip, datapoints.Image) def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.flip(-2) @@ -103,6 +138,7 @@ def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: return _FP.vflip(image) +@_register_kernel_internal(vertical_flip, datapoints.Mask) def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image_tensor(mask) @@ -124,25 +160,17 @@ def vertical_flip_bounding_boxes( return bounding_boxes.reshape(shape) -def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: - return vertical_flip_image_tensor(video) +@_register_kernel_internal(vertical_flip, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes: + output = vertical_flip_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output) -def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(vertical_flip) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return vertical_flip_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.vertical_flip() - elif isinstance(inpt, PIL.Image.Image): - return vertical_flip_image_pil(inpt) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) +@_register_kernel_internal(vertical_flip, datapoints.Video) +def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(video) # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are @@ -175,7 +203,7 @@ def resize( _log_api_usage_once(resize) if torch.jit.is_scripting() or is_simple_tensor(inpt): return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(resize, type(inpt)) return kernel(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) elif isinstance(inpt, PIL.Image.Image): @@ -353,6 +381,61 @@ def resize_video( return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) +def affine( + inpt: datapoints._InputTypeJIT, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + fill: datapoints._FillTypeJIT = None, + center: Optional[List[float]] = None, +) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(affine) + + # TODO: consider deprecating integers from angle and shear on the future + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return affine_image_tensor( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(affine, type(inpt)) + return kernel( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + elif isinstance(inpt, PIL.Image.Image): + return affine_image_pil( + inpt, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + def _affine_parse_args( angle: Union[int, float], translate: List[float], @@ -601,6 +684,7 @@ def _affine_grid( return output_grid.view(1, oh, ow, 2) +@_register_kernel_internal(affine, datapoints.Image) def affine_image_tensor( image: torch.Tensor, angle: Union[int, float], @@ -790,6 +874,29 @@ def affine_bounding_boxes( return out_box +@_register_kernel_internal(affine, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _affine_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, + **kwargs, +) -> datapoints.BoundingBoxes: + output = affine_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output) + + def affine_mask( mask: torch.Tensor, angle: Union[int, float], @@ -822,6 +929,30 @@ def affine_mask( return output +@_register_kernel_internal(affine, datapoints.Mask, datapoint_wrapper=False) +def _affine_mask_dispatch( + inpt: datapoints.Mask, + angle: Union[int, float], + translate: List[float], + scale: float, + shear: List[float], + fill: datapoints._FillTypeJIT = None, + center: Optional[List[float]] = None, + **kwargs, +) -> datapoints.Mask: + output = affine_mask( + inpt.as_subclass(torch.Tensor), + angle=angle, + translate=translate, + scale=scale, + shear=shear, + fill=fill, + center=center, + ) + return datapoints.Mask.wrap_like(inpt, output) + + +@_register_kernel_internal(affine, datapoints.Video) def affine_video( video: torch.Tensor, angle: Union[int, float], @@ -844,46 +975,24 @@ def affine_video( ) -def affine( +def rotate( inpt: datapoints._InputTypeJIT, - angle: Union[int, float], - translate: List[float], - scale: float, - shear: List[float], + angle: float, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - fill: datapoints._FillTypeJIT = None, + expand: bool = False, center: Optional[List[float]] = None, + fill: datapoints._FillTypeJIT = None, ) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(affine) + _log_api_usage_once(rotate) - # TODO: consider deprecating integers from angle and shear on the future if torch.jit.is_scripting() or is_simple_tensor(inpt): - return affine_image_tensor( - inpt, - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.affine( - angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center - ) + return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(rotate, type(inpt)) + return kernel(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, PIL.Image.Image): - return affine_image_pil( - inpt, - angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) + return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -891,6 +1000,7 @@ def affine( ) +@_register_kernel_internal(rotate, datapoints.Image) def rotate_image_tensor( image: torch.Tensor, angle: float, @@ -978,6 +1088,21 @@ def rotate_bounding_boxes( ) +@_register_kernel_internal(rotate, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _rotate_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs +) -> datapoints.BoundingBoxes: + output, canvas_size = rotate_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + angle=angle, + expand=expand, + center=center, + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + + def rotate_mask( mask: torch.Tensor, angle: float, @@ -1006,6 +1131,20 @@ def rotate_mask( return output +@_register_kernel_internal(rotate, datapoints.Mask, datapoint_wrapper=False) +def _rotate_mask_dispatch( + inpt: datapoints.Mask, + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, + fill: datapoints._FillTypeJIT = None, + **kwargs, +) -> datapoints.Mask: + output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center) + return datapoints.Mask.wrap_like(inpt, output) + + +@_register_kernel_internal(rotate, datapoints.Video) def rotate_video( video: torch.Tensor, angle: float, @@ -1017,23 +1156,23 @@ def rotate_video( return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) -def rotate( +def pad( inpt: datapoints._InputTypeJIT, - angle: float, - interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[float]] = None, - fill: datapoints._FillTypeJIT = None, + padding: List[int], + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", ) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(rotate) + _log_api_usage_once(pad) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) + + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(pad, type(inpt)) + return kernel(inpt, padding, fill=fill, padding_mode=padding_mode) elif isinstance(inpt, PIL.Image.Image): - return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -1065,6 +1204,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: return [pad_left, pad_right, pad_top, pad_bottom] +@_register_kernel_internal(pad, datapoints.Image) def pad_image_tensor( image: torch.Tensor, padding: List[int], @@ -1166,6 +1306,7 @@ def _pad_with_vector_fill( pad_image_pil = _FP.pad +@_register_kernel_internal(pad, datapoints.Mask) def pad_mask( mask: torch.Tensor, padding: List[int], @@ -1219,6 +1360,21 @@ def pad_bounding_boxes( return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size +@_register_kernel_internal(pad, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _pad_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs +) -> datapoints.BoundingBoxes: + output, canvas_size = pad_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + padding=padding, + padding_mode=padding_mode, + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + + +@_register_kernel_internal(pad, datapoints.Video) def pad_video( video: torch.Tensor, padding: List[int], @@ -1228,22 +1384,17 @@ def pad_video( return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) -def pad( - inpt: datapoints._InputTypeJIT, - padding: List[int], - fill: Optional[Union[int, float, List[float]]] = None, - padding_mode: str = "constant", -) -> datapoints._InputTypeJIT: +def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(pad) + _log_api_usage_once(crop) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) - - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.pad(padding, fill=fill, padding_mode=padding_mode) + return crop_image_tensor(inpt, top, left, height, width) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(crop, type(inpt)) + return kernel(inpt, top, left, height, width) elif isinstance(inpt, PIL.Image.Image): - return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) + return crop_image_pil(inpt, top, left, height, width) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -1251,6 +1402,7 @@ def pad( ) +@_register_kernel_internal(crop, datapoints.Image) def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: h, w = image.shape[-2:] @@ -1293,6 +1445,17 @@ def crop_bounding_boxes( return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size +@_register_kernel_internal(crop, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _crop_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int +) -> datapoints.BoundingBoxes: + output, canvas_size = crop_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + + +@_register_kernel_internal(crop, datapoints.Mask) def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -1308,20 +1471,32 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) return output +@_register_kernel_internal(crop, datapoints.Video) def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: return crop_image_tensor(video, top, left, height, width) -def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: +def perspective( + inpt: datapoints._InputTypeJIT, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: datapoints._FillTypeJIT = None, + coefficients: Optional[List[float]] = None, +) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(crop) - + _log_api_usage_once(perspective) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return crop_image_tensor(inpt, top, left, height, width) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.crop(top, left, height, width) + return perspective_image_tensor( + inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients + ) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(perspective, type(inpt)) + return kernel(inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients) elif isinstance(inpt, PIL.Image.Image): - return crop_image_pil(inpt, top, left, height, width) + return perspective_image_pil( + inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients + ) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -1376,6 +1551,7 @@ def _perspective_coefficients( raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.") +@_register_kernel_internal(perspective, datapoints.Image) def perspective_image_tensor( image: torch.Tensor, startpoints: Optional[List[List[int]]], @@ -1530,6 +1706,25 @@ def perspective_bounding_boxes( ).reshape(original_shape) +@_register_kernel_internal(perspective, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _perspective_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + coefficients: Optional[List[float]] = None, + **kwargs, +) -> datapoints.BoundingBoxes: + output = perspective_bounding_boxes( + inpt.as_subclass(torch.Tensor), + format=inpt.format, + canvas_size=inpt.canvas_size, + startpoints=startpoints, + endpoints=endpoints, + coefficients=coefficients, + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output) + + def perspective_mask( mask: torch.Tensor, startpoints: Optional[List[List[int]]], @@ -1553,6 +1748,26 @@ def perspective_mask( return output +@_register_kernel_internal(perspective, datapoints.Mask, datapoint_wrapper=False) +def _perspective_mask_dispatch( + inpt: datapoints.Mask, + startpoints: Optional[List[List[int]]], + endpoints: Optional[List[List[int]]], + fill: datapoints._FillTypeJIT = None, + coefficients: Optional[List[float]] = None, + **kwargs, +) -> datapoints.Mask: + output = perspective_mask( + inpt.as_subclass(torch.Tensor), + startpoints=startpoints, + endpoints=endpoints, + fill=fill, + coefficients=coefficients, + ) + return datapoints.Mask.wrap_like(inpt, output) + + +@_register_kernel_internal(perspective, datapoints.Video) def perspective_video( video: torch.Tensor, startpoints: Optional[List[List[int]]], @@ -1566,28 +1781,25 @@ def perspective_video( ) -def perspective( +def elastic( inpt: datapoints._InputTypeJIT, - startpoints: Optional[List[List[int]]], - endpoints: Optional[List[List[int]]], + displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: datapoints._FillTypeJIT = None, - coefficients: Optional[List[float]] = None, ) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(perspective) + _log_api_usage_once(elastic) + + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + if torch.jit.is_scripting() or is_simple_tensor(inpt): - return perspective_image_tensor( - inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients - ) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.perspective( - startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients - ) + return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(elastic, type(inpt)) + return kernel(inpt, displacement, interpolation=interpolation, fill=fill) elif isinstance(inpt, PIL.Image.Image): - return perspective_image_pil( - inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients - ) + return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -1595,6 +1807,10 @@ def perspective( ) +elastic_transform = elastic + + +@_register_kernel_internal(elastic, datapoints.Image) def elastic_image_tensor( image: torch.Tensor, displacement: torch.Tensor, @@ -1726,6 +1942,16 @@ def elastic_bounding_boxes( ).reshape(original_shape) +@_register_kernel_internal(elastic, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _elastic_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, displacement: torch.Tensor, **kwargs +) -> datapoints.BoundingBoxes: + output = elastic_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output) + + def elastic_mask( mask: torch.Tensor, displacement: torch.Tensor, @@ -1745,6 +1971,15 @@ def elastic_mask( return output +@_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False) +def _elastic_mask_dispatch( + inpt: datapoints.Mask, displacement: torch.Tensor, fill: datapoints._FillTypeJIT = None, **kwargs +) -> datapoints.Mask: + output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill) + return datapoints.Mask.wrap_like(inpt, output) + + +@_register_kernel_internal(elastic, datapoints.Video) def elastic_video( video: torch.Tensor, displacement: torch.Tensor, @@ -1754,24 +1989,17 @@ def elastic_video( return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) -def elastic( - inpt: datapoints._InputTypeJIT, - displacement: torch.Tensor, - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - fill: datapoints._FillTypeJIT = None, -) -> datapoints._InputTypeJIT: +def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(elastic) - - if not isinstance(displacement, torch.Tensor): - raise TypeError("Argument displacement should be a Tensor") + _log_api_usage_once(center_crop) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.elastic(displacement, interpolation=interpolation, fill=fill) + return center_crop_image_tensor(inpt, output_size) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(center_crop, type(inpt)) + return kernel(inpt, output_size) elif isinstance(inpt, PIL.Image.Image): - return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) + return center_crop_image_pil(inpt, output_size) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -1779,9 +2007,6 @@ def elastic( ) -elastic_transform = elastic - - def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): s = int(output_size) @@ -1809,6 +2034,7 @@ def _center_crop_compute_crop_anchor( return crop_top, crop_left +@_register_kernel_internal(center_crop, datapoints.Image) def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) shape = image.shape @@ -1858,6 +2084,17 @@ def center_crop_bounding_boxes( ) +@_register_kernel_internal(center_crop, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _center_crop_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, output_size: List[int] +) -> datapoints.BoundingBoxes: + output, canvas_size = center_crop_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + + +@_register_kernel_internal(center_crop, datapoints.Mask) def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) @@ -1873,20 +2110,33 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor return output +@_register_kernel_internal(center_crop, datapoints.Video) def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: return center_crop_image_tensor(video, output_size) -def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: +def resized_crop( + inpt: datapoints._InputTypeJIT, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + antialias: Optional[Union[str, bool]] = "warn", +) -> datapoints._InputTypeJIT: if not torch.jit.is_scripting(): - _log_api_usage_once(center_crop) + _log_api_usage_once(resized_crop) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return center_crop_image_tensor(inpt, output_size) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.center_crop(output_size) + return resized_crop_image_tensor( + inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation + ) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(resized_crop, type(inpt)) + return kernel(inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) elif isinstance(inpt, PIL.Image.Image): - return center_crop_image_pil(inpt, output_size) + return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -1894,6 +2144,7 @@ def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datap ) +@_register_kernel_internal(resized_crop, datapoints.Image) def resized_crop_image_tensor( image: torch.Tensor, top: int, @@ -1931,8 +2182,18 @@ def resized_crop_bounding_boxes( width: int, size: List[int], ) -> Tuple[torch.Tensor, Tuple[int, int]]: - bounding_boxes, _ = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) - return resize_bounding_boxes(bounding_boxes, canvas_size=(height, width), size=size) + bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) + return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size) + + +@_register_kernel_internal(resized_crop, datapoints.BoundingBoxes, datapoint_wrapper=False) +def _resized_crop_bounding_boxes_dispatch( + inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs +) -> datapoints.BoundingBoxes: + output, canvas_size = resized_crop_bounding_boxes( + inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size + ) + return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) def resized_crop_mask( @@ -1947,6 +2208,17 @@ def resized_crop_mask( return resize_mask(mask, size) +@_register_kernel_internal(resized_crop, datapoints.Mask, datapoint_wrapper=False) +def _resized_crop_mask_dispatch( + inpt: datapoints.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs +) -> datapoints.Mask: + output = resized_crop_mask( + inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size + ) + return datapoints.Mask.wrap_like(inpt, output) + + +@_register_kernel_internal(resized_crop, datapoints.Video) def resized_crop_video( video: torch.Tensor, top: int, @@ -1962,34 +2234,6 @@ def resized_crop_video( ) -def resized_crop( - inpt: datapoints._InputTypeJIT, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", -) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(resized_crop) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return resized_crop_image_tensor( - inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation - ) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation) - elif isinstance(inpt, PIL.Image.Image): - return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) def five_crop( inpt: datapoints._InputTypeJIT, size: List[int] @@ -2005,7 +2249,7 @@ def five_crop( if torch.jit.is_scripting() or is_simple_tensor(inpt): return five_crop_image_tensor(inpt, size) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(five_crop, type(inpt)) return kernel(inpt, size) elif isinstance(inpt, PIL.Image.Image): @@ -2096,7 +2340,7 @@ def ten_crop( if torch.jit.is_scripting() or is_simple_tensor(inpt): return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(ten_crop, type(inpt)) return kernel(inpt, size, vertical_flip=vertical_flip) elif isinstance(inpt, PIL.Image.Image): diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 7148c003b28..a4bfe7df8e4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -18,7 +18,7 @@ def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJI if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_dimensions_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(get_dimensions, type(inpt)) return kernel(inpt) elif isinstance(inpt, PIL.Image.Image): @@ -58,7 +58,7 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_num_channels_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(get_num_channels, type(inpt)) return kernel(inpt) elif isinstance(inpt, PIL.Image.Image): @@ -101,7 +101,7 @@ def get_size(inpt: datapoints._InputTypeJIT) -> List[int]: if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_size_image_tensor(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(get_size, type(inpt)) return kernel(inpt) elif isinstance(inpt, PIL.Image.Image): @@ -151,7 +151,7 @@ def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: if torch.jit.is_scripting() or is_simple_tensor(inpt): return get_num_frames_video(inpt) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(get_num_frames, type(inpt)) return kernel(inpt) else: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index c7ca4fe1d62..90a3e44e9d3 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -32,7 +32,7 @@ def normalize( _log_api_usage_once(normalize) if torch.jit.is_scripting() or is_simple_tensor(inpt): return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(normalize, type(inpt)) return kernel(inpt, mean=mean, std=std, inplace=inplace) else: @@ -82,6 +82,27 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in return normalize_image_tensor(video, mean, std, inplace=inplace) +@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) +def gaussian_blur( + inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> datapoints._InputTypeJIT: + if not torch.jit.is_scripting(): + _log_api_usage_once(gaussian_blur) + + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) + elif isinstance(inpt, datapoints.Datapoint): + kernel = _get_kernel(gaussian_blur, type(inpt)) + return kernel(inpt, kernel_size=kernel_size, sigma=sigma) + elif isinstance(inpt, PIL.Image.Image): + return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma) x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) @@ -98,6 +119,7 @@ def _get_gaussian_kernel2d( return kernel2d +@_register_kernel_internal(gaussian_blur, datapoints.Image) def gaussian_blur_image_tensor( image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: @@ -171,31 +193,13 @@ def gaussian_blur_image_pil( return to_pil_image(output, mode=image.mode) +@_register_kernel_internal(gaussian_blur, datapoints.Video) def gaussian_blur_video( video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: return gaussian_blur_image_tensor(video, kernel_size, sigma) -def gaussian_blur( - inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None -) -> datapoints._InputTypeJIT: - if not torch.jit.is_scripting(): - _log_api_usage_once(gaussian_blur) - - if torch.jit.is_scripting() or is_simple_tensor(inpt): - return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma) - elif isinstance(inpt, PIL.Image.Image): - return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma) - else: - raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " - f"but got {type(inpt)} instead." - ) - - def to_dtype( inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False ) -> datapoints._InputTypeJIT: @@ -204,7 +208,7 @@ def to_dtype( if torch.jit.is_scripting() or is_simple_tensor(inpt): return to_dtype_image_tensor(inpt, dtype, scale=scale) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(to_dtype, type(inpt)) return kernel(inpt, dtype, scale=scale) else: diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 81d6793f179..52c745f9901 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -17,12 +17,12 @@ def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) if torch.jit.is_scripting() or is_simple_tensor(inpt): return uniform_temporal_subsample_video(inpt, num_samples) - elif isinstance(inpt, datapoints._datapoint.Datapoint): + elif isinstance(inpt, datapoints.Datapoint): kernel = _get_kernel(uniform_temporal_subsample, type(inpt)) return kernel(inpt, num_samples) else: raise TypeError( - f"Input can either be a plain tensor, any TorchVision datapoint, " f"but got {type(inpt)} instead." + f"Input can either be a plain tensor or any TorchVision datapoint, but got {type(inpt)} instead." ) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index e8201a05cea..3cd9b7be55c 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -3,11 +3,11 @@ from typing import Any, Callable, Dict, Type import torch -from torchvision.datapoints._datapoint import Datapoint +from torchvision import datapoints def is_simple_tensor(inpt: Any) -> bool: - return isinstance(inpt, torch.Tensor) and not isinstance(inpt, Datapoint) + return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint) _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} From a92401379b25095ffdba4d70946955642798cc20 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 14:12:50 +0200 Subject: [PATCH 19/22] put back legacy test_dispatch_datapoint --- test/test_transforms_v2_functional.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 3075efab9a0..8d529732610 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -2,11 +2,11 @@ import math import os import re +from unittest import mock import numpy as np import PIL.Image import pytest - import torch from common_utils import ( @@ -25,6 +25,7 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes +from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY from torchvision.transforms.v2.utils import is_simple_tensor from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS @@ -415,6 +416,28 @@ def test_pil_output_type(self, info, args_kwargs): assert isinstance(output, PIL.Image.Image) + @make_info_args_kwargs_parametrization( + DISPATCHER_INFOS, + args_kwargs_fn=lambda info: info.sample_inputs(), + ) + def test_dispatch_datapoint(self, info, args_kwargs, spy_on): + (datapoint, *other_args), kwargs = args_kwargs.load() + + input_type = type(datapoint) + + wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type] + + # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the + # proper kernel was wrapped + if hasattr(wrapped_kernel, "__wrapped__"): + assert wrapped_kernel.__wrapped__ is info.kernels[input_type] + + spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__) + with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}): + info.dispatcher(datapoint, *other_args, **kwargs) + + spy.assert_called_once() + @make_info_args_kwargs_parametrization( DISPATCHER_INFOS, args_kwargs_fn=lambda info: info.sample_inputs(), From b3c2c88a4ebec34a692fbd589d2a02158ba3a154 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 14:20:28 +0200 Subject: [PATCH 20/22] minor test fixes --- test/test_transforms_v2.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0a187c4903d..437ceb6621d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1331,7 +1331,6 @@ def test__transform(self, inpt): def test_antialias_warning(): pil_img = PIL.Image.new("RGB", size=(10, 10), color=127) tensor_img = torch.randint(0, 256, size=(3, 10, 10), dtype=torch.uint8) - tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8) match = "The default value of the antialias parameter" with pytest.warns(UserWarning, match=match): @@ -1343,12 +1342,6 @@ def test_antialias_warning(): with pytest.warns(UserWarning, match=match): transforms.RandomResize(10, 20)(tensor_img) - with pytest.warns(UserWarning, match=match): - datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20)) - - with pytest.warns(UserWarning, match=match): - datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20)) - with warnings.catch_warnings(): warnings.simplefilter("error") transforms.RandomResizedCrop((20, 20))(pil_img) @@ -1361,9 +1354,6 @@ def test_antialias_warning(): transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img) transforms.RandomResize(10, 20, antialias=True)(tensor_img) - datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True) - datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True) - @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("label_type", (torch.Tensor, int)) From a1f5ea47a5b3693c39feedd6c7836328230d145d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 15:31:44 +0200 Subject: [PATCH 21/22] Update torchvision/transforms/v2/functional/_utils.py Co-authored-by: Nicolas Hug --- torchvision/transforms/v2/functional/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 3cd9b7be55c..63e029d6c77 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -10,6 +10,7 @@ def is_simple_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint) +# {dispatcher: {input_type: type_specific_kernel}} _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {} From d29d95b126734cafb17bd743cc1cf6279119fa5c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Aug 2023 15:36:56 +0200 Subject: [PATCH 22/22] reinstante antialias tests --- test/test_transforms_v2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 437ceb6621d..49455b05dc5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1331,6 +1331,7 @@ def test__transform(self, inpt): def test_antialias_warning(): pil_img = PIL.Image.new("RGB", size=(10, 10), color=127) tensor_img = torch.randint(0, 256, size=(3, 10, 10), dtype=torch.uint8) + tensor_video = torch.randint(0, 256, size=(2, 3, 10, 10), dtype=torch.uint8) match = "The default value of the antialias parameter" with pytest.warns(UserWarning, match=match): @@ -1342,6 +1343,14 @@ def test_antialias_warning(): with pytest.warns(UserWarning, match=match): transforms.RandomResize(10, 20)(tensor_img) + with pytest.warns(UserWarning, match=match): + F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20)) + + with pytest.warns(UserWarning, match=match): + F.resize(datapoints.Video(tensor_video), (20, 20)) + with pytest.warns(UserWarning, match=match): + F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20)) + with warnings.catch_warnings(): warnings.simplefilter("error") transforms.RandomResizedCrop((20, 20))(pil_img) @@ -1354,6 +1363,9 @@ def test_antialias_warning(): transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img) transforms.RandomResize(10, 20, antialias=True)(tensor_img) + F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True) + F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True) + @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("label_type", (torch.Tensor, int))