Skip to content

Commit

Permalink
Add support for aten.pixel_unshuffle dynamo converter
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Mar 23, 2024
1 parent 7f14221 commit 0398f48
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,29 @@ def aten_ops_pixel_shuffle(
)


@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_pixel_unshuffle(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.shuffle.pixel_unshuffle(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
def aten_ops_argmax(
Expand Down
44 changes: 44 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,47 @@ def pixel_shuffle(
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
)


def pixel_unshuffle(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
downscale_factor: int,
) -> TRTTensor:
shape = input.shape
in_channels, in_height, in_width = shape[-3:]
out_channels = in_channels * (downscale_factor**2)
out_height = in_height // downscale_factor
out_width = in_width // downscale_factor
new_shape = shape[:-3] + (
in_channels,
out_height,
downscale_factor,
out_width,
downscale_factor,
)
reshaped_tensor = reshape(
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
)
rank = len(new_shape)
permute_shape = tuple(range(rank - 5)) + (
rank - 5, # in_channels
rank - 3, # downscale_factor
rank - 1, # downscale_factor
rank - 4, # out_height
rank - 2, # out_width
)
permuted_tensor = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
)
return reshape(
ctx,
target,
source_ir,
f"{name}_reshape2",
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
)
29 changes: 29 additions & 0 deletions tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestPixelUnshuffleConverter(DispatchTestCase):
@parameterized.expand(
[
((1, 1, 1), 1),
((1, 1, 12, 12), 3),
((2, 3, 4, 25, 30), 5),
]
)
def test_pixel_unshuffle(self, shape, downscale_factor):
class PixelUnshuffle(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.pixel_unshuffle.default(x, downscale_factor)

inputs = [torch.randn(shape)]
self.run_test(
PixelUnshuffle(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 0398f48

Please sign in to comment.