Skip to content

Commit

Permalink
expand has_any and has_all to also accept check callables (#6447)
Browse files Browse the repository at this point in the history
* expand has_any and has_all to also accept check callables

* add test and fix has_all

* add support for simple tensor images to CutMix, MixUp and RandomIoUCrop

* remove TODO

* remove pythonic syntax sugar

* simplify

* use concreate examples in test rather than abstract ones

* simplify further
  • Loading branch information
pmeier authored Aug 18, 2022
1 parent 80c197a commit 330b6c9
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 9 deletions.
83 changes: 83 additions & 0 deletions test/test_prototype_transforms_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import PIL.Image
import pytest

import torch

from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask

from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor
from torchvision.prototype.transforms.functional import to_image_pil


IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)


@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False),
((IMAGE,), (features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True),
],
)
def test_has_any(sample, types, expected):
assert has_any(sample, *types) is expected


@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),),
True,
),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
],
)
def test_has_all(sample, types, expected):
assert has_all(sample, *types) is expected
6 changes: 4 additions & 2 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision.prototype.transforms import functional as F

from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -105,7 +105,9 @@ def __init__(self, *, alpha: float, p: float = 0.5) -> None:

def forward(self, *inpts: Any) -> Any:
sample = inpts if len(inpts) > 1 else inpts[0]
if not has_all(sample, features.Image, features.OneHotLabel):
if not (
has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError(
Expand Down
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
# TODO: Allow image to be a torch.Tensor
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
Expand Down
20 changes: 15 additions & 5 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Tuple, Type, Union
from typing import Any, Callable, Tuple, Type, Union

import PIL.Image
import torch
Expand Down Expand Up @@ -39,14 +39,24 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
return channels, height, width


def has_any(sample: Any, *types: Type) -> bool:
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
return any(issubclass(type(obj), types) for obj in flat_sample)
for type_or_check in types_or_checks:
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False


def has_all(sample: Any, *types: Type) -> bool:
def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
return not bool(set(types) - set([type(obj) for obj in flat_sample]))
for type_or_check in types_or_checks:
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
break
else:
return False
return True


def is_simple_tensor(inpt: Any) -> bool:
Expand Down

0 comments on commit 330b6c9

Please sign in to comment.