Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.pixel_shuffle dynamo converter #2596

Merged
merged 1 commit into from
Jan 24, 2024

Conversation

zewenli98
Copy link
Collaborator

Description

Support aten.pixel_shuffle dynamo converter.

Fixes #2594

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jan 15, 2024
@github-actions github-actions bot requested a review from narendasan January 15, 2024 16:58
@zewenli98 zewenli98 self-assigned this Jan 16, 2024
@zewenli98 zewenli98 force-pushed the pixel_shuffle_dynamo_converter branch from 004810f to e2c7081 Compare January 21, 2024 15:14
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall - just asked a clarifying question about the implementation details.

Comment on lines +33 to +54
shape = input.shape
in_channels, in_height, in_width = shape[-3:]
out_channels = in_channels // (upscale_factor**2)
out_height = in_height * upscale_factor
out_width = in_width * upscale_factor
new_shape = shape[:-3] + (
out_channels,
upscale_factor,
upscale_factor,
in_height,
in_width,
)
reshaped_tensor = reshape(
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
)
rank = len(shape)
permute_shape = list(range(rank))
permute_shape.insert(-2, rank)
permute_shape.insert(-1, rank + 1)
permuted_tensor = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for the intermediate reshape and permute here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. This is because if we directly use reshape, the output shape will be correct, but the values won't. For example:

>>> x = torch.randn((8,2,3))
>>> x.reshape((2,4,6))
tensor([[[-1.4204, -0.4205, -1.3309, -1.1576, -1.8777, -1.3462],
         [-0.5689, -0.1234, -0.5276,  1.2325,  0.2859,  0.4005],
         [-0.7908, -0.4946, -0.7183,  0.2497, -0.6588, -1.0771],
         [ 0.5446, -0.0980,  0.9309, -2.9004,  1.9834, -0.2377]],

        [[ 1.3769,  0.5741, -0.3463,  0.6038, -0.9376,  1.1402],
         [-0.1754,  0.4850, -3.5597, -0.5911,  1.7931, -1.7492],
         [ 0.9871, -0.2294,  0.7445, -0.0991,  0.0278,  0.6699],
         [-0.1543, -1.4414, -0.6795, -0.0403,  0.4620, -1.2007]]])
>>> torch.nn.functional.pixel_shuffle(x, upscale_factor=2)
tensor([[[-1.4204, -0.5689, -0.4205, -0.1234, -1.3309, -0.5276],
         [-0.7908,  0.5446, -0.4946, -0.0980, -0.7183,  0.9309],
         [-1.1576,  1.2325, -1.8777,  0.2859, -1.3462,  0.4005],
         [ 0.2497, -2.9004, -0.6588,  1.9834, -1.0771, -0.2377]],

        [[ 1.3769, -0.1754,  0.5741,  0.4850, -0.3463, -3.5597],
         [ 0.9871, -0.1543, -0.2294, -1.4414,  0.7445, -0.6795],
         [ 0.6038, -0.5911, -0.9376,  1.7931,  1.1402, -1.7492],
         [-0.0991, -0.0403,  0.0278,  0.4620,  0.6699, -1.2007]]])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example - that's very helpful

@zewenli98 zewenli98 requested a review from gs-olive January 23, 2024 14:37
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

@zewenli98 zewenli98 merged commit 593ff44 into pytorch:main Jan 24, 2024
21 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.pixel_shuffle
3 participants