Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify _NO_WRAPPING_EXCEPTIONS #7806

Merged
merged 3 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,3 @@ def test_deepcopy(datapoint, requires_grad):

assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad
assert datapoint_deepcopied.is_leaf
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the deepcopy isn't a leaf anymore because it went through wrap_like(), so it's got an "ancestor".
I don't think is_leaf is part of the deepcopy contract anyway? I don't think we really need to enforce this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the deepcopy isn't a leaf anymore because it went through wrap_like(), so it's got an "ancestor". I don't think is_leaf is part of the deepcopy contract anyway? I don't think we really need to enforce this.

I don't think it is specified anywhere, so I'm ok with removing this check. Might be surprising to users though if they bank on this. Let's find out though 🤷

28 changes: 11 additions & 17 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,9 @@ def _to_tensor(
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)

_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}
# The ops in this set are those that should *preserve* the Datapoint type,
# i.e. they are exceptions to the "no wrapping" rule.
_NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -79,22 +74,21 @@ def __torch_function__(
with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())

wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls):
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output)
return cls.wrap_like(args[0], output)

# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
return output.as_subclass(torch.Tensor)
if isinstance(output, cls):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
Comment on lines +83 to +85
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also move most of the content of the DisableTorchFunctionSubclass out of it. The only part that matters is the call to func, the rest can be outside.

return output.as_subclass(torch.Tensor)

return output
return output

def _make_repr(self, **kwargs: Any) -> str:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
Expand Down