Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More datapoints docs and comments #7830

Merged
merged 5 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions gallery/plot_custom_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
How to write your own Datapoint class
=====================================

This guide is intended for downstream library maintainers. We explain how to
This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_plot_datapoints.py`.
Expand Down Expand Up @@ -68,10 +68,6 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# could also have used the functional *itself*, i.e.
# ``@register_kernel(functional=F.hflip, ...)``.
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# :ref:`functional_transforms`.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:

Expand Down
82 changes: 57 additions & 25 deletions gallery/plot_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,22 @@
# Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function
# for the input data.
#
# :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# What can I do with a datapoint?
# -------------------------------
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or
# any ``torch.*`` operator will also works on datapoints. See
# any ``torch.*`` operator will also work on datapoints. See
# :ref:`datapoint_unwrapping_behaviour` for a few gotchas.

# %%
#
# What datapoints are supported?
# ------------------------------
#
# So far :mod:`torchvision.datapoints` supports four types of datapoints:
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBoxes`
# * :class:`~torchvision.datapoints.Mask`
#
# .. _datapoint_creation:
#
# How do I construct a datapoint?
Expand Down Expand Up @@ -209,42 +205,78 @@ def get_transform(train):
# I had a Datapoint but now I have a Tensor. Help!
# ------------------------------------------------
#
# For a lot of operations involving datapoints, we cannot safely infer whether
# the result should retain the datapoint type, so we choose to return a plain
# tensor instead of a datapoint (this might change, see note below):
# By default, operations on :class:`~torchvision.datapoints.Datapoint` objects
# will return a pure Tensor:


assert isinstance(bboxes, datapoints.BoundingBoxes)

# Shift bboxes by 3 pixels in both H and W
new_bboxes = bboxes + 3

assert isinstance(new_bboxes, torch.Tensor) and not isinstance(new_bboxes, datapoints.BoundingBoxes)
assert isinstance(new_bboxes, torch.Tensor)
assert not isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
# .. note::
#
# This behavior only affects native ``torch`` operations. If you are using
# the built-in ``torchvision`` transforms or functionals, you will always get
# as output the same type that you passed as input (pure ``Tensor`` or
# ``Datapoint``).

# %%
# If you're writing your own custom transforms or code involving datapoints, you
# can re-wrap the output into a datapoint by just calling their constructor, or
# by using the ``.wrap_like()`` class method:
# But I want a Datapoint back!
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details
# above in :ref:`datapoint_creation`):

new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-blocking thought to reduce the number of chars to type, it could be simple to have bboxes.wrap_like(new_bboxes) or bboxes.wrap_to(new_bboxes).
I think we already thought about that earlier but failed to get a good naming ...

Copy link
Member Author

@NicolasHug NicolasHug Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I find it a bit long and clunky as well.

bboxes.wrap_to(new_bboxes).

That would be the best UX but unfortunately we can't do that because wrap_to would have to be a method on pure Tensors.

We can implement instance methods on BBoxes objects, but I can't think of a good name either right now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have #7780 to go from BoundingBoxes to BBoxes. I'm somewhat against doing bboxes.wrap_like(...) here, since it no longer makes it clear that this is a class and not instance method.

assert isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
# See more details above in :ref:`datapoint_creation`.
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager:

with datapoints.set_return_type("datapoint"):
new_bboxes = bboxes + 3
assert isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
# Why is this happening?
# ^^^^^^^^^^^^^^^^^^^^^^
#
# .. note::
# **For performance reasons**. :class:`~torchvision.datapoints.Datapoint`
# classes are Tensor subclasses, so any operation involving a
# :class:`~torchvision.datapoints.Datapoint` object will go through the
# `__torch_function__
# <https://pytorch.org/docs/stable/notes/extending.html#extending-torch>`_
# protocol. This induces a small overhead, which we want to avoid when possible.
# This doesn't matter for built-in ``torchvision`` transforms because we can
# avoid the overhead there, but it could be a problem in your model's
# ``forward``.
#
# You never need to re-wrap manually if you're using the built-in transforms
# or their functional equivalents: this is automatically taken care of for
# you.
# **The alternative isn't much better anyway.** For every operation where
# preserving the :class:`~torchvision.datapoints.Datapoint` type makes
# sense, there are just as many operations where returning a pure Tensor is
# preferable: for example, is ``img.sum()`` still an :class:`~torchvision.datapoints.Image`?
# If we were to preserve :class:`~torchvision.datapoints.Datapoint` types all
# the way, even model's logits or the output of the loss function would end up
# being of type :class:`~torchvision.datapoints.Image`, and surely that's not
# desirable.
#
# .. note::
#
# This "unwrapping" behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# This behaviour is something we're actively seeking feedback on. If you find this surprising or if you
# have any suggestions on how to better support your use-cases, please reach out to us via this issue:
# https://github.com/pytorch/vision/issues/7319
#
# Exceptions
# ^^^^^^^^^^
#
# There are a few exceptions to this "unwrapping" rule:
#
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
Expand Down
1 change: 1 addition & 0 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_to_datapoint_reference(make_input, return_type):

assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor)
assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
Expand Down
28 changes: 11 additions & 17 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,12 @@ def __torch_function__(
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.

The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Datapoint`
use case, this has two downsides:
Why do we override this? Because the base implementation in torch.Tensor would preserve the Datapoint type
of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
"Datapoints FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).

1. Since some :class:`Datapoint`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.

For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in _FORCE_TORCHFUNCTION_SUBCLASS
Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.

if not all(issubclass(cls, t) for t in types):
return NotImplemented

Expand All @@ -89,12 +82,13 @@ def __torch_function__(

must_return_subclass = _must_return_subclass()
if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
# If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
# in test_to_datapoint_reference().
# The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
# the computation by walking the MRO upwards. For example,
# `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
# `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would
# be wrapped into an `Image`.
return cls._wrap_output(output, args, kwargs)

if not must_return_subclass and isinstance(output, cls):
Expand Down
9 changes: 8 additions & 1 deletion torchvision/datapoints/_torch_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,26 @@ def __exit__(self, *args):
def set_return_type(return_type: str):
"""Set the return type of torch operations on datapoints.

This only affects the behaviour of torch operations. It has no effect on
``torchvision`` transforms or functionals, which will always return as
output the same type that was passed as input.

Can be used as a global flag for the entire program:

.. code:: python

set_return_type("datapoints")
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor (default behaviour)

set_return_type("datapoints")
img + 2 # This is an Image

or as a context manager to restrict the scope:

.. code:: python

img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is a pure Tensor
with set_return_type("datapoints"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
Expand Down
11 changes: 9 additions & 2 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@ def is_simple_tensor(inpt: Any) -> bool:
def _kernel_datapoint_wrapper(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
# We always pass datapoints as pure tensors to the kernels to avoid going through the
# Tensor.__torch_function__ logic, which is costly.
# If you're wondering whether we could / should get rid of this wrapper,
# the answer is no: we want to pass pure Tensors to avoid the overhead
# of the __torch_function__ machinery. Note that this is always valid,
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap_like(), because the Datapoint type would be
# lost after the first operation due to our own __torch_function__
# logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return type(inpt).wrap_like(inpt, output)

Expand Down