Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add example for v2 wrapping for custom datasets #7514

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions gallery/plot_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
torchvision.disable_beta_transforms_warning()

from torchvision import datapoints
from torchvision.transforms.v2 import functional as F


########################################################################################################################
Expand Down Expand Up @@ -93,6 +94,68 @@
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
# also don't have to wrap manually.
#
# If you have a custom dataset, for example the ``PennFudanDataset`` from
# `this tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_, you have two options:
#
# 1. Perform the wrapping inside ``__getitem__``:

class PennFudanDataset(torch.utils.data.Dataset):
...

def __getitem__(self, item):
...

target["boxes"] = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
target["labels"] = labels
target["masks"] = datapoints.Mask(masks)

...

if self.transforms is not None:
img, target = self.transforms(img, target)

...

########################################################################################################################
# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline:


class WrapPennFudanDataset:
def __call__(self, img, target):
target["boxes"] = datapoints.BoundingBox(
target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
target["masks"] = datapoints.Mask(target["masks"])
return img, target


...


def get_transform(train):
transforms = []
transforms.append(WrapPennFudanDataset())
transforms.append(T.PILToTensor())
...

########################################################################################################################
# .. note::
#
# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
# at least not wrapping the obsolete parts, can lead to a significant performance boost.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not wrapping the obsolete parts

This works due to the heuristic we have in place. I wonder if we should link to the explanation here or refrain from it to avoid confusion.

#
# For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids
# transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be
# even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
# generated from the masks.
#
# How do the datapoints behave inside a computation?
# --------------------------------------------------
#
Expand All @@ -101,6 +164,7 @@
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flake8 was complaining here. No idea why this wasn't the case before.

assert isinstance(image, datapoints.Image)

new_image = image + 0
Expand Down
4 changes: 3 additions & 1 deletion torchvision/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def __init__(self, dataset, target_keys):
if not isinstance(dataset, datasets.VisionDataset):
raise TypeError(
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f"but got a '{dataset_cls.__name__}' instead."
f"but got a '{dataset_cls.__name__}' instead.\n"
f"For an example of how to perform the wrapping for custom datasets, see\n\n"
"https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
)

for cls in dataset_cls.mro():
Expand Down