-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify _NO_WRAPPING_EXCEPTIONS #7806
Conversation
@@ -203,4 +203,3 @@ def test_deepcopy(datapoint, requires_grad): | |||
|
|||
assert type(datapoint_deepcopied) is type(datapoint) | |||
assert datapoint_deepcopied.requires_grad is requires_grad | |||
assert datapoint_deepcopied.is_leaf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the deepcopy isn't a leaf anymore because it went through wrap_like()
, so it's got an "ancestor".
I don't think is_leaf
is part of the deepcopy
contract anyway? I don't think we really need to enforce this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the deepcopy isn't a leaf anymore because it went through
wrap_like()
, so it's got an "ancestor". I don't thinkis_leaf
is part of thedeepcopy
contract anyway? I don't think we really need to enforce this.
I don't think it is specified anywhere, so I'm ok with removing this check. Might be surprising to users though if they bank on this. Let's find out though 🤷
if isinstance(output, cls): | ||
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, | ||
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also move most of the content of the DisableTorchFunctionSubclass
out of it. The only part that matters is the call to func
, the rest can be outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if CI is green. Thanks Nicolas!
@@ -203,4 +203,3 @@ def test_deepcopy(datapoint, requires_grad): | |||
|
|||
assert type(datapoint_deepcopied) is type(datapoint) | |||
assert datapoint_deepcopied.requires_grad is requires_grad | |||
assert datapoint_deepcopied.is_leaf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the deepcopy isn't a leaf anymore because it went through
wrap_like()
, so it's got an "ancestor". I don't thinkis_leaf
is part of thedeepcopy
contract anyway? I don't think we really need to enforce this.
I don't think it is specified anywhere, so I'm ok with removing this check. Might be surprising to users though if they bank on this. Let's find out though 🤷
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Reviewed By: matteobettini Differential Revision: D48642278 fbshipit-source-id: b74f5744ca32672d70f89dab2e8ef01b073c3be0
by calling
wrap_like
for all the ops in_NO_WRAPPING_EXCEPTIONS
insteaf of having a custom logic for each one.We don't need to save a few us on calls to
requires_grad_()
.cc @vfdev-5