Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/feat: Move convolution core to impl + add feature (FX converter refactor) #1972

Merged
merged 1 commit into from
Jun 30, 2023

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Jun 2, 2023

Description

  • Centralize convolution implementation in FX across all source IR variants, including support for conv1d, quantized, and other configurations
  • Update reference implementations across the stack to use centralized utility and remove individual replicated implementations
  • Allow conv layers to take bias inputs in FX, per new functionality from TRT
  • Enable pass-through of build errors in Dynamo e2e tests to ensure errors are not being hidden (this PR fixes a bug which disallowed that pass-through)

Fixes #1954
Addresses first bug in #1565

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive self-assigned this Jun 2, 2023
@github-actions github-actions bot requested a review from yinghai June 2, 2023 16:19
@gs-olive gs-olive added component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path labels Jun 2, 2023
@gs-olive gs-olive requested review from narendasan and frank-wei and removed request for yinghai June 2, 2023 16:29
@github-actions github-actions bot requested a review from yinghai June 2, 2023 16:29
@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label Jun 4, 2023
@gs-olive gs-olive force-pushed the enable_build_failures_e2e branch 7 times, most recently from 5348ac2 to 075a028 Compare June 5, 2023 02:07
@gs-olive gs-olive removed the WIP Work is in progress, pull request should not be merged yet label Jun 5, 2023
@gs-olive gs-olive requested review from wushirong and removed request for yinghai June 5, 2023 15:36
@gs-olive gs-olive changed the title fix: Allow FX convolution layers to take bias inputs fix/feat: Move convolution core to impl + add feature (FX converter refactor) Jun 5, 2023
@gs-olive gs-olive requested a review from apbose June 5, 2023 20:17
Comment on lines +60 to +64
# Process bias terms
if isinstance(bias, torch.Tensor):
# Transform the bias constant into a Numpy array
bias = to_numpy(bias)

elif isinstance(bias, TRTTensor):
bias = get_trt_tensor(network, bias, f"{name}_bias")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not add an unsqueeze operation to bias term since the requirement in TRT for the bias term is that it must have number of elements equal to the number of output features of the convolution, so the same bias as is used for Conv1D would work for Conv2D, with the number of output features being fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to clarify, you mean 1D or 2D conv? For 1D, we need bias to be unsqueezed.

Copy link
Collaborator Author

@gs-olive gs-olive Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially meant for all conv layers since this documentation seems to indicate we just need the number of elements in the bias Tensor to the be correct, and not necessarily the dimensions, but if bias needs to be unsqueezed for 1D, I can add that functionality back. I am wondering if the intended unsqueeze should be in the first dimension (torch.unsqueeze(bias, 0)) or the last dimension (torch.unsqueeze(bias, -1))?

Note: I think initially, it was torch.unsqueeze(bias, 0), while the weights and inputs were unsqueezed in the last dimension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if TRT has done the broadcast internally since the unit test for 1D works good even though you did not unsqueeze it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified on a small sample that Conv1D with bias compiles + runs inference successfully without unsqueezing the bias term

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a unit test for this just to catch if TRT behavior changes?

Copy link
Collaborator Author

@gs-olive gs-olive Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add one, but it would likely be very similar to this case, which interprets, builds, and runs inference on a Conv1D model with TRT both with and without bias:

class TestConvolutionConverter(AccTestCase):
@parameterized.expand(
[
("default", 1),
param("no_bias", 1, bias=False),
("tuple_parameters", 1, (1), (1)),
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
]
)
def test_conv1d(
self,
_,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(
3, 6, kernel_size, stride, padding, dilation, groups, bias
)
def forward(self, x):
return self.conv(x)
inputs = [torch.randn(1, 3, 32)]
self.run_test(
TestModule(),
inputs,
expected_ops={acc_ops.conv1d},
test_explicit_precision=True,
)

A breaking TRT change to that specific case should cause the accuracy check in the above test to fail.

@gs-olive gs-olive force-pushed the enable_build_failures_e2e branch 2 times, most recently from a6b5e6c to a21d778 Compare June 21, 2023 16:36
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly organization stuff

py/torch_tensorrt/fx/converters/convolution.py Outdated Show resolved Hide resolved
py/torch_tensorrt/fx/converters/convolution.py Outdated Show resolved Hide resolved
Comment on lines +60 to +64
# Process bias terms
if isinstance(bias, torch.Tensor):
# Transform the bias constant into a Numpy array
bias = to_numpy(bias)

elif isinstance(bias, TRTTensor):
bias = get_trt_tensor(network, bias, f"{name}_bias")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a unit test for this just to catch if TRT behavior changes?

@gs-olive gs-olive force-pushed the enable_build_failures_e2e branch from a21d778 to 69e8d33 Compare June 22, 2023 23:14
@gs-olive gs-olive force-pushed the enable_build_failures_e2e branch from 69e8d33 to de9938e Compare June 23, 2023 03:04
@gs-olive gs-olive requested a review from narendasan June 23, 2023 03:04
@gs-olive gs-olive force-pushed the enable_build_failures_e2e branch from de9938e to b7eea6f Compare June 23, 2023 04:06
@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label Jun 23, 2023
@github-actions github-actions bot requested a review from wushirong June 23, 2023 04:57
@gs-olive gs-olive force-pushed the enable_build_failures_e2e branch from b7eea6f to 0f5be88 Compare June 23, 2023 05:01
@gs-olive gs-olive removed the WIP Work is in progress, pull request should not be merged yet label Jun 23, 2023
- Centralize convolution implementation in FX, similar across all source
IRs, including aten, acc, nn
- Enable pass-through of build errors in e2e tests to ensure errors are
not being hidden
- Allow conv layers to take bias inputs in FX, per new functionality
from TRT
- Remove separate `convolution.py` file and centralize `nn` converters
to a single file
@gs-olive gs-olive force-pushed the enable_build_failures_e2e branch from 0f5be88 to 834064e Compare June 23, 2023 05:09
@gs-olive
Copy link
Collaborator Author

gs-olive commented Jun 26, 2023

HI @wushirong - thank you for the review. I was wondering if you could have another look at the changes, as I've moved convolution.py contents to nn_ops_converters.py for code cleanliness + organization and removed unused imports, in response to review comments by @narendasan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: fx fx Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] + 🐛 [Bug] Allow ITensor biases in aten.convolution converters
5 participants