Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.pixel_shuffle dynamo converter #2596

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,29 @@ def aten_ops_reshape(
)


@dynamo_tensorrt_converter(torch.ops.aten.pixel_shuffle.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_pixel_shuffle(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.pixel_shuffle(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
def aten_ops_argmax(
Expand Down
41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Sequence, Union

import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR
Expand All @@ -19,3 +20,43 @@ def reshape(
layer.reshape_dims = tuple(shape)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def pixel_shuffle(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
upscale_factor: int,
) -> TRTTensor:
shape = input.shape
in_channels, in_height, in_width = shape[-3:]
out_channels = in_channels // (upscale_factor**2)
out_height = in_height * upscale_factor
out_width = in_width * upscale_factor
new_shape = shape[:-3] + (
out_channels,
upscale_factor,
upscale_factor,
in_height,
in_width,
)
reshaped_tensor = reshape(
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
)
rank = len(shape)
permute_shape = list(range(rank))
permute_shape.insert(-2, rank)
permute_shape.insert(-1, rank + 1)
permuted_tensor = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
)
Comment on lines +33 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for the intermediate reshape and permute here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. This is because if we directly use reshape, the output shape will be correct, but the values won't. For example:

>>> x = torch.randn((8,2,3))
>>> x.reshape((2,4,6))
tensor([[[-1.4204, -0.4205, -1.3309, -1.1576, -1.8777, -1.3462],
         [-0.5689, -0.1234, -0.5276,  1.2325,  0.2859,  0.4005],
         [-0.7908, -0.4946, -0.7183,  0.2497, -0.6588, -1.0771],
         [ 0.5446, -0.0980,  0.9309, -2.9004,  1.9834, -0.2377]],

        [[ 1.3769,  0.5741, -0.3463,  0.6038, -0.9376,  1.1402],
         [-0.1754,  0.4850, -3.5597, -0.5911,  1.7931, -1.7492],
         [ 0.9871, -0.2294,  0.7445, -0.0991,  0.0278,  0.6699],
         [-0.1543, -1.4414, -0.6795, -0.0403,  0.4620, -1.2007]]])
>>> torch.nn.functional.pixel_shuffle(x, upscale_factor=2)
tensor([[[-1.4204, -0.5689, -0.4205, -0.1234, -1.3309, -0.5276],
         [-0.7908,  0.5446, -0.4946, -0.0980, -0.7183,  0.9309],
         [-1.1576,  1.2325, -1.8777,  0.2859, -1.3462,  0.4005],
         [ 0.2497, -2.9004, -0.6588,  1.9834, -1.0771, -0.2377]],

        [[ 1.3769, -0.1754,  0.5741,  0.4850, -0.3463, -3.5597],
         [ 0.9871, -0.1543, -0.2294, -1.4414,  0.7445, -0.6795],
         [ 0.6038, -0.5911, -0.9376,  1.7931,  1.1402, -1.7492],
         [-0.0991, -0.0403,  0.0278,  0.4620,  0.6699, -1.2007]]])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example - that's very helpful

return reshape(
ctx,
target,
source_ir,
f"{name}_reshape2",
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
)
31 changes: 31 additions & 0 deletions tests/py/dynamo/conversion/test_pixel_shuffle_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestPixelShuffleConverter(DispatchTestCase):
@parameterized.expand(
[
((1, 1, 1), 1),
((12, 3, 4), 2),
((1, 9, 4, 4), 3),
((2, 32, 2, 3), 4),
((1, 10, 36, 2, 4), 6),
]
)
def test_pixel_shuffle(self, shape, upscale_factor):
class PixelShuffle(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.pixel_shuffle.default(x, upscale_factor)

inputs = [torch.randn(shape)]
self.run_test(
PixelShuffle(),
inputs,
)


if __name__ == "__main__":
run_tests()
Loading