Skip to content

Commit

Permalink
promote Mixup and Cutmix from prototype to transforms v2 (#7731)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nicolashug@meta.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
3 people authored Jul 28, 2023
1 parent 8071c17 commit 3591371
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 105 deletions.
16 changes: 16 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix
v2.AugMix

Cutmix - Mixup
--------------

Cutmix and Mixup are special transforms that
are meant to be used on batches rather than on individual images, because they
are combining pairs of images together. These can be used after the dataloader,
or part of a collation function. See
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.

.. autosummary::
:toctree: generated/
:template: class.rst

v2.Cutmix
v2.Mixup

.. _functional_transforms:

Functional Transforms
Expand Down
8 changes: 8 additions & 0 deletions gallery/plot_cutmix_mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

"""
===========================
How to use Cutmix and Mixup
===========================
TODO
"""
19 changes: 9 additions & 10 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch.utils.data
import torchvision
import torchvision.transforms
import transforms
import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from transforms import get_mixup_cutmix


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Expand Down Expand Up @@ -218,18 +218,17 @@ def main(args):
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
mixup_cutmix = get_mixup_cutmix(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
)
if mixup_cutmix is not None:

def collate_fn(batch):
return mixupcutmix(*default_collate(batch))
return mixup_cutmix(*default_collate(batch))

else:
collate_fn = default_collate

data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
23 changes: 23 additions & 0 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,33 @@
from typing import Tuple

import torch
from presets import get_module
from torch import Tensor
from torchvision.transforms import functional as F


def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
transforms_module = get_module(use_v2)

mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None

return transforms_module.RandomChoice(mixup_cutmix)


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
Expand Down
26 changes: 2 additions & 24 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,9 +1558,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):


@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):

Expand Down Expand Up @@ -1648,22 +1646,6 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_labels.tolist() == valid_indices


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
labels = torch.arange(10)
sample = {key: labels, "another_key": "whatever"}
if sample_type is tuple:
sample = (None, sample, "whatever_again")
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels

if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels


def test_sanitize_bounding_boxes_errors():

good_bbox = datapoints.BoundingBox(
Expand All @@ -1674,17 +1656,13 @@ def test_sanitize_bounding_boxes_errors():

with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBox(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"):
with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
transforms.SanitizeBoundingBox(labels_getter=12)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBox()(bad_labels_key)

with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
transforms.SanitizeBoundingBox()(not_a_dict)

with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBox()(not_a_tensor)
Expand Down
148 changes: 146 additions & 2 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
assert_no_warnings,
cache,
cpu_and_cuda,
freeze_rng_state,
ignore_jit_no_profile_information_warning,
make_bounding_box,
make_detection_mask,
Expand All @@ -25,12 +26,14 @@
make_image_tensor,
make_segmentation_mask,
make_video,
needs_cuda,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import datapoints

from torchvision.transforms._functional_tensor import _max_value as get_max_value
Expand Down Expand Up @@ -61,8 +64,10 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
input_cuda = input.as_subclass(torch.Tensor)
input_cpu = input_cuda.to("cpu")

actual = kernel(input_cuda, *args, **kwargs)
expected = kernel(input_cpu, *args, **kwargs)
with freeze_rng_state():
actual = kernel(input_cuda, *args, **kwargs)
with freeze_rng_state():
expected = kernel(input_cpu, *args, **kwargs)

assert_close(actual, expected, check_device=False, rtol=rtol, atol=atol)

Expand Down Expand Up @@ -1892,3 +1897,142 @@ def test_errors_warnings(self, make_input):
assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype


class TestCutMixMixUp:
class DummyDataset:
def __init__(self, size, num_classes):
self.size = size
self.num_classes = num_classes
assert size < num_classes

def __getitem__(self, idx):
img = torch.rand(3, 100, 100)
label = idx # This ensures all labels in a batch are unique and makes testing easier
return img, label

def __len__(self):
return self.size

@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_supported_input_structure(self, T):

batch_size = 32
num_classes = 100

dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)

cutmix_mixup = T(alpha=0.5, num_classes=num_classes)

dl = DataLoader(dataset, batch_size=batch_size)

# Input sanity checks
img, target = next(iter(dl))
input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size,)

def check_output(img, target):
assert img.shape == (batch_size, *input_img_size)
assert target.shape == (batch_size, num_classes)
torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
num_non_zero_labels = (target != 0).sum(axis=-1)
assert (num_non_zero_labels == 2).all()

# After Dataloader, as unpacked input
img, target = next(iter(dl))
assert target.shape == (batch_size,)
img, target = cutmix_mixup(img, target)
check_output(img, target)

# After Dataloader, as packed input
packed_from_dl = next(iter(dl))
assert isinstance(packed_from_dl, list)
img, target = cutmix_mixup(packed_from_dl)
check_output(img, target)

# As collation function. We expect default_collate to be used by users.
def collate_fn_1(batch):
return cutmix_mixup(default_collate(batch))

def collate_fn_2(batch):
return cutmix_mixup(*default_collate(batch))

for collate_fn in (collate_fn_1, collate_fn_2):
dl = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
img, target = next(iter(dl))
check_output(img, target)

@needs_cuda
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_cpu_vs_gpu(self, T):
num_classes = 10
batch_size = 3
H, W = 12, 12

imgs = torch.rand(batch_size, 3, H, W)
labels = torch.randint(0, num_classes, (batch_size,))
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)

_check_kernel_cuda_vs_cpu(cutmix_mixup, imgs, labels, rtol=None, atol=None)

@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_error(self, T):

num_classes = 10
batch_size = 9

imgs = torch.rand(batch_size, 3, 12, 12)
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)

for input_with_bad_type in (
F.to_pil_image(imgs[0]),
datapoints.Mask(torch.rand(12, 12)),
datapoints.BoundingBox(torch.rand(2, 4), format="XYXY", spatial_size=12),
):
with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
cutmix_mixup({"img": imgs, "Nothing_else": 3})

with pytest.raises(ValueError, match="labels tensor should be of shape"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup(imgs)

with pytest.raises(ValueError, match="When using the default labels_getter"):
cutmix_mixup(imgs, "not_a_tensor")

with pytest.raises(ValueError, match="labels tensor should be of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))

with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))

with pytest.raises(ValueError, match="does not match the batch size of the labels"):
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))

with pytest.raises(ValueError, match="labels tensor should be of shape"):
# The purpose of this check is more about documenting the current
# behaviour of what happens on a Compose(), rather than actually
# asserting the expected behaviour. We may support Compose() in the
# future, e.g. for 2 consecutive CutMix?
labels = torch.randint(0, num_classes, size=(batch_size,))
transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
@pytest.mark.parametrize("sample_type", (tuple, list, dict))
def test_labels_getter_default_heuristic(key, sample_type):
labels = torch.arange(10)
sample = {key: labels, "another_key": "whatever"}
if sample_type is not dict:
sample = sample_type((None, sample, "whatever_again"))
assert transforms._utils._find_labels_default_heuristic(sample) is labels

if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms._utils._find_labels_default_heuristic(d) is labels
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._transform import Transform # usort: skip

from ._augment import RandomErasing
from ._augment import Cutmix, Mixup, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
Loading

0 comments on commit 3591371

Please sign in to comment.