From 8f8f93663d7a50bf8ce1769f7e406144d01caedc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Aug 2023 15:55:58 +0100 Subject: [PATCH 1/2] move stuff out of CM --- torchvision/datapoints/_datapoint.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 9b1c648648d..bf78c230504 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -37,6 +37,8 @@ def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) _NO_WRAPPING_EXCEPTIONS = { + # The ops in this dict are those that should *preserve* the Datapoint + # type, i.e. they are exceptions to the "no wrapping" rule. torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output), @@ -79,22 +81,22 @@ def __torch_function__( with DisableTorchFunctionSubclass(): output = func(*args, **kwargs or dict()) - wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) + wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) + if wrapper and isinstance(args[0], cls): # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # be wrapped into a `datapoints.Image`. - if wrapper and isinstance(args[0], cls): - return wrapper(cls, args[0], output) + return wrapper(cls, args[0], output) - # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, - # will retain the input type. Thus, we need to unwrap here. - if isinstance(output, cls): - return output.as_subclass(torch.Tensor) + if isinstance(output, cls): + # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`, + # so for those, the output is still a Datapoint. Thus, we need to manually unwrap. + return output.as_subclass(torch.Tensor) - return output + return output def _make_repr(self, **kwargs: Any) -> str: # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. From b1018a98fed3f4efde10dc7ac1620b318d002328 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Aug 2023 16:19:03 +0100 Subject: [PATCH 2/2] Call wrap_like for all exceptions --- test/test_datapoints.py | 1 - torchvision/datapoints/_datapoint.py | 20 ++++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/test/test_datapoints.py b/test/test_datapoints.py index 25a2182e050..ded9a771e14 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -203,4 +203,3 @@ def test_deepcopy(datapoint, requires_grad): assert type(datapoint_deepcopied) is type(datapoint) assert datapoint_deepcopied.requires_grad is requires_grad - assert datapoint_deepcopied.is_leaf diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index bf78c230504..2faa4e3716a 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -36,16 +36,9 @@ def _to_tensor( def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: return tensor.as_subclass(cls) - _NO_WRAPPING_EXCEPTIONS = { - # The ops in this dict are those that should *preserve* the Datapoint - # type, i.e. they are exceptions to the "no wrapping" rule. - torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), - torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), - torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output), - # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus - # retains the type automatically - torch.Tensor.requires_grad_: lambda cls, input, output: output, - } + # The ops in this set are those that should *preserve* the Datapoint type, + # i.e. they are exceptions to the "no wrapping" rule. + _NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_} @classmethod def __torch_function__( @@ -81,15 +74,14 @@ def __torch_function__( with DisableTorchFunctionSubclass(): output = func(*args, **kwargs or dict()) - wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) - if wrapper and isinstance(args[0], cls): - # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be + if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls): + # We also require the primary operand, i.e. `args[0]`, to be # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # be wrapped into a `datapoints.Image`. - return wrapper(cls, args[0], output) + return cls.wrap_like(args[0], output) if isinstance(output, cls): # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,