Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch/vision into AAAAAAAAAaaaaaa…
Browse files Browse the repository at this point in the history
…aaaaa
  • Loading branch information
NicolasHug committed Aug 23, 2023
2 parents d3d6731 + 11e49de commit cf4e6c4
Show file tree
Hide file tree
Showing 20 changed files with 119 additions and 110 deletions.
2 changes: 1 addition & 1 deletion references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, size, fill=0):
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
_, height, width = v2._utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_pure_tensor
from torchvision.transforms.v2._utils import is_pure_tensor


def assert_samples_equal(*args, msg=None, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_pure_tensor
from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_boxes,
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw
from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw
from transforms_v2_legacy_utils import (
make_bounding_boxes,
make_detection_mask,
Expand Down
3 changes: 1 addition & 2 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2._utils import _get_fill, query_size
from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import query_size
from transforms_v2_legacy_utils import (
ArgsKwargs,
make_bounding_boxes,
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
from transforms_v2_legacy_utils import (
Expand Down
10 changes: 5 additions & 5 deletions test/test_transforms_v2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import torch

import torchvision.transforms.v2.utils
import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image

from torchvision import datapoints
from torchvision.transforms.v2._utils import has_all, has_any
from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import has_all, has_any


IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
Expand Down Expand Up @@ -37,15 +37,15 @@
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True,
),
(
(to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True,
),
],
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._utils import is_pure_tensor

from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import is_pure_tensor


class SimpleCopyPaste(Transform):
Expand Down
12 changes: 10 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size
from torchvision.transforms.v2._utils import (
_FillType,
_get_fill,
_setup_fill_arg,
_setup_size,
get_bounding_boxes,
has_any,
is_pure_tensor,
query_size,
)


class FixedSizeCrop(Transform):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision import datapoints
from torchvision.transforms.v2 import Transform

from torchvision.transforms.v2.utils import is_pure_tensor
from torchvision.transforms.v2._utils import is_pure_tensor


T = TypeVar("T")
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip

from . import functional, utils # usort: skip
from . import functional # usort: skip

from ._transform import Transform # usort: skip

Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from torchvision.transforms.v2 import functional as F

from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter
from .utils import has_any, is_pure_tensor, query_chw, query_size
from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size


class RandomErasing(_RandomApplyTransform):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT

from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_pure_tensor
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor


ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision.transforms.v2 import functional as F, Transform

from ._transform import _RandomApplyTransform
from .utils import query_chw
from ._utils import query_chw


class Grayscale(Transform):
Expand Down
6 changes: 5 additions & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
get_bounding_boxes,
has_all,
has_any,
is_pure_tensor,
query_size,
)
from .utils import get_bounding_boxes, has_all, has_any, is_pure_tensor, query_size


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import get_bounding_boxes, has_any, is_pure_tensor
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor


# TODO: do we want/need to expose this?
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_pure_tensor
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once

from .functional._utils import _get_kernel
Expand Down
9 changes: 5 additions & 4 deletions torchvision/transforms/v2/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform

from torchvision.transforms.v2.utils import is_pure_tensor
from torchvision.transforms.v2._utils import is_pure_tensor


class PILToTensor(Transform):
Expand Down Expand Up @@ -44,23 +44,24 @@ def _transform(


class ToPILImage(Transform):
"""[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values.
"""[BETA] Convert a tensor or an ndarray to PIL Image
.. v2betastatus:: ToPILImage transform
This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
H x W x C to a PIL Image while adjusting the value range depending on the ``mode``.
Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
``short``).
``short``).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
Expand Down
81 changes: 80 additions & 1 deletion torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import collections.abc
import numbers
from contextlib import suppress
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union

from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union

import PIL.Image
import torch

from torchvision import datapoints

from torchvision._utils import sequence_to_str

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


Expand Down Expand Up @@ -138,3 +147,73 @@ def _parse_labels_getter(
return lambda _: None
else:
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")


def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w


def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_size(inpt))
for inpt in flat_inputs
if check_type(
inpt,
(
is_pure_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
),
)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w


def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False


def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False


def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True
Loading

0 comments on commit cf4e6c4

Please sign in to comment.