Skip to content

Commit

Permalink
[fbsync] Simple tensor -> pure tensor (#7846)
Browse files Browse the repository at this point in the history
Reviewed By: matteobettini

Differential Revision: D48642272

fbshipit-source-id: 72faafd56c8582531e5e639fcc0daa757f06bee6
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Aug 25, 2023
1 parent acf1d79 commit 2ed29ba
Show file tree
Hide file tree
Showing 20 changed files with 89 additions and 89 deletions.
10 changes: 5 additions & 5 deletions test/test_prototype_datasets_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor


def assert_samples_equal(*args, msg=None, **kwargs):
Expand Down Expand Up @@ -140,18 +140,18 @@ def make_msg_and_close(head):
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
def test_no_unaccompanied_pure_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))

simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}
pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)}

if simple_tensors and not any(
if pure_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
f"{sequence_to_str(sorted(pure_tensors), separate_last='and ')} contained pure tensors, "
f"but didn't find any (encoded) image or video."
)

Expand Down
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_simple_tensor
from torchvision.transforms.v2.utils import check_type, is_pure_tensor

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]

Expand Down Expand Up @@ -296,7 +296,7 @@ def test_call(self, dims, inverse_dims):
value_type = type(value)
transformed_value = transformed_sample[key]

if check_type(value, (Image, is_simple_tensor, Video)):
if check_type(value, (Image, is_pure_tensor, Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
Expand Down Expand Up @@ -341,7 +341,7 @@ def test_call(self, dims):
transformed_value = transformed_sample[key]

transposed_dims = transform.dims.get(value_type)
if check_type(value, (Image, is_simple_tensor, Video)):
if check_type(value, (Image, is_pure_tensor, Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
Expand Down
38 changes: 19 additions & 19 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw


def make_vanilla_tensor_images(*args, **kwargs):
Expand Down Expand Up @@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device):
if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
Expand Down Expand Up @@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device):
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
Expand Down Expand Up @@ -357,19 +357,19 @@ def test_random_resized_crop(self, transform, input):
3,
),
)
def test_simple_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split):
def test_pure_tensor_heuristic(flat_inputs):
def split_on_pure_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 1. The first pure tensor. If none is present, this will be `None`
# 2. A list of the remaining pure tensors
# 3. A list of all other items
simple_tensors = []
pure_tensors = []
others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others
(pure_tensors if is_pure_tensor(inpt) else others).append(item)
return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others

class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params):
Expand All @@ -385,20 +385,20 @@ def was_applied(output, inpt):
assert_equal(output, inpt)
return True

first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)
first_pure_tensor_input, other_pure_tensor_inputs, other_inputs = split_on_pure_tensor(flat_inputs)

transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs)

first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)
first_pure_tensor_output, other_pure_tensor_outputs, other_outputs = split_on_pure_tensor(transformed_sample)

if first_simple_tensor_input is not None:
if first_pure_tensor_input is not None:
if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
assert not transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
assert transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)

for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
for output, inpt in zip(other_pure_tensor_outputs, other_pure_tensor_inputs):
assert not transform.was_applied(output, inpt)

for input, output in zip(other_inputs, other_outputs):
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
assert is_pure_tensor(image)

label = 1 if label_type is int else torch.tensor([1])

Expand Down Expand Up @@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
assert is_pure_tensor(image)

label = torch.randint(0, 10, size=(num_boxes,))

Expand All @@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out = t(sample)

if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image:
assert is_simple_tensor(out["image"])
assert is_pure_tensor(out["image"])
else:
assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"]))
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def check_call_consistency(
raise AssertionError(
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`is_simple_tensor` path in `_transform`."
f"`is_pure_tensor` path in `_transform`."
) from exc

assert_close(
Expand Down
26 changes: 13 additions & 13 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS

Expand Down Expand Up @@ -168,7 +168,7 @@ def _unbatch(self, batch, *, data_dims):
def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)

datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
Expand Down Expand Up @@ -333,9 +333,9 @@ def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher)

(image_datapoint, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_datapoint)
image_pure_tensor = torch.Tensor(image_datapoint)

dispatcher(image_simple_tensor, *other_args, **kwargs)
dispatcher(image_pure_tensor, *other_args, **kwargs)

# TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
# replaces this test for them.
Expand All @@ -358,11 +358,11 @@ def test_scriptable(self, dispatcher):
script(dispatcher)

@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
def test_pure_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)
image_pure_tensor = image_datapoint.as_subclass(torch.Tensor)

output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)
output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)

# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor
Expand Down Expand Up @@ -505,11 +505,11 @@ class TestClampBoundingBoxes:
dict(canvas_size=(1, 1)),
],
)
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
def test_pure_tensor_insufficient_metadata(self, metadata):
pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
F.clamp_bounding_boxes(simple_tensor, **metadata)
F.clamp_bounding_boxes(pure_tensor, **metadata)

@pytest.mark.parametrize(
"metadata",
Expand Down Expand Up @@ -538,11 +538,11 @@ def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_boxes(inpt, old_format)

def test_simple_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)

def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes())
Expand Down
6 changes: 3 additions & 3 deletions test/test_transforms_v2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True,
),
(
(to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True,
),
],
Expand Down
2 changes: 1 addition & 1 deletion test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform

from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor


class SimpleCopyPaste(Transform):
Expand Down Expand Up @@ -109,7 +109,7 @@ def _extract_image_targets(
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj):
if isinstance(obj, datapoints.Image) or is_pure_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj))
Expand Down Expand Up @@ -146,7 +146,7 @@ def _insert_outputs(
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes):
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size


class FixedSizeCrop(Transform):
Expand All @@ -32,7 +32,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
flat_inputs,
PIL.Image.Image,
datapoints.Image,
is_simple_tensor,
is_pure_tensor,
datapoints.Video,
):
raise TypeError(
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision import datapoints
from torchvision.transforms.v2 import Transform

from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor


T = TypeVar("T")
Expand All @@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:


class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)

def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
Expand All @@ -47,7 +47,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:


class TransposeDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)

def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter
from .utils import has_any, is_simple_tensor, query_chw, query_size
from .utils import has_any, is_pure_tensor, query_chw, query_size


class RandomErasing(_RandomApplyTransform):
Expand Down Expand Up @@ -243,7 +243,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])

output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
Expand Down Expand Up @@ -310,7 +310,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])

x1, y1, x2, y2 = params["box"]
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT

from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor
from .utils import check_type, is_pure_tensor


ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
Expand Down Expand Up @@ -50,7 +50,7 @@ def _flatten_and_extract_image_or_video(
(
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
is_pure_tensor,
datapoints.Video,
),
):
Expand Down
Loading

0 comments on commit 2ed29ba

Please sign in to comment.