Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

promote Mixup and Cutmix from prototype to transforms v2 #7731

Merged
merged 22 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch.utils.data
import torchvision
import torchvision.transforms
import transforms
import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from transforms import get_mixup_cutmix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll validate the changes made to the references once this is merged.



def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Expand Down Expand Up @@ -218,18 +218,17 @@ def main(args):
val_dir = os.path.join(args.data_path, "val")
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
mixup_cutmix = get_mixup_cutmix(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
)
if mixup_cutmix is not None:

def collate_fn(batch):
return mixupcutmix(*default_collate(batch))
return mixup_cutmix(*default_collate(batch))

else:
collate_fn = default_collate

data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down
23 changes: 23 additions & 0 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,33 @@
from typing import Tuple

import torch
from presets import get_module
from torch import Tensor
from torchvision.transforms import functional as F


def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
transforms_module = get_module(use_v2)

mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.Mixup(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomMixup(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.Cutmix(alpha=mixup_alpha, num_categories=num_categories)
if use_v2
else RandomCutmix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None

return transforms_module.RandomChoice(mixup_cutmix)


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
Expand Down
143 changes: 143 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
make_image_tensor,
make_segmentation_mask,
make_video,
needs_cuda,
set_rng_seed,
)
from torch.testing import assert_close
from torch.utils.data import DataLoader, default_collate
from torchvision import datapoints
from torchvision.datasets import FakeData

from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping
Expand Down Expand Up @@ -1634,3 +1637,143 @@ def test_transform_negative_degrees_error(self):
def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCutMixMixUp:
# TODO: Does it work when labels are already dirichlet-distributed? like does Compose([Mixup(), CutMix()]) work?
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup, "Compose"])
@pytest.mark.parametrize("one_hot", [True, False])
def test_supported_input_structure(self, T, one_hot):

num_categories = 10
batch_size = 32
H, W = 12, 12

preproc = transforms.Compose([transforms.PILToTensor(), transforms.ToDtype(torch.float32)])
if one_hot:

class ToOneHot(torch.nn.Module):
def forward(self, inpt):
img, label = inpt
return img, torch.nn.functional.one_hot(label, num_classes=num_categories)

preproc = transforms.Compose([preproc, ToOneHot()])

dataset = FakeData(size=batch_size, image_size=(3, H, W), num_classes=num_categories, transforms=preproc)
if T == "Compose":
cutmix = transforms.Cutmix(alpha=0.5, num_categories=num_categories)
mixup = transforms.Mixup(alpha=0.5, num_categories=num_categories)
cutmix_mixup = transforms.Compose([cutmix, mixup])
expected_num_non_zero_labels = 3 # see common_checks
else:
cutmix_mixup = T(alpha=0.5, num_categories=num_categories)
expected_num_non_zero_labels = 2 # see common_checks

dl = DataLoader(dataset, batch_size=batch_size)

# Input sanity checks
img, target = next(iter(dl))
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size, num_categories) if one_hot else (batch_size,)

def check_output(img, target):
assert img.shape == (batch_size, 3, H, W)
assert target.shape == (batch_size, num_categories)
torch.testing.assert_close(target.sum(axis=-1), torch.ones(batch_size))
# Below we check the number of non-zero values in the target tensor.
# When just CutMix() (or just MixUp()) is called, we should expect 2
# non-zero label values per sample. Although, it may happen that
# only 1 non-zero value is present, basically if the transform had
# no effect. Here we make sure that:
# - there is at least one sample with 2 non-zero values
# - there is no sample with more than 2 non-zero values
# When CutMix() and MixUp() are called in sequence together, we
# should expect 3 instead of 2. That's the
# expected_num_non_zero_labels threshold.
num_non_zero_values = (target != 0).sum(axis=-1)
assert (num_non_zero_values == expected_num_non_zero_labels).any()
assert (num_non_zero_values <= expected_num_non_zero_labels).all()
assert (num_non_zero_values > 0).all() # Note: we already know that from target.sum(axis=-1) check above

# After Dataloader, as unpacked input
img, target = next(iter(dl))
assert target.shape == (batch_size, num_categories) if one_hot else (batch_size,)
img, target = cutmix_mixup(img, target)
check_output(img, target)

# After Dataloader, as packed input
packed_from_dl = next(iter(dl))
assert isinstance(packed_from_dl, list)
img, target = cutmix_mixup(packed_from_dl)
check_output(img, target)

# As collation function. We expect default_collate to be used by users.
def collate_fn_1(batch):
return cutmix_mixup(default_collate(batch))

def collate_fn_2(batch):
return cutmix_mixup(*default_collate(batch))

for collate_fn in (collate_fn_1, collate_fn_2):
dl = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
img, target = next(iter(dl))
check_output(img, target)

@needs_cuda
@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_cpu_vs_gpu(self, T):
num_categories = 10
batch_size = 3
H, W = 12, 12

imgs = torch.rand(batch_size, 3, H, W).to("cuda")
labels = torch.randint(0, num_categories, (batch_size,)).to("cuda")
cutmix_mixup = T(alpha=0.5, num_categories=num_categories)

_check_kernel_cuda_vs_cpu(cutmix_mixup, input=(imgs, labels), rtol=None, atol=None)

@pytest.mark.parametrize("T", [transforms.Cutmix, transforms.Mixup])
def test_error(self, T):

num_categories = 10
batch_size = 9

# imgs = torch.randint(0, 256, (batch_size, 3, 12, 12), dtype=torch.uint8)
imgs = torch.rand(batch_size, 3, 12, 12)

for input_with_bad_type in (
F.to_pil_image(imgs[0]),
datapoints.Mask(torch.rand(12, 12)),
datapoints.BoundingBox(torch.rand(2, 4), format="XYXY", spatial_size=12),
):
with pytest.raises(ValueError, match="does not support PIL images, "):
T(alpha=0.5)(input_with_bad_type)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
T(alpha=0.5)({"img": imgs, "Nothing_else": 3})

with pytest.raises(ValueError, match="labels should be index based"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
T(alpha=0.5)(imgs)

with pytest.raises(ValueError, match="When using the default labels_getter"):
T(alpha=0.5)(imgs, "not_a_tensor")

with pytest.raises(ValueError, match="When passing 2D labels"):
wrong_num_categories = num_categories + 1
T(alpha=0.5, num_categories=num_categories)(
imgs, torch.randint(0, 2, size=(batch_size, wrong_num_categories))
)

with pytest.raises(ValueError, match="but got a tensor of shape"):
T(alpha=0.5)(imgs, torch.randint(0, 2, size=(2, 3, 4)))

with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
T(alpha=0.5)(imgs[None, None], torch.randint(0, num_categories, size=(batch_size,)))

with pytest.raises(ValueError, match="does not match the batch size of the labels"):
T(alpha=0.5)(imgs, torch.randint(0, num_categories, size=(batch_size + 1,)))

with pytest.raises(ValueError, match="num_categories must be passed"):
T(alpha=0.5)(imgs, torch.randint(0, num_categories, size=(batch_size,)))
6 changes: 5 additions & 1 deletion torchvision/datasets/fakedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def __init__(
image_size: Tuple[int, int, int] = (3, 224, 224),
num_classes: int = 10,
transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
random_offset: int = 0,
) -> None:
super().__init__(None, transform=transform, target_transform=target_transform) # type: ignore[arg-type]
super().__init__(None, transform=transform, transforms=transforms, target_transform=target_transform) # type: ignore[arg-type]
self.size = size
self.num_classes = num_classes
self.image_size = image_size
Expand Down Expand Up @@ -60,6 +61,9 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target # We don't want to call item() on arbitrarily transformed targets

return img, target.item()

Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._transform import Transform # usort: skip

from ._augment import RandomErasing
from ._augment import Cutmix, Mixup, RandomErasing
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Expand Down
Loading