-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
promote Mixup and Cutmix from prototype to transforms v2 #7731
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7731
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New FailuresAs of commit 993f693: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, only made minor comments for now, overall this looks great. Will give a more thorough look once we have the tests
references/classification/train.py
Outdated
if batch_transform: | ||
image, target = batch_transform(image, target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I failed to notice this when we discussed it offline, but we should keep those transforms as collate_fn
: calling them after the dataloder like done here means we can't leverage multi-processing.
references/classification/train.py
Outdated
from torchvision.transforms.functional import InterpolationMode | ||
from transforms import get_batch_transform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sooooo to avoid bikeshedding on how we should call those (batch transforms vs pairwise transforms vs something else), maybe we should just rename that to get_cutmix_mixup
?
msg = "Couldn't find a label in the inputs." | ||
if self.labels_getter == "default": | ||
msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can write that entire message regardless of whether "default" was passed. It would simplify the logic a bit and avoid storing self.labels_getter
.
msg = "Couldn't find a label in the inputs." | ||
if self.labels_getter == "default": | ||
msg = f"{msg} To overwrite the default find behavior, pass a callable for labels_getter." | ||
raise RuntimeError(msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this could qualify as a ValueError as well?
|
||
# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor, but coming | ||
# after an image or video. However, since we want to handle them in _transform, we | ||
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We used a different strategy in SanitizeBoundingBox
where we called _transform()
on all inputs and just handled that filtering logic within _transform()
. I don't have a pref right now (haven't thought about it much). But maybe we should align both transforms to follow the same strat? (we could do it in another PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't use the same strategy here. SanitizeBoundingBox
does not affect images or videos, so we don't care about needs_transform_list
there. However, here we transform images. Meaning, we need to use needs_transform_list
to make use of the heuristic about what image to transform. This cannot be done in _transform
since in there we have no concept if an image should be transformed or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we transform all images (each image is collated as (N, C, H, W)) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I understand what you mean: we'd need to re-implement the "tensor pass-through heuristic" in _transform()
if we were to do something like in SanitizeBoundingBox()
, and we don't want to do that. I feel like we could use the same strategy used here in SanitizeBoudingBox()
though. But that's OK.
Shouldn't we transform all images (each image is collated as (N, C, H, W)) ?
We are transforming all images yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we could use the same strategy used here in
SanitizeBoudingBox()
though. But that's OK.
Yes, we can certainly also use needs_transform_list
there. I'm ok with that. Up to you.
@@ -93,3 +96,61 @@ def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: | |||
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: | |||
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: | |||
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") | |||
|
|||
|
|||
def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this method can be a class that could finetune itself after the first iteration on the input type and skip checking for tuple if batch is a dict and set up the label key ?
Another idea could be to provide predefined labels_getter
for these two situations...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point' it'd be interesting to figure out whether this makes things faster. Might be best to leave this out as future improvement though, to keep this PR simpler.
One thing to note: doing this would tie the transform instance to a specific dataset [format]. IDK whether this is a problem in practice, but worth keeping in mind
Quick update: after chatting with @pmeier we decided to remove support for 2D labels. There isn't a strong need for it considering CutMix and MixUp should probably never be called consecutively - it's either one of the other. Should users request this to be supported for whatever reason (maybe Compose(CutMix(), CutMix()) make sense??) then we can add this back. support was removed in 9f4a9e6 |
torchvision/transforms/v2/_utils.py
Outdated
contains no "label-like" key. | ||
|
||
""" | ||
# TODO: Document list and why |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sooo I've decided not to document it in the code because this will probably just add more confusion. But the reason we need to add support for list
is because this is what the DataLoader actually returns (in its most default setting):
for x in DataLoader(...):
# x is a list [img_batch, labels_batch] and we want to support
CutMix(x)
the key whose value corresponds to the labels. It can also be a callable that takes the same input | ||
as the transform, and returns the labels. | ||
By default, this will try to find a "labels" key in the input, if | ||
By default, this will try to find a "labels" key in the input (case-insensitive), if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I chose not to document the "labels" key matching in the finer details. We can revisit (or point users to check the code?)
How to use Cutmix and Mixup | ||
=========================== | ||
|
||
TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import utils | ||
from sampler import RASampler | ||
from torch import nn | ||
from torch.utils.data.dataloader import default_collate | ||
from torchvision.transforms.functional import InterpolationMode | ||
from transforms import get_mixup_cutmix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll validate the changes made to the references once this is merged.
transforms.SanitizeBoundingBox(labels_getter=12) | ||
|
||
with pytest.raises(ValueError, match="Could not infer where the labels are"): | ||
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} | ||
transforms.SanitizeBoundingBox()(bad_labels_key) | ||
|
||
with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I had to delete this. I feel like it was a valid error to raise. We could put it back if we were to have a different labels_getter
logic for SanitizeBBox and the Cutmix/Mixup ones (which is probably going to be needed eventually anyway).
This is OK for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are conflicts to resolve |
They're trivial, I'll address at the next (and hopefully last) review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Nicolas for finishing this. I can't approve though, since it is technically my PR.
@@ -261,6 +261,22 @@ The new transform can be used standalone or mixed-and-matched with existing tran | |||
AugMix | |||
v2.AugMix | |||
|
|||
Cutmix - Mixup |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: technically, paper names of these techniques are CutMix and MixUp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, saw this after I merged. I'll address via #7766
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
TL;DR: this PR promotes Mixup and Cutmix from
torchvision.prototype.transforms
totorchvision.transforms.v2
by reusing thelabels_getter
functionality that we have forSanitizeBoundingBoxes
.To achieve this, the following this are implemented here:
Factor out the static "
labels_getter
" methods fromSanitizeBoundingBoxes
. While doing that we also change the functionality slightly to make the handling a little easier:labels_getter
, since that is just slightly more convenient than passing a callable directly, while making the handling harder for us.Remove the
p
parameter. We have this parameter in our references, since we have based them on a research implementation:vision/references/classification/transforms.py
Line 22 in 08c9938
However, by design, a research implementation will have more knobs than a stable library. In fact, we are hardcoding the parameter in our references:
vision/references/classification/train.py
Lines 224 to 227 in 08c9938
By removing the
p
parameter in this PR, we get the same behavior that we currently have in our references as well. If the user need the more flexible behavior back, they can always wrap the transform likeRandomApply(Mixup(...), p=...)
.Since the existence of the
p
parameter was the reason to prefix "Random" before the "canonical" names Mixup and Cutmix, I've dropped the prefix here as well.Follow-up to Add --use-v2 support to classification references #7724. The implementation of this PR will be available in the classification references with
--use-v2
. I also refactored the training script to just use the transform inside the training loop rather than putting it inside thecollate_fn
. This makes it more clear that this transform needs batching, but is otherwise independent of the data loader.ToDo
SanitizeBoundingBoxes
Mixup
andcutmix
RandomMixup
andRandomCutmix
cc @vfdev-5