Skip to content

Commit

Permalink
[fbsync] Make v2.utils private. (#7863)
Browse files Browse the repository at this point in the history
Summary: (Note: this ignores all push blocking failures!)

Reviewed By: matteobettini

Differential Revision: D48900397

fbshipit-source-id: 8271d35783c58b979ae6cce041854b59bde98e9a
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 6, 2023
1 parent 0d4aa66 commit dca5b0a
Show file tree
Hide file tree
Showing 20 changed files with 115 additions and 107 deletions.
2 changes: 1 addition & 1 deletion references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, size, fill=0):
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
_, height, width = v2._utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_pure_tensor
from torchvision.transforms.v2._utils import is_pure_tensor


def assert_samples_equal(*args, msg=None, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_pure_tensor
from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_boxes,
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw
from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw
from transforms_v2_legacy_utils import (
make_bounding_boxes,
make_detection_mask,
Expand Down
3 changes: 1 addition & 2 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@

from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2._utils import _get_fill, query_size
from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import query_size
from transforms_v2_legacy_utils import (
ArgsKwargs,
make_bounding_boxes,
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
from transforms_v2_legacy_utils import (
Expand Down
10 changes: 5 additions & 5 deletions test/test_transforms_v2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import torch

import torchvision.transforms.v2.utils
import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image

from torchvision import datapoints
from torchvision.transforms.v2._utils import has_all, has_any
from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import has_all, has_any


IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
Expand Down Expand Up @@ -37,15 +37,15 @@
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True,
),
(
(to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True,
),
],
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._utils import is_pure_tensor

from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import is_pure_tensor


class SimpleCopyPaste(Transform):
Expand Down
12 changes: 10 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size
from torchvision.transforms.v2._utils import (
_FillType,
_get_fill,
_setup_fill_arg,
_setup_size,
get_bounding_boxes,
has_any,
is_pure_tensor,
query_size,
)


class FixedSizeCrop(Transform):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision import datapoints
from torchvision.transforms.v2 import Transform

from torchvision.transforms.v2.utils import is_pure_tensor
from torchvision.transforms.v2._utils import is_pure_tensor


T = TypeVar("T")
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip

from . import functional, utils # usort: skip
from . import functional # usort: skip

from ._transform import Transform # usort: skip

Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from torchvision.transforms.v2 import functional as F

from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter
from .utils import has_any, is_pure_tensor, query_chw, query_size
from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size


class RandomErasing(_RandomApplyTransform):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT

from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_pure_tensor
from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor


ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision.transforms.v2 import functional as F, Transform

from ._transform import _RandomApplyTransform
from .utils import query_chw
from ._utils import query_chw


class Grayscale(Transform):
Expand Down
6 changes: 5 additions & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
get_bounding_boxes,
has_all,
has_any,
is_pure_tensor,
query_size,
)
from .utils import get_bounding_boxes, has_all, has_any, is_pure_tensor, query_size


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import get_bounding_boxes, has_any, is_pure_tensor
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor


# TODO: do we want/need to expose this?
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_pure_tensor
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once

from .functional._utils import _get_kernel
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform

from torchvision.transforms.v2.utils import is_pure_tensor
from torchvision.transforms.v2._utils import is_pure_tensor


class PILToTensor(Transform):
Expand Down
81 changes: 80 additions & 1 deletion torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import collections.abc
import numbers
from contextlib import suppress
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union

from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union

import PIL.Image
import torch

from torchvision import datapoints

from torchvision._utils import sequence_to_str

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


Expand Down Expand Up @@ -138,3 +147,73 @@ def _parse_labels_getter(
return lambda _: None
else:
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")


def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w


def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_size(inpt))
for inpt in flat_inputs
if check_type(
inpt,
(
is_pure_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
),
)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w


def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False


def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False


def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True
79 changes: 0 additions & 79 deletions torchvision/transforms/v2/utils.py

This file was deleted.

0 comments on commit dca5b0a

Please sign in to comment.