Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove custom types defintions from datapoints module #7814

Merged
merged 2 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torchvision/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS

from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT, Datapoint
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._datapoint import Datapoint
from ._image import Image
from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
from ._video import Video

if _WARN_ABOUT_BETA_TRANSFORMS:
import warnings
Expand Down
9 changes: 1 addition & 8 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union

import PIL.Image
import torch
from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size


D = TypeVar("D", bound="Datapoint")
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]


class Datapoint(torch.Tensor):
Expand Down Expand Up @@ -132,7 +129,3 @@ def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
# `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBoxes.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]


_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor
6 changes: 0 additions & 6 deletions torchvision/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,3 @@ def __new__(

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()


_ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor
_TensorImageType = Union[torch.Tensor, Image]
_TensorImageTypeJIT = torch.Tensor
6 changes: 0 additions & 6 deletions torchvision/datapoints/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,3 @@ def __new__(

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr()


_VideoType = Union[torch.Tensor, Video]
_VideoTypeJIT = torch.Tensor
_TensorVideoType = Union[torch.Tensor, Video]
_TensorVideoTypeJIT = torch.Tensor
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def __init__(

def _copy_paste(
self,
image: datapoints._TensorImageType,
image: Union[torch.Tensor, datapoints.Image],
target: Dict[str, Any],
paste_image: datapoints._TensorImageType,
paste_image: Union[torch.Tensor, datapoints.Image],
paste_target: Dict[str, Any],
random_selection: torch.Tensor,
blending: bool,
resize_interpolation: F.InterpolationMode,
antialias: Optional[bool],
) -> Tuple[datapoints._TensorImageType, Dict[str, Any]]:
) -> Tuple[torch.Tensor, Dict[str, Any]]:

paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
Expand Down Expand Up @@ -106,7 +106,7 @@ def _copy_paste(

def _extract_image_targets(
self, flat_sample: List[Any]
) -> Tuple[List[datapoints._TensorImageType], List[Dict[str, Any]]]:
) -> Tuple[List[Union[torch.Tensor, datapoints.Image]], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
Expand Down Expand Up @@ -137,7 +137,7 @@ def _extract_image_targets(
def _insert_outputs(
self,
flat_sample: List[Any],
output_images: List[datapoints._TensorImageType],
output_images: List[torch.Tensor],
output_targets: List[Dict[str, Any]],
) -> None:
c0, c1, c2, c3 = 0, 0, 0, 0
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size


class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
Expand Down
8 changes: 2 additions & 6 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
)
self.dims = dims

def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
Expand All @@ -63,9 +61,7 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
)
self.dims = dims

def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
Expand Down
24 changes: 14 additions & 10 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT

from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor


ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]


class _AutoAugmentBase(Transform):
def __init__(
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
Expand All @@ -35,7 +39,7 @@ def _flatten_and_extract_image_or_video(
self,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (datapoints.BoundingBoxes, datapoints.Mask),
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[datapoints._ImageType, datapoints._VideoType]]:
) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
needs_transform_list = self._needs_transform_list(flat_inputs)

Expand Down Expand Up @@ -68,20 +72,20 @@ def _flatten_and_extract_image_or_video(
def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[datapoints._ImageType, datapoints._VideoType],
image_or_video: ImageOrVideo,
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
return tree_unflatten(flat_inputs, spec)

def _apply_image_or_video_transform(
self,
image: Union[datapoints._ImageType, datapoints._VideoType],
image: ImageOrVideo,
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill: Dict[Union[Type, str], _FillTypeJIT],
) -> ImageOrVideo:
fill_ = _get_fill(fill, type(image))

if transform_id == "Identity":
Expand Down Expand Up @@ -214,7 +218,7 @@ def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
Expand Down Expand Up @@ -394,7 +398,7 @@ def __init__(
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
Expand Down Expand Up @@ -467,7 +471,7 @@ def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
Expand Down Expand Up @@ -550,7 +554,7 @@ def __init__(
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
Expand Down
4 changes: 1 addition & 3 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params

def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
Expand Down
18 changes: 8 additions & 10 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._utils import _FillType

from ._transform import _RandomApplyTransform
from ._utils import (
Expand Down Expand Up @@ -311,9 +312,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)


ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]


class FiveCrop(Transform):
"""[BETA] Crop the image or video into four corners and the central crop.

Expand Down Expand Up @@ -459,7 +457,7 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -514,7 +512,7 @@ class RandomZoomOut(_RandomApplyTransform):

def __init__(
self,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
Expand Down Expand Up @@ -592,7 +590,7 @@ def __init__(
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
Expand Down Expand Up @@ -674,7 +672,7 @@ def __init__(
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -812,7 +810,7 @@ def __init__(
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand Down Expand Up @@ -931,7 +929,7 @@ def __init__(
distortion_scale: float = 0.5,
p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None:
super().__init__(p=p)

Expand Down Expand Up @@ -1033,7 +1031,7 @@ def __init__(
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
Expand Down
4 changes: 1 addition & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")

def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Any:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)


Expand Down
3 changes: 1 addition & 2 deletions torchvision/transforms/v2/_temporal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict

import torch
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform


Expand All @@ -25,5 +24,5 @@ def __init__(self, num_samples: int):
super().__init__()
self.num_samples = num_samples

def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.uniform_temporal_subsample(inpt, self.num_samples)
5 changes: 2 additions & 3 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import torch

from torchvision import datapoints
from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
Expand Down Expand Up @@ -36,7 +35,7 @@ def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")


def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
def _convert_fill_arg(fill: _FillType) -> _FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
Expand Down
6 changes: 2 additions & 4 deletions torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import PIL.Image

import torch
Expand All @@ -12,14 +10,14 @@

@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
def erase(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
inpt: torch.Tensor,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
) -> torch.Tensor:
if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)

Expand Down
Loading