-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow register_kernel() to take dispatcher name as input #7796
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7796
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 12fb8e1: NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base f3c89cc:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
import torchvision.transforms.v2.functional # noqa | ||
|
||
try: | ||
return next( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the most efficient, we could store the mapping somewhere. I don't think we care anyway since this is just executed during registration, not when calling the dispatcher.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand the iteration here. Can't we just do
try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(...) from None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes much better!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok with this.
Side note: I thought about allowing the transform name as well e.g.
@register_kernel("Resize", MyDatapoint)
, but I don't think we should do that because that won't work in general anyway, as some transforms rely on more than one dispatcher.
👍 for the reason you mentioned.
import torchvision.transforms.v2.functional # noqa | ||
|
||
try: | ||
return next( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand the iteration here. Can't we just do
try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(...) from None
plus some minor tests for
register_kernel
.Writing the tutorials in #7795 I thought that we should allow to register kernels by name, not just by callables. Typically if I'm a dev and I just care about
Resize()
(the transform class), I should just be able to register my kernel with@register_kernel("resize", MyDatapoint)
instead of having to do@register_kernel(resize, MyDatapoint)
which forces me to import the functional dispatcher for no obvious reason.Side note: I thought about allowing the transform name as well e.g.
@register_kernel("Resize", MyDatapoint)
, but I don't think we should do that because that won't work in general anyway, as some transforms rely on more than one dispatcher.cc @vfdev-5