Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

promote Mixup and Cutmix from prototype to transforms v2 #7731

Merged
merged 22 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix
v2.AugMix

Cutmix - Mixup
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: technically, paper names of these techniques are CutMix and MixUp.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, saw this after I merged. I'll address via #7766

--------------

Cutmix and Mixup are special transforms that
are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader,
or part of a collation function. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.

.. autosummary::
:toctree: generated/
:template: class.rst

v2.Cutmix
v2.Mixup

.. _functional_transforms:

Functional Transforms
Expand Down
8 changes: 8 additions & 0 deletions gallery/plot_cutmix_mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

"""
===========================
How to use Cutmix and Mixup
===========================

TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
19 changes: 9 additions & 10 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch.utils.data
import torchvision
import torchvision.transforms
import transforms
import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from transforms import get_mixup_cutmix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll validate the changes made to the references once this is merged.



def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Expand Down Expand Up @@ -218,18 +218,17 @@ def main(args):
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
mixup_cutmix = get_mixup_cutmix(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
)
if mixup_cutmix is not None:

def collate_fn(batch):
return mixupcutmix(*default_collate(batch))
return mixup_cutmix(*default_collate(batch))

else:
collate_fn = default_collate

data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
23 changes: 23 additions & 0 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,33 @@
from typing import Tuple

import torch
from presets import get_module
from torch import Tensor
from torchvision.transforms import functional as F


def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
transforms_module = get_module(use_v2)

mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None

return transforms_module.RandomChoice(mixup_cutmix)


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
Expand Down
26 changes: 2 additions & 24 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,9 +1558,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):


@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):

Expand Down Expand Up @@ -1648,22 +1646,6 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_labels.tolist() == valid_indices


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
labels = torch.arange(10)
sample = {key: labels, "another_key": "whatever"}
if sample_type is tuple:
sample = (None, sample, "whatever_again")
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels

if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels


def test_sanitize_bounding_boxes_errors():

good_bbox = datapoints.BoundingBox(
Expand All @@ -1674,17 +1656,13 @@ def test_sanitize_bounding_boxes_errors():

with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBox(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"):
with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
transforms.SanitizeBoundingBox(labels_getter=12)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBox()(bad_labels_key)

with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I had to delete this. I feel like it was a valid error to raise. We could put it back if we were to have a different labels_getter logic for SanitizeBBox and the Cutmix/Mixup ones (which is probably going to be needed eventually anyway).

This is OK for now.

not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
transforms.SanitizeBoundingBox()(not_a_dict)

with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBox()(not_a_tensor)
Expand Down
148 changes: 146 additions & 2 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
assert_no_warnings,
cache,
cpu_and_cuda,
freeze_rng_state,
ignore_jit_no_profile_information_warning,
make_bounding_box,
make_detection_mask,
Expand All @@ -25,12 +26,14 @@
make_image_tensor,
make_segmentation_mask,
make_video,
needs_cuda,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import datapoints

from torchvision.transforms._functional_tensor import _max_value as get_max_value
Expand Down Expand Up @@ -61,8 +64,10 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
input_cuda = input.as_subclass(torch.Tensor)
input_cpu = input_cuda.to("cpu")

actual = kernel(input_cuda, *args, **kwargs)
expected = kernel(input_cpu, *args, **kwargs)
with freeze_rng_state():
actual = kernel(input_cuda, *args, **kwargs)
with freeze_rng_state():
expected = kernel(input_cpu, *args, **kwargs)

assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol)

Expand Down Expand Up @@ -1892,3 +1897,142 @@ def test_errors_warnings(self, make_input):
assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype


class TestCutMixMixUp:
class DummyDataset:
def __init__(self, size, num_classes):
self.size = size
self.num_classes = num_classes
assert size < num_classes

def __getitem__(self, idx):
img = torch.rand(3, 100, 100)
label = idx # This ensures all labels in a batch are unique and makes testing easier
return img, label

def __len__(self):
return self.size

@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_supported_input_structure(self, T):

batch_size = 32
num_classes = 100

dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)

cutmix_mixup = T(alpha=0.5, num_classes=num_classes)

dl = DataLoader(dataset, batch_size=batch_size)

# Input sanity checks
img, target = next(iter(dl))
input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size,)

def check_output(img, target):
assert img.shape == (batch_size, *input_img_size)
assert target.shape == (batch_size, num_classes)
torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
num_non_zero_labels = (target != 0).sum(axis=-1)
assert (num_non_zero_labels == 2).all()

# After Dataloader, as unpacked input
img, target = next(iter(dl))
assert target.shape == (batch_size,)
img, target = cutmix_mixup(img, target)
check_output(img, target)

# After Dataloader, as packed input
packed_from_dl = next(iter(dl))
assert isinstance(packed_from_dl, list)
img, target = cutmix_mixup(packed_from_dl)
check_output(img, target)

# As collation function. We expect default_collate to be used by users.
def collate_fn_1(batch):
return cutmix_mixup(default_collate(batch))

def collate_fn_2(batch):
return cutmix_mixup(*default_collate(batch))

for collate_fn in (collate_fn_1, collate_fn_2):
dl = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
img, target = next(iter(dl))
check_output(img, target)

@needs_cuda
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_cpu_vs_gpu(self, T):
num_classes = 10
batch_size = 3
H, W = 12, 12

imgs = torch.rand(batch_size, 3, H, W)
labels = torch.randint(0, num_classes, (batch_size,))
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)

_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)

@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_error(self, T):

num_classes = 10
batch_size = 9

imgs = torch.rand(batch_size, 3, 12, 12)
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)

for input_with_bad_type in (
F.to_pil_image(imgs[0]),
datapoints.Mask(torch.rand(12, 12)),
datapoints.BoundingBox(torch.rand(2, 4), format="XYXY", spatial_size=12),
):
with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
cutmix_mixup({"img": imgs, "Nothing_else": 3})

with pytest.raises(ValueError, match="labels tensor should be of shape"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup(imgs)

with pytest.raises(ValueError, match="When using the default labels_getter"):
cutmix_mixup(imgs, "not_a_tensor")

with pytest.raises(ValueError, match="labels tensor should be of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))

with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))

with pytest.raises(ValueError, match="does not match the batch size of the labels"):
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))

with pytest.raises(ValueError, match="labels tensor should be of shape"):
# The purpose of this check is more about documenting the current
# behaviour of what happens on a Compose(), rather than actually
# asserting the expected behaviour. We may support Compose() in the
# future, e.g. for 2 consecutive CutMix?
labels = torch.randint(0, num_classes, size=(batch_size,))
transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
@pytest.mark.parametrize("sample_type", (tuple, list, dict))
def test_labels_getter_default_heuristic(key, sample_type):
labels = torch.arange(10)
sample = {key: labels, "another_key": "whatever"}
if sample_type is not dict:
sample = sample_type((None, sample, "whatever_again"))
assert transforms._utils._find_labels_default_heuristic(sample) is labels

if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms._utils._find_labels_default_heuristic(d) is labels
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._transform import Transform # usort: skip

from ._augment import RandomErasing
from ._augment import Cutmix, Mixup, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
Loading