Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support activation dynamo converters #2254

Merged
merged 3 commits into from
Sep 1, 2023

Conversation

zewenli98
Copy link
Collaborator

Description

Support activation dynamo converters, including relu, sigmoid, tanh, leaky_relu, elu, selu, softplus, clip, hardsigmoid.

Fixes #2201

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests labels Aug 22, 2023
@github-actions github-actions bot requested a review from gs-olive August 22, 2023 01:34
@zewenli98 zewenli98 force-pushed the activation_dynamo_converters branch from a9c002b to 4e0288a Compare August 22, 2023 21:01
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.layer_norm(
return impl.actv.sigmoid(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the full name (activation)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but got error:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/__init__.py", line 86, in <module>
    from torch_tensorrt._compile import *  # noqa: F403
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 13, in <module>
    from torch_tensorrt.dynamo.compile import compile as dynamo_compile
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/__init__.py", line 13, in <module>
    from .compile import compile  # noqa: F403
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/compile.py", line 13, in <module>
    from torch_tensorrt.dynamo import CompilationSettings, partitioning
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/__init__.py", line 1, in <module>
    from ._adjacency_partitioner import partition as fast_partition
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py", line 16, in <module>
    from torch_tensorrt.dynamo.conversion.converter_registry import (
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/__init__.py", line 2, in <module>
    from .aten_ops_converters import *  # noqa: F403
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 8, in <module>
    from torch_tensorrt.dynamo.conversion import impl
  File "/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/impl/__init__.py", line 3, in <module>
    from . import (
ImportError: cannot import name 'activation' from partially initialized module 'torch_tensorrt.dynamo.conversion.impl' (most likely due to a circular import) (/opt/circleci/.pyenv/versions/3.10.9/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/impl/__init__.py)

Exited with code exit status 1
CircleCI received exit code 1

Did you come across the similar error by any chance?

Copy link
Collaborator

@gs-olive gs-olive Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen this error before. The issue seems to be here:

from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation

It is trying to import a function which is in the module being currently initialized. I think it can be fixed by replacing with

from .base import convert_activation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks George! I modified as you suggested but it looks like not the problem. I got the same error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing this - I actually tried this branch on my own machine and I'm not seeing this error, with either the from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation or from .base import convert_activation. Do you see the error locally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your help! No, I don't. I typically push every time after passing all tests on my local machine. That's weird...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! Definitely strange - could try a rebase to main to resolve the merge conflict and see if anything changes with that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased, but it doesn't work (I used from .base import convert_activation). I guess it's not the problem of torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation because the unary folder uses the similar from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary but it works. 😵


plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)

layer = network.add_plugin_v2([input_val], plugin)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This need to be handled in lowering instead of the plugin @peri044

@zewenli98 zewenli98 force-pushed the activation_dynamo_converters branch 3 times, most recently from f83615a to e878a8c Compare August 24, 2023 21:51

plugin = get_trt_plugin(plugin_name, field_collection, plugin_version)

layer = network.add_plugin_v2([input_val], plugin)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use a lowering pass and run it natively instead of in a plugin

Copy link
Collaborator

@gs-olive gs-olive Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is currently an aten.gelu decomposition enabled on main, so this can potentially be removed.

Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments on switching to functional implementations of operators. Additionally, please rebase to main to resolve merge conflicts.

operation_type = trt.ActivationType.SELU

def selu_dyn_range_fn(dyn_range):
return (torch.nn.SELU(dyn_range[0]), torch.nn.SELU(dyn_range[1]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to torch.nn.functional.selu, to use functional implementation.

operation_type = trt.ActivationType.SOFTSIGN

def softsign_dyn_range_fn(dyn_range):
return (torch.nn.Softsign(dyn_range[0]), torch.nn.Softsign(dyn_range[1]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to torch.nn.functional.softsign, to use functional implementation.

Comment on lines 215 to 216
torch.nn.Softplus(dyn_range[0], beta),
torch.nn.Softplus(dyn_range[1], beta),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here for torch.nn.functional.softplus


def scaled_tanh_dyn_range_fn(dyn_range):
def scaled_tanh_fn(x):
return alpha * torch.nn.Tanh(beta * x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 140 to 141
torch.nn.ELU(dyn_range[0], alpha),
torch.nn.ELU(dyn_range[1], alpha),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here: torch.nn.functional.elu

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive Thanks! I modified and rebased.

@zewenli98 zewenli98 force-pushed the activation_dynamo_converters branch from b8d189a to 60169bb Compare August 29, 2023 23:55
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cyclic import issue may be caused by the fact that a new directory/module was created, but not added to the setup.py. Could you try adding "torch_tensorrt.dynamo.conversion.impl.activation" here:

TensorRT/setup.py

Lines 396 to 397 in e49ef6d

"torch_tensorrt.dynamo.conversion.impl.unary",
"torch_tensorrt.dynamo.lowering",

As well as "torch_tensorrt.dynamo.conversion.impl.activation": "py/torch_tensorrt/dynamo/conversion/impl/activation" here:

TensorRT/setup.py

Lines 422 to 423 in e49ef6d

"torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary",
"torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering",

from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .base import convert_activation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be changed back to from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation

@zewenli98 zewenli98 force-pushed the activation_dynamo_converters branch from 60169bb to a4722d9 Compare August 31, 2023 23:07
@github-actions github-actions bot added the component: build system Issues re: Build system label Aug 31, 2023
@zewenli98
Copy link
Collaborator Author

@gs-olive Oh I didn't even know this before. I guess that's the problem! thanks! updated the code!

@zewenli98 zewenli98 requested a review from gs-olive August 31, 2023 23:12
@gs-olive
Copy link
Collaborator

@zewenli98 - sure, no problem! I think since the gelu converter was removed in this PR, we also need to skip those converter tests, so you can either add a pytest skip decorator or remove/comment the code.

lint test file

fix bugs: circular import

delete gelu

change function calls from nn.Module to nn.functional
@zewenli98 zewenli98 force-pushed the activation_dynamo_converters branch from a4722d9 to 27b2dcf Compare September 1, 2023 02:09
Copy link
Collaborator

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@gs-olive gs-olive merged commit d6a07bb into pytorch:main Sep 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Exposing IActivationLayer in dynamo.conversion.impl
4 participants