Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand has_any and has_all to also accept check callables #6447

Merged
merged 10 commits into from
Aug 18, 2022
83 changes: 83 additions & 0 deletions test/test_prototype_transforms_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import PIL.Image
import pytest

import torch

from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask

from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor
from torchvision.prototype.transforms.functional import to_image_pil


IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)


@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False),
((IMAGE,), (features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
],
)
def test_has_any(sample, types, expected):
assert has_any(sample, *types) is expected


@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),),
True,
),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
],
)
def test_has_all(sample, types, expected):
assert has_all(sample, *types) is expected
6 changes: 4 additions & 2 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision.prototype.transforms import functional as F, Transform

from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -105,7 +105,9 @@ def __init__(self, *, alpha: float) -> None:

def forward(self, *inpts: Any) -> Any:
sample = inpts if len(inpts) > 1 else inpts[0]
if not has_all(sample, features.Image, features.OneHotLabel):
if not (
has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError(
Expand Down
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
# TODO: Allow image to be a torch.Tensor
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
Expand Down
30 changes: 25 additions & 5 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Tuple, Type, Union
from typing import Any, Callable, Tuple, Type, Union

import PIL.Image
import torch
Expand Down Expand Up @@ -39,14 +39,34 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
return channels, height, width


def has_any(sample: Any, *types: Type) -> bool:
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
return any(issubclass(type(obj), types) for obj in flat_sample)
for type_or_check in types_or_checks:
passed_check = False
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
passed_check = True
break

if passed_check:
return True
pmeier marked this conversation as resolved.
Show resolved Hide resolved

def has_all(sample: Any, *types: Type) -> bool:
return False


def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
return not bool(set(types) - set([type(obj) for obj in flat_sample]))
for type_or_check in types_or_checks:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are asymmetric because the behavior we want for has_any is any(any(...)) and for has_all is any(all(...)). If has_all would be all(all(...)) they would be symmetric.

passed_check = False
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
passed_check = True
break

if not passed_check:
return False
pmeier marked this conversation as resolved.
Show resolved Hide resolved

return True


def is_simple_tensor(inpt: Any) -> bool:
Expand Down