-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tuto for custom transforms and custom datapoints in gallery example #7795
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7795
Note: Links to docs will display an error until the docs builds have been completed. ✅ 3 Unrelated FailuresAs of commit 7be27e8: BROKEN TRUNK - The following jobs failed but were present on the merge base bf03f4e:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# Do I have to wrap the output of the datasets myself? | ||
# ---------------------------------------------------- | ||
# | ||
# TODO: Move this in another guide - this is user-facing, not dev-facing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will address eventually and open an issue so we don't forget. Just LMK if this is OK in principle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
vision/gallery/plot_transforms_v2.py
Line 3 in 9ebf10a
Getting started with transforms v2 |
@@ -25,9 +32,13 @@ def _to_tensor( | |||
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False | |||
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) | |||
|
|||
@classmethod | |||
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: | |||
return tensor.as_subclass(cls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up moving _wrap
and wrap_like
in the base class. They were all implemented exactly in the same way for all datapoints except for BoundingBoxes
, because that one requires meta-data.
I think it's OK to assume that these are a good default implementations for most datapoints, i.e. those that don't require meta-data. What can go wrong anyway?
It makes the example much easier to write and explain, and more importantly it makes the Datapoint authoring experience much smoother. So I think we should do that. I can do it in another PR if you prefer though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok leaving it here.
@classmethod | ||
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: | ||
raise NotImplementedError | ||
return cls._wrap(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, so now I'm a bit confused.
We clearly don't use other
so we should remove it from there.
BUT then:
- don't we want to allow
wrap_like(other, t)
so that t gets the same device (and dtype??) asother
?? we don't seem to do that - do we eve need
wrap_like
for those datapoints that don't have meta-data? What's the point of doingImage.wrap_like(other, t)
or evenImage.wrap_like(t)
(not supported) when one can just doImage(t)
?
Anyway, I'd suggest to keep things as-is for now, and to address that stuff after this PR is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we have wrap_like
the way it is, is that we use it inside __torch_function__
:
vision/torchvision/datapoints/_datapoint.py
Lines 32 to 39 in 9ebf10a
_NO_WRAPPING_EXCEPTIONS = { | |
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), | |
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), | |
torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output), | |
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus | |
# retains the type automatically | |
torch.Tensor.requires_grad_: lambda cls, input, output: output, | |
} |
Basically, we need a generic way for "Here is a datapoint input possibly with some metadata. Given a plain tensor, make it the same type as the datapoint and copy the metadata if available".
So the reason we have other
in all of them is because if we need metadata and pass nothing explicitly, we need to take it from other
. This also means that on datapoints that have no metadata, we never use other
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I understand this is useful for us internally but I'm still questioning whether we should expose that as public for datapoints that don't have meta-data. Let's just keep things as-is in this PR so we can merge it soon, and revisit just after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, thanks Nicolas! I left a bunch of comments.
@@ -25,9 +32,13 @@ def _to_tensor( | |||
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False | |||
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) | |||
|
|||
@classmethod | |||
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D: | |||
return tensor.as_subclass(cls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok leaving it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, I addressed most comments or left some thoughts on the rest. LMK.
target["bboxes"] = datapoints.BoundingBoxes( | ||
bboxes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, I think we want to be consistent with the names used in this tutorial. What's being used in another tutorial is lower in priority when it comes to consistency
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Nicolas, awesome to have better docs now!
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
…ery example (#7795) Summary: Co-authored-by: Philip Meier <github.pmeier@posteo.de> Reviewed By: matteobettini Differential Revision: D48642308 fbshipit-source-id: efe98a8405393c0bf0ca84141b895265d93636d0
Bunch of docs.