diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py index f21799e86c8..137a00ccf55 100644 --- a/references/segmentation/v2_extras.py +++ b/references/segmentation/v2_extras.py @@ -11,7 +11,7 @@ def __init__(self, size, fill=0): self.fill = v2._utils._setup_fill_arg(fill) def _get_params(self, sample): - _, height, width = v2.utils.query_chw(sample) + _, height, width = v2._utils.query_chw(sample) padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] needs_padding = any(padding) return dict(padding=padding, needs_padding=needs_padding) diff --git a/test/test_prototype_datasets_builtin.py b/test/test_prototype_datasets_builtin.py index e29dfb17fe1..8497ea27b54 100644 --- a/test/test_prototype_datasets_builtin.py +++ b/test/test_prototype_datasets_builtin.py @@ -25,7 +25,7 @@ from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE -from torchvision.transforms.v2.utils import is_pure_tensor +from torchvision.transforms.v2._utils import is_pure_tensor def assert_samples_equal(*args, msg=None, **kwargs): diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 0410ecadc48..b4e1d108748 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -10,8 +10,8 @@ from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms +from torchvision.transforms.v2._utils import check_type, is_pure_tensor from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image -from torchvision.transforms.v2.utils import check_type, is_pure_tensor from transforms_v2_legacy_utils import ( DEFAULT_EXTRA_DIMS, make_bounding_boxes, diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9630132e271..26dde640788 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -16,7 +16,7 @@ from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import to_pil_image from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw +from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw from transforms_v2_legacy_utils import ( make_bounding_boxes, make_detection_mask, diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 4e8595d2185..0d11f610a89 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -19,9 +19,8 @@ from torchvision.transforms import functional as legacy_F from torchvision.transforms.v2 import functional as prototype_F -from torchvision.transforms.v2._utils import _get_fill +from torchvision.transforms.v2._utils import _get_fill, query_size from torchvision.transforms.v2.functional import to_pil_image -from torchvision.transforms.v2.utils import query_size from transforms_v2_legacy_utils import ( ArgsKwargs, make_bounding_boxes, diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index e6a540ae06b..826ba8b57e1 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -13,9 +13,9 @@ from torchvision import datapoints from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format -from torchvision.transforms.v2.utils import is_pure_tensor from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_legacy_utils import ( diff --git a/test/test_transforms_v2_utils.py b/test/test_transforms_v2_utils.py index 55825d652e6..511b0c364aa 100644 --- a/test/test_transforms_v2_utils.py +++ b/test/test_transforms_v2_utils.py @@ -3,12 +3,12 @@ import torch -import torchvision.transforms.v2.utils +import torchvision.transforms.v2._utils from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image from torchvision import datapoints +from torchvision.transforms.v2._utils import has_all, has_any from torchvision.transforms.v2.functional import to_pil_image -from torchvision.transforms.v2.utils import has_all, has_any IMAGE = make_image(DEFAULT_SIZE, color_space="RGB") @@ -37,15 +37,15 @@ ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), - ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True), + ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True), ( (torch.Tensor(IMAGE),), - (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), + (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True, ), ( (to_pil_image(IMAGE),), - (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), + (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True, ), ], diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index a2f6ebbf498..f4013ffa718 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,9 +7,9 @@ from torchvision.ops import masks_to_boxes from torchvision.prototype import datapoints as proto_datapoints from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform +from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2.functional._geometry import _check_interpolation -from torchvision.transforms.v2.utils import is_pure_tensor class SimpleCopyPaste(Transform): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index bf97d1f605b..3b7e6878170 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,8 +6,16 @@ from torchvision import datapoints from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size +from torchvision.transforms.v2._utils import ( + _FillType, + _get_fill, + _setup_fill_arg, + _setup_size, + get_bounding_boxes, + has_any, + is_pure_tensor, + query_size, +) class FixedSizeCrop(Transform): diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 0dd495ab05b..fa812bbbbe9 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -8,7 +8,7 @@ from torchvision import datapoints from torchvision.transforms.v2 import Transform -from torchvision.transforms.v2.utils import is_pure_tensor +from torchvision.transforms.v2._utils import is_pure_tensor T = TypeVar("T") diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index b60962748d1..dbc0474d307 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip -from . import functional, utils # usort: skip +from . import functional # usort: skip from ._transform import Transform # usort: skip diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index a6af96a5ef6..130950fee34 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -11,8 +11,7 @@ from torchvision.transforms.v2 import functional as F from ._transform import _RandomApplyTransform, Transform -from ._utils import _parse_labels_getter -from .utils import has_any, is_pure_tensor, query_chw, query_size +from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size class RandomErasing(_RandomApplyTransform): diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 2c82d092ec2..664210ff7e7 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -12,8 +12,7 @@ from torchvision.transforms.v2.functional._meta import get_size from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT -from ._utils import _get_fill, _setup_fill_arg -from .utils import check_type, is_pure_tensor +from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video] diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index a3792797959..efe731b5ec9 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -6,7 +6,7 @@ from torchvision.transforms.v2 import functional as F, Transform from ._transform import _RandomApplyTransform -from .utils import query_chw +from ._utils import query_chw class Grayscale(Transform): diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 2c54b53d4c5..4f94b37aa31 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -23,8 +23,12 @@ _setup_fill_arg, _setup_float_or_seq, _setup_size, + get_bounding_boxes, + has_all, + has_any, + is_pure_tensor, + query_size, ) -from .utils import get_bounding_boxes, has_all, has_any, is_pure_tensor, query_size class RandomHorizontalFlip(_RandomApplyTransform): diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 6974c62b02e..c17530ecfb9 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,8 +9,7 @@ from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform -from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size -from .utils import get_bounding_boxes, has_any, is_pure_tensor +from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor # TODO: do we want/need to expose this? diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index e9af4b426fa..f377c822a2d 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,7 +8,7 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints -from torchvision.transforms.v2.utils import check_type, has_any, is_pure_tensor +from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py index 26d23375400..e92c98e6cb3 100644 --- a/torchvision/transforms/v2/_type_conversion.py +++ b/torchvision/transforms/v2/_type_conversion.py @@ -7,7 +7,7 @@ from torchvision import datapoints from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2.utils import is_pure_tensor +from torchvision.transforms.v2._utils import is_pure_tensor class PILToTensor(Transform): @@ -44,23 +44,24 @@ def _transform( class ToPILImage(Transform): - """[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values. + """[BETA] Convert a tensor or an ndarray to PIL Image .. v2betastatus:: ToPILImage transform This transform does not support torchscript. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape - H x W x C to a PIL Image while preserving the value range. + H x W x C to a PIL Image while adjusting the value range depending on the ``mode``. Args: mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). If ``mode`` is ``None`` (default) there are some assumptions made about the input data: + - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, - ``short``). + ``short``). .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index f9d9bae49e9..3c6977fae91 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -1,11 +1,20 @@ +from __future__ import annotations + import collections.abc import numbers from contextlib import suppress -from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union + +import PIL.Image import torch +from torchvision import datapoints + +from torchvision._utils import sequence_to_str + from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -138,3 +147,73 @@ def _parse_labels_getter( return lambda _: None else: raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.") + + +def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: + # This assumes there is only one bbox per sample as per the general convention + try: + return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)) + except StopIteration: + raise ValueError("No bounding boxes were found in the sample") + + +def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: + chws = { + tuple(get_dimensions(inpt)) + for inpt in flat_inputs + if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) + } + if not chws: + raise TypeError("No image or video was found in the sample") + elif len(chws) > 1: + raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") + c, h, w = chws.pop() + return c, h, w + + +def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: + sizes = { + tuple(get_size(inpt)) + for inpt in flat_inputs + if check_type( + inpt, + ( + is_pure_tensor, + datapoints.Image, + PIL.Image.Image, + datapoints.Video, + datapoints.Mask, + datapoints.BoundingBoxes, + ), + ) + } + if not sizes: + raise TypeError("No image, video, mask or bounding box was found in the sample") + elif len(sizes) > 1: + raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") + h, w = sizes.pop() + return h, w + + +def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: + for type_or_check in types_or_checks: + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): + return True + return False + + +def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: + for inpt in flat_inputs: + if check_type(inpt, types_or_checks): + return True + return False + + +def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: + for type_or_check in types_or_checks: + for inpt in flat_inputs: + if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt): + break + else: + return False + return True diff --git a/torchvision/transforms/v2/utils.py b/torchvision/transforms/v2/utils.py deleted file mode 100644 index 1e4ff2d05aa..00000000000 --- a/torchvision/transforms/v2/utils.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable, List, Tuple, Type, Union - -import PIL.Image -from torchvision import datapoints - -from torchvision._utils import sequence_to_str -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor - - -def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: - # This assumes there is only one bbox per sample as per the general convention - try: - return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)) - except StopIteration: - raise ValueError("No bounding boxes were found in the sample") - - -def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: - chws = { - tuple(get_dimensions(inpt)) - for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) - } - if not chws: - raise TypeError("No image or video was found in the sample") - elif len(chws) > 1: - raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") - c, h, w = chws.pop() - return c, h, w - - -def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: - sizes = { - tuple(get_size(inpt)) - for inpt in flat_inputs - if check_type( - inpt, - ( - is_pure_tensor, - datapoints.Image, - PIL.Image.Image, - datapoints.Video, - datapoints.Mask, - datapoints.BoundingBoxes, - ), - ) - } - if not sizes: - raise TypeError("No image, video, mask or bounding box was found in the sample") - elif len(sizes) > 1: - raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}") - h, w = sizes.pop() - return h, w - - -def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: - for type_or_check in types_or_checks: - if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): - return True - return False - - -def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: - for inpt in flat_inputs: - if check_type(inpt, types_or_checks): - return True - return False - - -def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: - for type_or_check in types_or_checks: - for inpt in flat_inputs: - if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt): - break - else: - return False - return True