-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
promote Mixup and Cutmix from prototype to transforms v2 #7731
Changes from all commits
6e9eb90
de92eb6
c160ae7
1cd7c7a
7934566
d5bb664
b4e6d43
fa97d52
f3708be
50fa4d2
26f55de
e91e879
45bf28c
9f4a9e6
0505f24
4d5890d
4538c10
6542fd0
0c3b932
acc7a98
5e02675
993f693
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
""" | ||
=========================== | ||
How to use Cutmix and Mixup | ||
=========================== | ||
|
||
TODO | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,12 @@ | |
import torch.utils.data | ||
import torchvision | ||
import torchvision.transforms | ||
import transforms | ||
import utils | ||
from sampler import RASampler | ||
from torch import nn | ||
from torch.utils.data.dataloader import default_collate | ||
from torchvision.transforms.functional import InterpolationMode | ||
from transforms import get_mixup_cutmix | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll validate the changes made to the references once this is merged. |
||
|
||
|
||
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): | ||
|
@@ -218,18 +218,17 @@ def main(args): | |
val_dir = os.path.join(args.data_path, "val") | ||
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) | ||
|
||
collate_fn = None | ||
num_classes = len(dataset.classes) | ||
mixup_transforms = [] | ||
if args.mixup_alpha > 0.0: | ||
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) | ||
if args.cutmix_alpha > 0.0: | ||
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) | ||
if mixup_transforms: | ||
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) | ||
mixup_cutmix = get_mixup_cutmix( | ||
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2 | ||
) | ||
if mixup_cutmix is not None: | ||
|
||
def collate_fn(batch): | ||
return mixupcutmix(*default_collate(batch)) | ||
return mixup_cutmix(*default_collate(batch)) | ||
|
||
else: | ||
collate_fn = default_collate | ||
|
||
data_loader = torch.utils.data.DataLoader( | ||
dataset, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1558,9 +1558,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): | |
|
||
|
||
@pytest.mark.parametrize("min_size", (1, 10)) | ||
@pytest.mark.parametrize( | ||
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None) | ||
) | ||
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) | ||
@pytest.mark.parametrize("sample_type", (tuple, dict)) | ||
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): | ||
|
||
|
@@ -1648,22 +1646,6 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): | |
assert out_labels.tolist() == valid_indices | ||
|
||
|
||
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) | ||
@pytest.mark.parametrize("sample_type", (tuple, dict)) | ||
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type): | ||
labels = torch.arange(10) | ||
sample = {key: labels, "another_key": "whatever"} | ||
if sample_type is tuple: | ||
sample = (None, sample, "whatever_again") | ||
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels | ||
|
||
if key.lower() != "labels": | ||
# If "labels" is in the dict (case-insensitive), | ||
# it takes precedence over other keys which would otherwise be a match | ||
d = {key: "something_else", "labels": labels} | ||
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels | ||
|
||
|
||
def test_sanitize_bounding_boxes_errors(): | ||
|
||
good_bbox = datapoints.BoundingBox( | ||
|
@@ -1674,17 +1656,13 @@ def test_sanitize_bounding_boxes_errors(): | |
|
||
with pytest.raises(ValueError, match="min_size must be >= 1"): | ||
transforms.SanitizeBoundingBox(min_size=0) | ||
with pytest.raises(ValueError, match="labels_getter should either be a str"): | ||
with pytest.raises(ValueError, match="labels_getter should either be 'default'"): | ||
transforms.SanitizeBoundingBox(labels_getter=12) | ||
|
||
with pytest.raises(ValueError, match="Could not infer where the labels are"): | ||
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} | ||
transforms.SanitizeBoundingBox()(bad_labels_key) | ||
|
||
with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that I had to delete this. I feel like it was a valid error to raise. We could put it back if we were to have a different This is OK for now. |
||
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0])) | ||
transforms.SanitizeBoundingBox()(not_a_dict) | ||
|
||
with pytest.raises(ValueError, match="must be a tensor"): | ||
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()} | ||
transforms.SanitizeBoundingBox()(not_a_tensor) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: technically, paper names of these techniques are CutMix and MixUp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, saw this after I merged. I'll address via #7766