diff --git a/docs/source/conf.py b/docs/source/conf.py index 33fd64e3a59..66138c2d12e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -76,14 +76,45 @@ """ +class CustomGalleryExampleSortKey: + # See https://sphinx-gallery.github.io/stable/configuration.html#sorting-gallery-examples + # and https://github.com/sphinx-gallery/sphinx-gallery/blob/master/sphinx_gallery/sorting.py + def __init__(self, src_dir): + self.src_dir = src_dir + + transforms_subsection_order = [ + "plot_transforms_getting_started.py", + "plot_transforms_illustrations.py", + "plot_transforms_e2e.py", + "plot_cutmix_mixup.py", + "plot_custom_transforms.py", + "plot_datapoints.py", + "plot_custom_datapoints.py", + ] + + def __call__(self, filename): + if "gallery/transforms" in self.src_dir: + try: + return self.transforms_subsection_order.index(filename) + except ValueError as e: + raise ValueError( + "Looks like you added an example in gallery/transforms? " + "You need to specify its order in docs/source/conf.py. Look for CustomGalleryExampleSortKey." + ) from e + else: + # For other subsections we just sort alphabetically by filename + return filename + + sphinx_gallery_conf = { "examples_dirs": "../../gallery/", # path to your example scripts "gallery_dirs": "auto_examples", # path to where to save gallery generated output - "subsection_order": ExplicitOrder(["../../gallery/v2_transforms", "../../gallery/others"]), + "subsection_order": ExplicitOrder(["../../gallery/transforms", "../../gallery/others"]), "backreferences_dir": "gen_modules/backreferences", "doc_module": ("torchvision",), "remove_config_comments": True, "ignore_pattern": "helpers.py", + "within_subsection_order": CustomGalleryExampleSortKey, } napoleon_use_ivar = True diff --git a/docs/source/datapoints.rst b/docs/source/datapoints.rst index 4a2a8e9fc5c..2ecfdec54c2 100644 --- a/docs/source/datapoints.rst +++ b/docs/source/datapoints.rst @@ -8,7 +8,7 @@ Datapoints Datapoints are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to dispatch their inputs to the appropriate lower-level kernels. Most users do not need to manipulate datapoints directly and can simply rely on dataset wrapping - -see e.g. :ref:`sphx_glr_auto_examples_v2_transforms_plot_transforms_v2_e2e.py`. +see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. .. autosummary:: :toctree: generated/ diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6bf2c3753db..1527000443d 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -113,7 +113,7 @@ do to is to update the import to ``torchvision.transforms.v2``. In terms of output, there might be negligible differences due to implementation differences. To learn more about the v2 transforms, check out -:ref:`sphx_glr_auto_examples_v2_transforms_plot_transforms_v2.py`. +:ref:`sphx_glr_auto_examples_transforms_plot_transforms_getting_started.py`. .. TODO: make sure link is still good!! @@ -479,7 +479,7 @@ CutMix and MixUp are special transforms that are meant to be used on batches rather than on individual images, because they are combining pairs of images together. These can be used after the dataloader (once the samples are batched), or part of a collation function. See -:ref:`sphx_glr_auto_examples_v2_transforms_plot_cutmix_mixup.py` for detailed usage examples. +:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage examples. .. autosummary:: :toctree: generated/ diff --git a/gallery/others/plot_scripted_tensor_transforms.py b/gallery/others/plot_scripted_tensor_transforms.py index 85b332c4ca1..27e8cc166a8 100644 --- a/gallery/others/plot_scripted_tensor_transforms.py +++ b/gallery/others/plot_scripted_tensor_transforms.py @@ -62,7 +62,7 @@ def show(imgs): # -------------------------- # Most transforms natively support tensors on top of PIL images (to visualize # the effect of the transforms, you may refer to see -# :ref:`sphx_glr_auto_examples_others_plot_transforms.py`). +# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`). # Using tensor images, we can run the transforms on GPUs if cuda is available! import torch.nn as nn diff --git a/gallery/transforms/README.rst b/gallery/transforms/README.rst new file mode 100644 index 00000000000..1b8b1b08155 --- /dev/null +++ b/gallery/transforms/README.rst @@ -0,0 +1,4 @@ +.. _transforms_gallery: + +Transforms +---------- diff --git a/gallery/v2_transforms/helpers.py b/gallery/transforms/helpers.py similarity index 87% rename from gallery/v2_transforms/helpers.py rename to gallery/transforms/helpers.py index 3c92df4322e..957d9bcb709 100644 --- a/gallery/v2_transforms/helpers.py +++ b/gallery/transforms/helpers.py @@ -5,7 +5,7 @@ from torchvision.transforms.v2 import functional as F -def plot(imgs): +def plot(imgs, row_title=None, **imshow_kwargs): if not isinstance(imgs[0], list): # Make a 2d grid even if there's just 1 row imgs = [imgs] @@ -40,7 +40,11 @@ def plot(imgs): img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65) ax = axs[row_idx, col_idx] - ax.imshow(img.permute(1, 2, 0).numpy()) + ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs) ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + if row_title is not None: + for row_idx in range(num_rows): + axs[row_idx, 0].set(ylabel=row_title[row_idx]) + plt.tight_layout() diff --git a/gallery/v2_transforms/plot_custom_datapoints.py b/gallery/transforms/plot_custom_datapoints.py similarity index 95% rename from gallery/v2_transforms/plot_custom_datapoints.py rename to gallery/transforms/plot_custom_datapoints.py index 0859adb6d93..674aceb6e5b 100644 --- a/gallery/v2_transforms/plot_custom_datapoints.py +++ b/gallery/transforms/plot_custom_datapoints.py @@ -5,12 +5,12 @@ .. note:: Try on `collab `_ - or :ref:`go to the end ` to download the full example code. + or :ref:`go to the end ` to download the full example code. This guide is intended for advanced users and downstream library maintainers. We explain how to write your own datapoint class, and how to make it compatible with the built-in Torchvision v2 transforms. Before continuing, make sure you have read -:ref:`sphx_glr_auto_examples_v2_transforms_plot_datapoints.py`. +:ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`. """ # %% diff --git a/gallery/v2_transforms/plot_custom_transforms.py b/gallery/transforms/plot_custom_transforms.py similarity index 97% rename from gallery/v2_transforms/plot_custom_transforms.py rename to gallery/transforms/plot_custom_transforms.py index 912ddf323ff..55e8e3f060f 100644 --- a/gallery/v2_transforms/plot_custom_transforms.py +++ b/gallery/transforms/plot_custom_transforms.py @@ -5,7 +5,7 @@ .. note:: Try on `collab `_ - or :ref:`go to the end ` to download the full example code. + or :ref:`go to the end ` to download the full example code. This guide explains how to write transforms that are compatible with the torchvision transforms V2 API. diff --git a/gallery/v2_transforms/plot_cutmix_mixup.py b/gallery/transforms/plot_cutmix_mixup.py similarity index 97% rename from gallery/v2_transforms/plot_cutmix_mixup.py rename to gallery/transforms/plot_cutmix_mixup.py index 6bf21933d5e..d26b027b121 100644 --- a/gallery/v2_transforms/plot_cutmix_mixup.py +++ b/gallery/transforms/plot_cutmix_mixup.py @@ -6,7 +6,7 @@ .. note:: Try on `collab `_ - or :ref:`go to the end ` to download the full example code. + or :ref:`go to the end ` to download the full example code. :class:`~torchvision.transforms.v2.CutMix` and :class:`~torchvision.transforms.v2.MixUp` are popular augmentation strategies diff --git a/gallery/v2_transforms/plot_datapoints.py b/gallery/transforms/plot_datapoints.py similarity index 98% rename from gallery/v2_transforms/plot_datapoints.py rename to gallery/transforms/plot_datapoints.py index b56de809f37..726046097a9 100644 --- a/gallery/v2_transforms/plot_datapoints.py +++ b/gallery/transforms/plot_datapoints.py @@ -5,7 +5,7 @@ .. note:: Try on `collab `_ - or :ref:`go to the end ` to download the full example code. + or :ref:`go to the end ` to download the full example code. Datapoints are Tensor subclasses introduced together with diff --git a/gallery/v2_transforms/plot_transforms_v2_e2e.py b/gallery/transforms/plot_transforms_e2e.py similarity index 97% rename from gallery/v2_transforms/plot_transforms_v2_e2e.py rename to gallery/transforms/plot_transforms_e2e.py index fa47dbfef5d..313c7b7e606 100644 --- a/gallery/v2_transforms/plot_transforms_v2_e2e.py +++ b/gallery/transforms/plot_transforms_e2e.py @@ -4,8 +4,8 @@ =============================================================== .. note:: - Try on `collab `_ - or :ref:`go to the end ` to download the full example code. + Try on `collab `_ + or :ref:`go to the end ` to download the full example code. Object detection and segmentation tasks are natively supported: ``torchvision.transforms.v2`` enables jointly transforming images, videos, diff --git a/gallery/v2_transforms/plot_transforms_v2.py b/gallery/transforms/plot_transforms_getting_started.py similarity index 96% rename from gallery/v2_transforms/plot_transforms_v2.py rename to gallery/transforms/plot_transforms_getting_started.py index 92a92545a58..da23ccd81fe 100644 --- a/gallery/v2_transforms/plot_transforms_v2.py +++ b/gallery/transforms/plot_transforms_getting_started.py @@ -4,8 +4,8 @@ ================================== .. note:: - Try on `collab `_ - or :ref:`go to the end ` to download the full example code. + Try on `collab `_ + or :ref:`go to the end ` to download the full example code. This example illustrates all of what you need to know to get started with the new :mod:`torchvision.transforms.v2` API. We'll cover simple tasks like @@ -70,7 +70,7 @@ # ` to learn more about recommended practices and conventions, or # explore more :ref:`examples ` e.g. how to use augmentation # transforms like :ref:`CutMix and MixUp -# `. +# `. # # .. note:: # @@ -148,7 +148,7 @@ # # You don't need to know much more about datapoints at this point, but advanced # users who want to learn more can refer to -# :ref:`sphx_glr_auto_examples_v2_transforms_plot_datapoints.py`. +# :ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`. # # What do I pass as input? # ------------------------ @@ -243,7 +243,7 @@ # # from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2 # -# dataset = CocoDetection(..., transforms=my_v2_transforms) +# dataset = CocoDetection(..., transforms=my_transforms) # dataset = wrap_dataset_for_transforms_v2(dataset) # # Now the dataset returns datapoints! # diff --git a/gallery/others/plot_transforms.py b/gallery/transforms/plot_transforms_illustrations.py similarity index 73% rename from gallery/others/plot_transforms.py rename to gallery/transforms/plot_transforms_illustrations.py index 9702bc9c3ba..95ab455d0fd 100644 --- a/gallery/others/plot_transforms.py +++ b/gallery/transforms/plot_transforms_illustrations.py @@ -4,55 +4,33 @@ ========================== .. note:: - Try on `collab `_ - or :ref:`go to the end ` to download the full example code. + Try on `collab `_ + or :ref:`go to the end ` to download the full example code. -This example illustrates the various transforms available in :ref:`the -torchvision.transforms module `. +This example illustrates some of the various transforms available in :ref:`the +torchvision.transforms.v2 module `. """ +# %% # sphinx_gallery_thumbnail_path = "../../gallery/assets/transforms_thumbnail.png" from PIL import Image from pathlib import Path import matplotlib.pyplot as plt -import numpy as np import torch -import torchvision.transforms as T - +from torchvision.transforms import v2 plt.rcParams["savefig.bbox"] = 'tight' -orig_img = Image.open(Path('../assets') / 'astronaut.jpg') + # if you change the seed, make sure that the randomly-applied transforms # properly show that the image can be both transformed and *not* transformed! torch.manual_seed(0) - -def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): - if not isinstance(imgs[0], list): - # Make a 2d grid even if there's just 1 row - imgs = [imgs] - - num_rows = len(imgs) - num_cols = len(imgs[0]) + with_orig - fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) - for row_idx, row in enumerate(imgs): - row = [orig_img] + row if with_orig else row - for col_idx, img in enumerate(row): - ax = axs[row_idx, col_idx] - ax.imshow(np.asarray(img), **imshow_kwargs) - ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) - - if with_orig: - axs[0, 0].set(title='Original image') - axs[0, 0].title.set_size(8) - if row_title is not None: - for row_idx in range(num_rows): - axs[row_idx, 0].set(ylabel=row_title[row_idx]) - - plt.tight_layout() - +# If you're trying to run that on collab, you can download the assets and the +# helpers from https://github.com/pytorch/vision/tree/main/gallery/ +from helpers import plot +orig_img = Image.open(Path('../assets') / 'astronaut.jpg') # %% # Geometric Transforms @@ -66,8 +44,8 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.Pad` transform # (see also :func:`~torchvision.transforms.functional.pad`) # pads all image borders with some pixel values. -padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)] -plot(padded_imgs) +padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)] +plot([orig_img] + padded_imgs) # %% # Resize @@ -75,8 +53,8 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.Resize` transform # (see also :func:`~torchvision.transforms.functional.resize`) # resizes an image. -resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] -plot(resized_imgs) +resized_imgs = [v2.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] +plot([orig_img] + resized_imgs) # %% # CenterCrop @@ -84,8 +62,8 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.CenterCrop` transform # (see also :func:`~torchvision.transforms.functional.center_crop`) # crops the given image at the center. -center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] -plot(center_crops) +center_crops = [v2.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)] +plot([orig_img] + center_crops) # %% # FiveCrop @@ -93,8 +71,8 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.FiveCrop` transform # (see also :func:`~torchvision.transforms.functional.five_crop`) # crops the given image into four corners and the central crop. -(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img) -plot([top_left, top_right, bottom_left, bottom_right, center]) +(top_left, top_right, bottom_left, bottom_right, center) = v2.FiveCrop(size=(100, 100))(orig_img) +plot([orig_img] + [top_left, top_right, bottom_left, bottom_right, center]) # %% # RandomPerspective @@ -102,9 +80,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomPerspective` transform # (see also :func:`~torchvision.transforms.functional.perspective`) # performs random perspective transform on an image. -perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0) +perspective_transformer = v2.RandomPerspective(distortion_scale=0.6, p=1.0) perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)] -plot(perspective_imgs) +plot([orig_img] + perspective_imgs) # %% # RandomRotation @@ -112,9 +90,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomRotation` transform # (see also :func:`~torchvision.transforms.functional.rotate`) # rotates an image with random angle. -rotater = T.RandomRotation(degrees=(0, 180)) +rotater = v2.RandomRotation(degrees=(0, 180)) rotated_imgs = [rotater(orig_img) for _ in range(4)] -plot(rotated_imgs) +plot([orig_img] + rotated_imgs) # %% # RandomAffine @@ -122,9 +100,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomAffine` transform # (see also :func:`~torchvision.transforms.functional.affine`) # performs random affine transform on an image. -affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75)) +affine_transfomer = v2.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75)) affine_imgs = [affine_transfomer(orig_img) for _ in range(4)] -plot(affine_imgs) +plot([orig_img] + affine_imgs) # %% # ElasticTransform @@ -133,9 +111,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # (see also :func:`~torchvision.transforms.functional.elastic_transform`) # Randomly transforms the morphology of objects in images and produces a # see-through-water-like effect. -elastic_transformer = T.ElasticTransform(alpha=250.0) +elastic_transformer = v2.ElasticTransform(alpha=250.0) transformed_imgs = [elastic_transformer(orig_img) for _ in range(2)] -plot(transformed_imgs) +plot([orig_img] + transformed_imgs) # %% # RandomCrop @@ -143,9 +121,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomCrop` transform # (see also :func:`~torchvision.transforms.functional.crop`) # crops an image at a random location. -cropper = T.RandomCrop(size=(128, 128)) +cropper = v2.RandomCrop(size=(128, 128)) crops = [cropper(orig_img) for _ in range(4)] -plot(crops) +plot([orig_img] + crops) # %% # RandomResizedCrop @@ -154,9 +132,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # (see also :func:`~torchvision.transforms.functional.resized_crop`) # crops an image at a random location, and then resizes the crop to a given # size. -resize_cropper = T.RandomResizedCrop(size=(32, 32)) +resize_cropper = v2.RandomResizedCrop(size=(32, 32)) resized_crops = [resize_cropper(orig_img) for _ in range(4)] -plot(resized_crops) +plot([orig_img] + resized_crops) # %% # Photometric Transforms @@ -175,17 +153,17 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.Grayscale` transform # (see also :func:`~torchvision.transforms.functional.to_grayscale`) # converts an image to grayscale -gray_img = T.Grayscale()(orig_img) -plot([gray_img], cmap='gray') +gray_img = v2.Grayscale()(orig_img) +plot([orig_img, gray_img], cmap='gray') # %% # ColorJitter # ~~~~~~~~~~~ # The :class:`~torchvision.transforms.ColorJitter` transform # randomly changes the brightness, contrast, saturation, hue, and other properties of an image. -jitter = T.ColorJitter(brightness=.5, hue=.3) -jitted_imgs = [jitter(orig_img) for _ in range(4)] -plot(jitted_imgs) +jitter = v2.ColorJitter(brightness=.5, hue=.3) +jittered_imgs = [jitter(orig_img) for _ in range(4)] +plot([orig_img] + jittered_imgs) # %% # GaussianBlur @@ -193,9 +171,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.GaussianBlur` transform # (see also :func:`~torchvision.transforms.functional.gaussian_blur`) # performs gaussian blur transform on an image. -blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)) +blurrer = v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)) blurred_imgs = [blurrer(orig_img) for _ in range(4)] -plot(blurred_imgs) +plot([orig_img] + blurred_imgs) # %% # RandomInvert @@ -203,9 +181,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomInvert` transform # (see also :func:`~torchvision.transforms.functional.invert`) # randomly inverts the colors of the given image. -inverter = T.RandomInvert() +inverter = v2.RandomInvert() invertered_imgs = [inverter(orig_img) for _ in range(4)] -plot(invertered_imgs) +plot([orig_img] + invertered_imgs) # %% # RandomPosterize @@ -214,9 +192,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # (see also :func:`~torchvision.transforms.functional.posterize`) # randomly posterizes the image by reducing the number of bits # of each color channel. -posterizer = T.RandomPosterize(bits=2) +posterizer = v2.RandomPosterize(bits=2) posterized_imgs = [posterizer(orig_img) for _ in range(4)] -plot(posterized_imgs) +plot([orig_img] + posterized_imgs) # %% # RandomSolarize @@ -225,9 +203,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # (see also :func:`~torchvision.transforms.functional.solarize`) # randomly solarizes the image by inverting all pixel values above # the threshold. -solarizer = T.RandomSolarize(threshold=192.0) +solarizer = v2.RandomSolarize(threshold=192.0) solarized_imgs = [solarizer(orig_img) for _ in range(4)] -plot(solarized_imgs) +plot([orig_img] + solarized_imgs) # %% # RandomAdjustSharpness @@ -235,9 +213,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomAdjustSharpness` transform # (see also :func:`~torchvision.transforms.functional.adjust_sharpness`) # randomly adjusts the sharpness of the given image. -sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2) +sharpness_adjuster = v2.RandomAdjustSharpness(sharpness_factor=2) sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)] -plot(sharpened_imgs) +plot([orig_img] + sharpened_imgs) # %% # RandomAutocontrast @@ -245,9 +223,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomAutocontrast` transform # (see also :func:`~torchvision.transforms.functional.autocontrast`) # randomly applies autocontrast to the given image. -autocontraster = T.RandomAutocontrast() +autocontraster = v2.RandomAutocontrast() autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)] -plot(autocontrasted_imgs) +plot([orig_img] + autocontrasted_imgs) # %% # RandomEqualize @@ -255,9 +233,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomEqualize` transform # (see also :func:`~torchvision.transforms.functional.equalize`) # randomly equalizes the histogram of the given image. -equalizer = T.RandomEqualize() +equalizer = v2.RandomEqualize() equalized_imgs = [equalizer(orig_img) for _ in range(4)] -plot(equalized_imgs) +plot([orig_img] + equalized_imgs) # %% # Augmentation Transforms @@ -270,22 +248,22 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.AutoAugment` transform # automatically augments data based on a given auto-augmentation policy. # See :class:`~torchvision.transforms.AutoAugmentPolicy` for the available policies. -policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN] -augmenters = [T.AutoAugment(policy) for policy in policies] +policies = [v2.AutoAugmentPolicy.CIFAR10, v2.AutoAugmentPolicy.IMAGENET, v2.AutoAugmentPolicy.SVHN] +augmenters = [v2.AutoAugment(policy) for policy in policies] imgs = [ [augmenter(orig_img) for _ in range(4)] for augmenter in augmenters ] row_title = [str(policy).split('.')[-1] for policy in policies] -plot(imgs, row_title=row_title) +plot([[orig_img] + row for row in imgs], row_title=row_title) # %% # RandAugment # ~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandAugment` is an alternate version of AutoAugment. -augmenter = T.RandAugment() +augmenter = v2.RandAugment() imgs = [augmenter(orig_img) for _ in range(4)] -plot(imgs) +plot([orig_img] + imgs) # %% # TrivialAugmentWide @@ -293,17 +271,17 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.TrivialAugmentWide` is an alternate implementation of AutoAugment. # However, instead of transforming an image multiple times, it transforms an image only once # using a random transform from a given list with a random strength number. -augmenter = T.TrivialAugmentWide() +augmenter = v2.TrivialAugmentWide() imgs = [augmenter(orig_img) for _ in range(4)] -plot(imgs) +plot([orig_img] + imgs) # %% # AugMix # ~~~~~~ # The :class:`~torchvision.transforms.AugMix` transform interpolates between augmented versions of an image. -augmenter = T.AugMix() +augmenter = v2.AugMix() imgs = [augmenter(orig_img) for _ in range(4)] -plot(imgs) +plot([orig_img] + imgs) # %% # Randomly-applied Transforms @@ -318,9 +296,9 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomHorizontalFlip` transform # (see also :func:`~torchvision.transforms.functional.hflip`) # performs horizontal flip of an image, with a given probability. -hflipper = T.RandomHorizontalFlip(p=0.5) +hflipper = v2.RandomHorizontalFlip(p=0.5) transformed_imgs = [hflipper(orig_img) for _ in range(4)] -plot(transformed_imgs) +plot([orig_img] + transformed_imgs) # %% # RandomVerticalFlip @@ -328,15 +306,15 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): # The :class:`~torchvision.transforms.RandomVerticalFlip` transform # (see also :func:`~torchvision.transforms.functional.vflip`) # performs vertical flip of an image, with a given probability. -vflipper = T.RandomVerticalFlip(p=0.5) +vflipper = v2.RandomVerticalFlip(p=0.5) transformed_imgs = [vflipper(orig_img) for _ in range(4)] -plot(transformed_imgs) +plot([orig_img] + transformed_imgs) # %% # RandomApply # ~~~~~~~~~~~ # The :class:`~torchvision.transforms.RandomApply` transform # randomly applies a list of transforms, with a given probability. -applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5) +applier = v2.RandomApply(transforms=[v2.RandomCrop(size=(64, 64))], p=0.5) transformed_imgs = [applier(orig_img) for _ in range(4)] -plot(transformed_imgs) +plot([orig_img] + transformed_imgs) diff --git a/gallery/v2_transforms/README.rst b/gallery/v2_transforms/README.rst deleted file mode 100644 index 371af30a14b..00000000000 --- a/gallery/v2_transforms/README.rst +++ /dev/null @@ -1,4 +0,0 @@ -.. _transforms_gallery: - -V2 transforms -------------- diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 0d11f610a89..a06ecb74824 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -1,4 +1,3 @@ -import enum import importlib.machinery import importlib.util import inspect @@ -83,35 +82,6 @@ def __init__( supports_pil=False, make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]), ), - ConsistencyConfig( - v2_transforms.Resize, - legacy_transforms.Resize, - [ - NotScriptableArgsKwargs(32), - ArgsKwargs([32]), - ArgsKwargs((32, 29)), - ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST), - ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST), - ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR), - NotScriptableArgsKwargs(31, max_size=32), - ArgsKwargs([31], max_size=32), - NotScriptableArgsKwargs(30, max_size=100), - ArgsKwargs([31], max_size=32), - ArgsKwargs((29, 32), antialias=False), - ArgsKwargs((28, 31), antialias=True), - ], - # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes - closeness_kwargs=dict(rtol=0, atol=1), - ), - ConsistencyConfig( - v2_transforms.Resize, - legacy_transforms.Resize, - [ - ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True), - ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True), - ], - closeness_kwargs=dict(rtol=0, atol=21), - ), ConsistencyConfig( v2_transforms.CenterCrop, legacy_transforms.CenterCrop, @@ -187,20 +157,6 @@ def __init__( # Use default tolerances of `torch.testing.assert_close` closeness_kwargs=dict(rtol=None, atol=None), ), - ConsistencyConfig( - v2_transforms.ConvertImageDtype, - legacy_transforms.ConvertImageDtype, - [ - ArgsKwargs(torch.float16), - ArgsKwargs(torch.bfloat16), - ArgsKwargs(torch.float32), - ArgsKwargs(torch.float64), - ArgsKwargs(torch.uint8), - ], - supports_pil=False, - # Use default tolerances of `torch.testing.assert_close` - closeness_kwargs=dict(rtol=None, atol=None), - ), ConsistencyConfig( v2_transforms.ToPILImage, legacy_transforms.ToPILImage, @@ -226,22 +182,6 @@ def __init__( # images given that the transform does nothing but call it anyway. supports_pil=False, ), - ConsistencyConfig( - v2_transforms.RandomHorizontalFlip, - legacy_transforms.RandomHorizontalFlip, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - ), - ConsistencyConfig( - v2_transforms.RandomVerticalFlip, - legacy_transforms.RandomVerticalFlip, - [ - ArgsKwargs(p=0), - ArgsKwargs(p=1), - ], - ), ConsistencyConfig( v2_transforms.RandomEqualize, legacy_transforms.RandomEqualize, @@ -367,30 +307,6 @@ def __init__( ], closeness_kwargs={"atol": 1e-5, "rtol": 1e-5}, ), - *[ - ConsistencyConfig( - v2_transforms.ElasticTransform, - legacy_transforms.ElasticTransform, - [ - ArgsKwargs(), - ArgsKwargs(alpha=20.0), - ArgsKwargs(alpha=(15.3, 27.2)), - ArgsKwargs(sigma=3.0), - ArgsKwargs(sigma=(2.5, 3.9)), - ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST), - ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC), - ArgsKwargs(interpolation=PIL.Image.NEAREST), - ArgsKwargs(interpolation=PIL.Image.BICUBIC), - ArgsKwargs(fill=1), - ], - # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]), - # We updated gaussian blur kernel generation with a faster and numerically more stable version - # This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance - closeness_kwargs=ckw, - ) - for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})] - ], ConsistencyConfig( v2_transforms.GaussianBlur, legacy_transforms.GaussianBlur, @@ -402,26 +318,6 @@ def __init__( ], closeness_kwargs={"rtol": 1e-5, "atol": 1e-5}, ), - ConsistencyConfig( - v2_transforms.RandomAffine, - legacy_transforms.RandomAffine, - [ - ArgsKwargs(degrees=30.0), - ArgsKwargs(degrees=(-20.0, 10.0)), - ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)), - ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)), - ArgsKwargs(degrees=0.0, shear=13), - ArgsKwargs(degrees=0.0, shear=(8, 17)), - ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)), - ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)), - ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST), - ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST), - ArgsKwargs(degrees=30.0, fill=1), - ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), - ArgsKwargs(degrees=30.0, center=(0, 0)), - ], - removed_params=["fillcolor", "resample"], - ), ConsistencyConfig( v2_transforms.RandomCrop, legacy_transforms.RandomCrop, @@ -456,21 +352,6 @@ def __init__( ], closeness_kwargs={"atol": None, "rtol": None}, ), - ConsistencyConfig( - v2_transforms.RandomRotation, - legacy_transforms.RandomRotation, - [ - ArgsKwargs(degrees=30.0), - ArgsKwargs(degrees=(-20.0, 10.0)), - ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR), - ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR), - ArgsKwargs(degrees=30.0, expand=True), - ArgsKwargs(degrees=30.0, center=(0, 0)), - ArgsKwargs(degrees=30.0, fill=1), - ArgsKwargs(degrees=30.0, fill=(1, 2, 3)), - ], - removed_params=["resample"], - ), ConsistencyConfig( v2_transforms.PILToTensor, legacy_transforms.PILToTensor, @@ -514,23 +395,6 @@ def __init__( ] -def test_automatic_coverage(): - available = { - name - for name, obj in legacy_transforms.__dict__.items() - if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum) - } - - checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS} - - missing = available - checked - if missing: - raise AssertionError( - f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} " - f"are not checked for consistency although a legacy counterpart exists." - ) - - @pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__) def test_signature_consistency(config): legacy_params = dict(inspect.signature(config.legacy_cls).parameters) @@ -708,15 +572,9 @@ def test_call_consistency(config, args_kwargs): (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), (v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), - (v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), - ( - v2_transforms.RandomAffine, - ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), - ), (v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), - (v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), (v2_transforms.AutoAugment, ArgsKwargs(5)), ] ], diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 1c22c002a36..2fea19e8190 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -228,26 +228,37 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type): assert functional_param == kernel_param -def _check_transform_v1_compatibility(transform, input): +def _check_transform_v1_compatibility(transform, input, rtol, atol): """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static - ``get_params`` method, is scriptable, and the scripted version can be called without error.""" - if transform._v1_transform_cls is None: + ``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version + can be called without error.""" + if type(input) is not torch.Tensor or isinstance(input, PIL.Image.Image): return - if type(input) is not torch.Tensor: + v1_transform_cls = transform._v1_transform_cls + if v1_transform_cls is None: return - if hasattr(transform._v1_transform_cls, "get_params"): - assert type(transform).get_params is transform._v1_transform_cls.get_params + if hasattr(v1_transform_cls, "get_params"): + assert type(transform).get_params is v1_transform_cls.get_params - scripted_transform = _script(transform) - with ignore_jit_no_profile_information_warning(): - scripted_transform(input) + v1_transform = v1_transform_cls(**transform._extract_params_for_v1_transform()) + + with freeze_rng_state(): + output_v2 = transform(input) + + with freeze_rng_state(): + output_v1 = v1_transform(input) + + assert_close(output_v2, output_v1, rtol=rtol, atol=atol) + if isinstance(input, PIL.Image.Image): + return + + _script(v1_transform)(input) -def check_transform(transform_cls, input, *args, **kwargs): - transform = transform_cls(*args, **kwargs) +def check_transform(transform, input, check_v1_compatibility=True): pickle.loads(pickle.dumps(transform)) output = transform(input) @@ -256,7 +267,8 @@ def check_transform(transform_cls, input, *args, **kwargs): if isinstance(input, datapoints.BoundingBoxes): assert output.format == input.format - _check_transform_v1_compatibility(transform, input) + if check_v1_compatibility: + _check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility)) def transform_cls_to_functional(transform_cls, **transform_specific_kwargs): @@ -541,7 +553,12 @@ def test_functional_signature(self, kernel, input_type): ], ) def test_transform(self, size, device, make_input): - check_transform(transforms.Resize, make_input(self.INPUT_SIZE, device=device), size=size, antialias=True) + check_transform( + transforms.Resize(size=size, antialias=True), + make_input(self.INPUT_SIZE, device=device), + # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes + check_v1_compatibility=dict(rtol=0, atol=1), + ) def _check_output_size(self, input, output, *, size, max_size): assert tuple(F.get_size(output)) == self._compute_output_size( @@ -862,7 +879,7 @@ def test_functional_signature(self, kernel, input_type): ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): - check_transform(transforms.RandomHorizontalFlip, make_input(device=device), p=1) + check_transform(transforms.RandomHorizontalFlip(p=1), make_input(device=device)) @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] @@ -1032,7 +1049,7 @@ def test_functional_signature(self, kernel, input_type): def test_transform(self, make_input, device): input = make_input(device=device) - check_transform(transforms.RandomAffine, input, **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES) + check_transform(transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), input) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"]) @@ -1312,7 +1329,7 @@ def test_functional_signature(self, kernel, input_type): ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): - check_transform(transforms.RandomVerticalFlip, make_input(device=device), p=1) + check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) def test_image_correctness(self, fn): @@ -1455,7 +1472,7 @@ def test_functional_signature(self, kernel, input_type): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, device): check_transform( - transforms.RandomRotation, make_input(device=device), **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES + transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES), make_input(device=device) ) @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) @@ -1752,7 +1769,7 @@ def test_transform(self, make_input, input_dtype, output_dtype, device, scale, a input = make_input(dtype=input_dtype, device=device) if as_dict: output_dtype = {type(input): output_dtype} - check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale) + check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input) def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False): input_dtype = image.dtype @@ -2441,7 +2458,12 @@ def test_displacement_error(self, make_input): @pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform(self, make_input, size, device): - check_transform(transforms.ElasticTransform, make_input(size, device=device)) + check_transform( + transforms.ElasticTransform(), + make_input(size, device=device), + # We updated gaussian blur kernel generation with a faster and numerically more stable version + check_v1_compatibility=dict(rtol=0, atol=1), + ) class TestToPureTensor: diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 11f869103b0..64103f5834e 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -17,7 +17,7 @@ class Datapoint(torch.Tensor): You probably don't want to use this class unless you're defining your own custom Datapoints. See - :ref:`sphx_glr_auto_examples_v2_transforms_plot_custom_datapoints.py` for details. + :ref:`sphx_glr_auto_examples_transforms_plot_custom_datapoints.py` for details. """ @staticmethod diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index a9bad8f9bf7..a5c98382540 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -216,7 +216,7 @@ class MixUp(_BaseMixUpCutMix): .. note:: This transform is meant to be used on **batches** of samples, not individual images. See - :ref:`sphx_glr_auto_examples_v2_transforms_plot_cutmix_mixup.py` for detailed usage + :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage examples. The sample pairing is deterministic and done by matching consecutive samples in the batch, so the batch needs to be shuffled (this is an @@ -266,7 +266,7 @@ class CutMix(_BaseMixUpCutMix): .. note:: This transform is meant to be used on **batches** of samples, not individual images. See - :ref:`sphx_glr_auto_examples_v2_transforms_plot_cutmix_mixup.py` for detailed usage + :ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage examples. The sample pairing is deterministic and done by matching consecutive samples in the batch, so the batch needs to be shuffled (this is an diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index dd77816462a..5a907121b92 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -69,7 +69,7 @@ def _name_to_functional(name): def register_kernel(functional, datapoint_cls): """[BETA] Decorate a kernel to register it for a functional and a (custom) datapoint type. - See :ref:`sphx_glr_auto_examples_v2_transforms_plot_custom_datapoints.py` for usage + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_datapoints.py` for usage details. """ if isinstance(functional, str):