From 1bf5049cfe55a16ad2a6bc00e52caa6916e83b0d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 13 Apr 2023 10:26:08 +0200 Subject: [PATCH] add example for v2 wrapping for custom datasets --- gallery/plot_datapoints.py | 64 ++++++++++++++++++++++ torchvision/datapoints/_dataset_wrapper.py | 4 +- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 83ca6793598..5094de13a3e 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -20,6 +20,7 @@ torchvision.disable_beta_transforms_warning() from torchvision import datapoints +from torchvision.transforms.v2 import functional as F ######################################################################################################################## @@ -93,6 +94,68 @@ # built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you # also don't have to wrap manually. # +# If you have a custom dataset, for example the ``PennFudanDataset`` from +# `this tutorial `_, you have two options: +# +# 1. Perform the wrapping inside ``__getitem__``: + +class PennFudanDataset(torch.utils.data.Dataset): + ... + + def __getitem__(self, item): + ... + + target["boxes"] = datapoints.BoundingBox( + boxes, + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=F.get_spatial_size(img), + ) + target["labels"] = labels + target["masks"] = datapoints.Mask(masks) + + ... + + if self.transforms is not None: + img, target = self.transforms(img, target) + + ... + +######################################################################################################################## +# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline: + + +class WrapPennFudanDataset: + def __call__(self, img, target): + target["boxes"] = datapoints.BoundingBox( + target["boxes"], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=F.get_spatial_size(img), + ) + target["masks"] = datapoints.Mask(target["masks"]) + return img, target + + +... + + +def get_transform(train): + transforms = [] + transforms.append(WrapPennFudanDataset()) + transforms.append(T.PILToTensor()) + ... + +######################################################################################################################## +# .. note:: +# +# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in +# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or +# at least not wrapping the obsolete parts, can lead to a significant performance boost. +# +# For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids +# transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be +# even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are +# generated from the masks. +# # How do the datapoints behave inside a computation? # -------------------------------------------------- # @@ -101,6 +164,7 @@ # Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the # datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): + assert isinstance(image, datapoints.Image) new_image = image + 0 diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index 09a5469dde1..d88bc81e62b 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -124,7 +124,9 @@ def __init__(self, dataset, target_keys): if not isinstance(dataset, datasets.VisionDataset): raise TypeError( f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, " - f"but got a '{dataset_cls.__name__}' instead." + f"but got a '{dataset_cls.__name__}' instead.\n" + f"For an example of how to perform the wrapping for custom datasets, see\n\n" + "https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself" ) for cls in dataset_cls.mro():