Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

promote Mixup and Cutmix from prototype to transforms v2 #7731

Merged
merged 22 commits into from
Jul 28, 2023

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Jul 10, 2023

TL;DR: this PR promotes Mixup and Cutmix from torchvision.prototype.transforms to torchvision.transforms.v2 by reusing the labels_getter functionality that we have for SanitizeBoundingBoxes.

To achieve this, the following this are implemented here:

  1. Factor out the static "labels_getter" methods from SanitizeBoundingBoxes. While doing that we also change the functionality slightly to make the handling a little easier:

    • Remove the functionality of passing a string to labels_getter, since that is just slightly more convenient than passing a callable directly, while making the handling harder for us.
    • Add the functionality to return labels if we find a tensor as second element of a tuple. We need this for Mixup and Cutmix.
  2. Remove the p parameter. We have this parameter in our references, since we have based them on a research implementation:

    def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:

    However, by design, a research implementation will have more knobs than a stable library. In fact, we are hardcoding the parameter in our references:
    if args.mixup_alpha > 0.0:
    mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
    if args.cutmix_alpha > 0.0:
    mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))

    By removing the p parameter in this PR, we get the same behavior that we currently have in our references as well. If the user need the more flexible behavior back, they can always wrap the transform like RandomApply(Mixup(...), p=...).

    Since the existence of the p parameter was the reason to prefix "Random" before the "canonical" names Mixup and Cutmix, I've dropped the prefix here as well.

  3. Follow-up to Add --use-v2 support to classification references  #7724. The implementation of this PR will be available in the classification references with --use-v2. I also refactored the training script to just use the transform inside the training loop rather than putting it inside the collate_fn. This makes it more clear that this transform needs batching, but is otherwise independent of the data loader.

ToDo

  • Fix old tests for SanitizeBoundingBoxes
  • Write new tests for Mixup and cutmix
  • Remove prototype transforms RandomMixup and RandomCutmix

cc @vfdev-5

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7731

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures

As of commit 993f693:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, only made minor comments for now, overall this looks great. Will give a more thorough look once we have the tests

Comment on lines 28 to 29
if batch_transform:
image, target = batch_transform(image, target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I failed to notice this when we discussed it offline, but we should keep those transforms as collate_fn: calling them after the dataloder like done here means we can't leverage multi-processing.

from torchvision.transforms.functional import InterpolationMode
from transforms import get_batch_transform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sooooo to avoid bikeshedding on how we should call those (batch transforms vs pairwise transforms vs something else), maybe we should just rename that to get_cutmix_mixup?

Comment on lines 164 to 166
msg = "Couldn't find a label in the inputs."
if self.labels_getter == "default":
msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can write that entire message regardless of whether "default" was passed. It would simplify the logic a bit and avoid storing self.labels_getter.

msg = "Couldn't find a label in the inputs."
if self.labels_getter == "default":
msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter."
raise RuntimeError(msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this could qualify as a ValueError as well?


# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor, but coming
# after an image or video. However, since we want to handle them in _transform, we
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used a different strategy in SanitizeBoundingBox where we called _transform() on all inputs and just handled that filtering logic within _transform(). I don't have a pref right now (haven't thought about it much). But maybe we should align both transforms to follow the same strat? (we could do it in another PR)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use the same strategy here. SanitizeBoundingBox does not affect images or videos, so we don't care about needs_transform_list there. However, here we transform images. Meaning, we need to use needs_transform_list to make use of the heuristic about what image to transform. This cannot be done in _transform since in there we have no concept if an image should be transformed or not.

Copy link
Collaborator

@vfdev-5 vfdev-5 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we transform all images (each image is collated as (N, C, H, W)) ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what you mean: we'd need to re-implement the "tensor pass-through heuristic" in _transform() if we were to do something like in SanitizeBoundingBox(), and we don't want to do that. I feel like we could use the same strategy used here in SanitizeBoudingBox() though. But that's OK.

Shouldn't we transform all images (each image is collated as (N, C, H, W)) ?

We are transforming all images yes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could use the same strategy used here in SanitizeBoudingBox() though. But that's OK.

Yes, we can certainly also use needs_transform_list there. I'm ok with that. Up to you.

@pmeier pmeier requested a review from vfdev-5 July 11, 2023 09:57
@@ -93,3 +96,61 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")


def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
Copy link
Collaborator

@vfdev-5 vfdev-5 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this method can be a class that could finetune itself after the first iteration on the input type and skip checking for tuple if batch is a dict and set up the label key ?

Another idea could be to provide predefined labels_getter for these two situations...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point' it'd be interesting to figure out whether this makes things faster. Might be best to leave this out as future improvement though, to keep this PR simpler.
One thing to note: doing this would tie the transform instance to a specific dataset [format]. IDK whether this is a problem in practice, but worth keeping in mind

@vfdev-5 vfdev-5 requested review from NicolasHug and vfdev-5 July 11, 2023 10:45
@NicolasHug
Copy link
Member

Quick update: after chatting with @pmeier we decided to remove support for 2D labels. There isn't a strong need for it considering CutMix and MixUp should probably never be called consecutively - it's either one of the other.

Should users request this to be supported for whatever reason (maybe Compose(CutMix(), CutMix()) make sense??) then we can add this back.

support was removed in 9f4a9e6

contains no "label-like" key.

"""
# TODO: Document list and why
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flag

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sooo I've decided not to document it in the code because this will probably just add more confusion. But the reason we need to add support for list is because this is what the DataLoader actually returns (in its most default setting):

for x in DataLoader(...):
    # x is  a list [img_batch, labels_batch] and we want to support
    CutMix(x)

the key whose value corresponds to the labels. It can also be a callable that takes the same input
as the transform, and returns the labels.
By default, this will try to find a "labels" key in the input, if
By default, this will try to find a "labels" key in the input (case-insensitive), if
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose not to document the "labels" key matching in the finer details. We can revisit (or point users to check the code?)

How to use Cutmix and Mixup
===========================

TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import utils
from sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
from transforms import get_mixup_cutmix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll validate the changes made to the references once this is merged.

transforms.SanitizeBoundingBox(labels_getter=12)

with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBox()(bad_labels_key)

with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I had to delete this. I feel like it was a valid error to raise. We could put it back if we were to have a different labels_getter logic for SanitizeBBox and the Cutmix/Mixup ones (which is probably going to be needed eventually anyway).

This is OK for now.

@NicolasHug NicolasHug marked this pull request as ready for review July 27, 2023 11:34
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this LGTM. Since I made a bunch of changes, this should get another round of reviews from @pmeier @vfdev-5 before merging

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jul 27, 2023

There are conflicts to resolve

@NicolasHug
Copy link
Member

They're trivial, I'll address at the next (and hopefully last) review

Copy link
Collaborator Author

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Nicolas for finishing this. I can't approve though, since it is technically my PR.

@@ -261,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AugMix
v2.AugMix

Cutmix - Mixup
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: technically, paper names of these techniques are CutMix and MixUp.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, saw this after I merged. I'll address via #7766

@NicolasHug NicolasHug merged commit 3591371 into pytorch:main Jul 28, 2023
@github-actions
Copy link

Hey @NicolasHug!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@pmeier pmeier deleted the cutmix-mixup branch July 28, 2023 09:45
facebook-github-bot pushed a commit that referenced this pull request Aug 25, 2023
)

Reviewed By: matteobettini

Differential Revision: D48642303

fbshipit-source-id: e1e379d7dc99fee094fb5a3f7f97e0cd1eb93028

Co-authored-by: Nicolas Hug <nicolashug@meta.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants