From 4239a7b352a359895d7249e959df67d64ffac1c9 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 22 May 2023 15:36:10 -0700 Subject: [PATCH] fix/feat: Add Dynamo-only converter registry - Add Dynamo converter registry which functions as a superset of the standard FX converter registry - For use with new + experimental converters - Uses custom decorator `dynamo_tensorrt_converter` - Update references within Dynamo functions to use the converter registry `DYNAMO_CONVERTERS` --- py/torch_tensorrt/dynamo/__init__.py | 5 ++++ .../dynamo/backend/lowering/_partition.py | 2 +- .../dynamo/converter_registry.py | 23 +++++++++++++++++++ .../dynamo/fx_ts_compat/fx2trt.py | 2 +- 4 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/converter_registry.py diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index ea1778edfe..4068b11c87 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,2 +1,7 @@ +from .converter_registry import ( + DYNAMO_CONVERTERS, + dynamo_tensorrt_converter, +) + from torch_tensorrt.dynamo import fx_ts_compat from .backend import compile diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 4d82bf4be5..e3308cbe03 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -10,7 +10,7 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.operator_support import OperatorSupport -from torch_tensorrt.fx.converter_registry import CONVERTERS +from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/converter_registry.py b/py/torch_tensorrt/dynamo/converter_registry.py new file mode 100644 index 0000000000..1a9ee03970 --- /dev/null +++ b/py/torch_tensorrt/dynamo/converter_registry.py @@ -0,0 +1,23 @@ +from typing import Any, Callable, Dict + +from torch.fx.node import Target +from torch_tensorrt.fx.converter_registry import CONVERTERS + +DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS) + + +def dynamo_tensorrt_converter( + key: Target, + enabled: bool = True, +) -> Callable[[Any], Any]: + def register_converter(converter): + DYNAMO_CONVERTERS[key] = converter + return converter + + def disable_converter(converter): + return converter + + if enabled: + return register_converter + else: + return disable_converter diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index a29cee509d..03499aaf85 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -14,7 +14,7 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata -from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS +from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS from .input_tensor_spec import InputTensorSpec from torch_tensorrt.fx.observer import Observer from torch_tensorrt.fx.utils import (