From 0eae97fbc276789b537a706e1867827fd0275606 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 20 Jul 2023 10:16:27 -0700 Subject: [PATCH] feat: Dynamo refactor (#2104) Signed-off-by: Dheeraj Peri Co-authored-by: gs-olive <113141689+gs-olive@users.noreply.github.com> --- .circleci/config.yml | 47 +- py/torch_tensorrt/_Input.py | 29 +- py/torch_tensorrt/_compile.py | 73 ++- py/torch_tensorrt/dynamo/__init__.py | 5 +- .../dynamo/{backend => }/_defaults.py | 5 +- .../dynamo/{backend => }/_settings.py | 7 +- py/torch_tensorrt/dynamo/aten_tracer.py | 157 ++++++ py/torch_tensorrt/dynamo/backend/__init__.py | 168 +------ py/torch_tensorrt/dynamo/backend/backends.py | 12 +- py/torch_tensorrt/dynamo/backend/utils.py | 111 ----- .../dynamo/common_utils/__init__.py | 36 -- .../dynamo/common_utils/test_utils.py | 16 - py/torch_tensorrt/dynamo/compile.py | 171 +++++++ .../dynamo/conversion/__init__.py | 2 + .../{backend => conversion}/conversion.py | 23 +- .../trt_interpreter.py} | 120 ++--- .../dynamo/fx_ts_compat/README.md | 13 - .../dynamo/fx_ts_compat/__init__.py | 14 - .../dynamo/fx_ts_compat/input_tensor_spec.py | 181 ------- .../dynamo/fx_ts_compat/lower.py | 370 --------------- .../dynamo/fx_ts_compat/lower_setting.py | 102 ---- .../dynamo/fx_ts_compat/passes/__init__.py | 0 .../passes/lower_pass_manager_builder.py | 333 ------------- .../dynamo/fx_ts_compat/passes/pass_utils.py | 304 ------------ .../fx_ts_compat/test/core/test_input.py | 89 ---- .../test/core/test_input_tensor_spec.py | 84 ---- .../dynamo/fx_ts_compat/tools/__init__.py | 1 - .../fx_ts_compat/tools/common_fx2trt.py | 446 ------------------ .../fx_ts_compat/tools/trt_minimizer.py | 103 ---- .../dynamo/{backend => }/lowering/__init__.py | 1 + .../{backend => }/lowering/_decompositions.py | 0 py/torch_tensorrt/dynamo/lowering/_fusers.py | 72 +++ .../{backend => }/lowering/_partition.py | 4 +- .../lowering/_pre_aot_lowering.py | 0 .../lowering/substitutions/__init__.py | 0 .../lowering/substitutions/einsum.py | 2 +- .../lowering/substitutions/maxpool1d.py | 2 +- .../dynamo/runtime/_PythonTorchTRTModule.py | 256 ++++++++++ .../{ => runtime}/_TorchTensorRTModule.py | 12 +- py/torch_tensorrt/dynamo/runtime/__init__.py | 2 + py/torch_tensorrt/dynamo/utils.py | 162 +++++++ tests/core/conversion/converters/BUILD | 12 +- .../dynamo/backend}/test_backend_compiler.py | 18 +- .../py/dynamo/backend}/test_compiler_utils.py | 23 +- .../py/dynamo/backend}/test_decompositions.py | 13 +- .../py/dynamo/backend}/test_partitioning.py | 2 +- .../dynamo/backend}/test_pre_aot_lowering.py | 18 +- .../backend}/test_specialized_models.py | 18 +- .../test => tests/py/dynamo/backend}/utils.py | 8 +- .../py/dynamo/models}/conftest.py | 2 +- .../py/dynamo/models/test_models.py | 7 +- tests/py/dynamo/models/test_models_export.py | 152 ++++++ tests/py/{ => ts}/BUILD | 0 tests/py/{ => ts}/api/test_classes.py | 2 +- tests/py/{ => ts}/api/test_collections.py | 0 tests/py/{ => ts}/api/test_e2e_behavior.py | 0 tests/py/{ => ts}/api/test_embed_engines.py | 0 tests/py/{ => ts}/api/test_logging.py | 0 tests/py/{ => ts}/api/test_module_fallback.py | 2 + .../py/{ => ts}/api/test_operator_fallback.py | 3 + tests/py/{ => ts}/api/test_ts_backend.py | 0 tests/py/{ => ts}/api/utils.py | 0 tests/py/{ => ts}/hw/test_api_dla.py | 0 tests/py/{ => ts}/hw/test_multi_gpu.py | 0 tests/py/{ => ts}/hw/utils.py | 0 .../integrations/test_to_backend_api.py | 0 .../test_trt_intercompatibility.py | 0 tests/py/{ => ts}/integrations/utils.py | 0 tests/py/{ => ts}/model_test_case.py | 0 tests/py/{ => ts}/models/custom_models.py | 0 tests/py/{ => ts}/models/test_models.py | 4 + .../test_multiple_registered_engines.py | 1 + tests/py/{ => ts}/models/utils.py | 0 .../ptq/test_ptq_dataloader_calibrator.py | 0 tests/py/{ => ts}/ptq/test_ptq_to_backend.py | 0 .../{ => ts}/ptq/test_ptq_trt_calibrator.py | 0 .../py/{ => ts}/qat/test_qat_trt_accuracy.py | 0 tests/py/{ => ts}/requirements.txt | 0 tests/py/{ => ts}/utils.py | 0 79 files changed, 1225 insertions(+), 2595 deletions(-) rename py/torch_tensorrt/dynamo/{backend => }/_defaults.py (69%) rename py/torch_tensorrt/dynamo/{backend => }/_settings.py (83%) create mode 100644 py/torch_tensorrt/dynamo/aten_tracer.py delete mode 100644 py/torch_tensorrt/dynamo/backend/utils.py delete mode 100644 py/torch_tensorrt/dynamo/common_utils/__init__.py delete mode 100644 py/torch_tensorrt/dynamo/common_utils/test_utils.py create mode 100644 py/torch_tensorrt/dynamo/compile.py create mode 100644 py/torch_tensorrt/dynamo/conversion/__init__.py rename py/torch_tensorrt/dynamo/{backend => conversion}/conversion.py (81%) rename py/torch_tensorrt/dynamo/{fx_ts_compat/fx2trt.py => conversion/trt_interpreter.py} (74%) delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/README.md delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/lower.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/passes/__init__.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py delete mode 100644 py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py rename py/torch_tensorrt/dynamo/{backend => }/lowering/__init__.py (91%) rename py/torch_tensorrt/dynamo/{backend => }/lowering/_decompositions.py (100%) create mode 100644 py/torch_tensorrt/dynamo/lowering/_fusers.py rename py/torch_tensorrt/dynamo/{backend => }/lowering/_partition.py (98%) rename py/torch_tensorrt/dynamo/{backend => }/lowering/_pre_aot_lowering.py (100%) rename py/torch_tensorrt/dynamo/{backend => }/lowering/substitutions/__init__.py (100%) rename py/torch_tensorrt/dynamo/{backend => }/lowering/substitutions/einsum.py (96%) rename py/torch_tensorrt/dynamo/{backend => }/lowering/substitutions/maxpool1d.py (98%) create mode 100644 py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py rename py/torch_tensorrt/dynamo/{ => runtime}/_TorchTensorRTModule.py (95%) create mode 100644 py/torch_tensorrt/dynamo/runtime/__init__.py create mode 100644 py/torch_tensorrt/dynamo/utils.py rename {py/torch_tensorrt/dynamo/backend/test => tests/py/dynamo/backend}/test_backend_compiler.py (93%) rename {py/torch_tensorrt/dynamo/backend/test => tests/py/dynamo/backend}/test_compiler_utils.py (61%) rename {py/torch_tensorrt/dynamo/backend/test => tests/py/dynamo/backend}/test_decompositions.py (95%) rename {py/torch_tensorrt/dynamo/backend/test => tests/py/dynamo/backend}/test_partitioning.py (98%) rename {py/torch_tensorrt/dynamo/backend/test => tests/py/dynamo/backend}/test_pre_aot_lowering.py (87%) rename {py/torch_tensorrt/dynamo/backend/test => tests/py/dynamo/backend}/test_specialized_models.py (87%) rename {py/torch_tensorrt/dynamo/backend/test => tests/py/dynamo/backend}/utils.py (96%) rename {py/torch_tensorrt/dynamo/test => tests/py/dynamo/models}/conftest.py (85%) rename py/torch_tensorrt/dynamo/test/test_dynamo_backend.py => tests/py/dynamo/models/test_models.py (96%) create mode 100644 tests/py/dynamo/models/test_models_export.py rename tests/py/{ => ts}/BUILD (100%) rename tests/py/{ => ts}/api/test_classes.py (99%) rename tests/py/{ => ts}/api/test_collections.py (100%) rename tests/py/{ => ts}/api/test_e2e_behavior.py (100%) rename tests/py/{ => ts}/api/test_embed_engines.py (100%) rename tests/py/{ => ts}/api/test_logging.py (100%) rename tests/py/{ => ts}/api/test_module_fallback.py (98%) rename tests/py/{ => ts}/api/test_operator_fallback.py (97%) rename tests/py/{ => ts}/api/test_ts_backend.py (100%) rename tests/py/{ => ts}/api/utils.py (100%) rename tests/py/{ => ts}/hw/test_api_dla.py (100%) rename tests/py/{ => ts}/hw/test_multi_gpu.py (100%) rename tests/py/{ => ts}/hw/utils.py (100%) rename tests/py/{ => ts}/integrations/test_to_backend_api.py (100%) rename tests/py/{ => ts}/integrations/test_trt_intercompatibility.py (100%) rename tests/py/{ => ts}/integrations/utils.py (100%) rename tests/py/{ => ts}/model_test_case.py (100%) rename tests/py/{ => ts}/models/custom_models.py (100%) rename tests/py/{ => ts}/models/test_models.py (98%) rename tests/py/{ => ts}/models/test_multiple_registered_engines.py (98%) rename tests/py/{ => ts}/models/utils.py (100%) rename tests/py/{ => ts}/ptq/test_ptq_dataloader_calibrator.py (100%) rename tests/py/{ => ts}/ptq/test_ptq_to_backend.py (100%) rename tests/py/{ => ts}/ptq/test_ptq_trt_calibrator.py (100%) rename tests/py/{ => ts}/qat/test_qat_trt_accuracy.py (100%) rename tests/py/{ => ts}/requirements.txt (100%) rename tests/py/{ => ts}/utils.py (100%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 06d5b5a91f..ae4261ac43 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -519,7 +519,7 @@ commands: command: | set -e mkdir -p /tmp/artifacts/test_results - cd tests/py + cd tests/py/ts/ pytest --junitxml=/tmp/artifacts/test_results/api/api_test_results.xml api/ pytest --junitxml=/tmp/artifacts/test_results/models/models_test_results.xml models/ pytest --junitxml=/tmp/artifacts/test_results/integrations/integrations_test_results.xml integrations/ @@ -733,50 +733,47 @@ commands: # =================== FX tests end ======================== # # =================== Dynamo tests start ======================== # - test-dynamo-fx_ts: - description: "Test the Dynamo fx_ts_compat path" + + test-dynamo-torch_compile: + description: "Test Dynamo torch_compile tests" steps: - run: - name: Run Dynamo fx_ts_compat core tests + name: Run Dynamo torch_compile tests command: | - cd py/torch_tensorrt/dynamo/fx_ts_compat/test - pushd core/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml - popd + cd tests/py/dynamo/backend/ + pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml - store_test_results: path: /tmp/artifacts - store_artifacts: path: /tmp/testlogs - test-dynamo-compile-core: - description: "Test the Dynamo compile path" + test-dynamo-models_torch_compile: + description: "Test the Dynamo models via torch_compile path" steps: - run: - name: Run Dynamo compile core tests + name: Run Dynamo models via torch_compile path command: | - cd py/torch_tensorrt/dynamo/backend - pushd test/ - pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml - popd + cd tests/py/dynamo/models + pip3 install timm + pip3 install transformers + pytest test_models.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir torch_compile - store_test_results: path: /tmp/artifacts - store_artifacts: path: /tmp/testlogs - test-dynamo-compile: - description: "Test the Dynamo compile path" + test-dynamo-models_torch_export: + description: "Test the Dynamo models via torch_export path" steps: - run: - name: Run Dynamo compile E2E tests + name: Run Dynamo models via torch_export path command: | - cd py/torch_tensorrt/dynamo/ - pushd test/ + cd tests/py/dynamo/models pip3 install timm pip3 install transformers - pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile - popd + pytest test_models_export.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo - store_test_results: path: /tmp/artifacts @@ -1039,9 +1036,9 @@ jobs: command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env - - test-dynamo-compile - - test-dynamo-compile-core - - test-dynamo-fx_ts + - test-dynamo-torch_compile + - test-dynamo-models_torch_compile + - test-dynamo-models_torch_export package-x86_64-linux: parameters: diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index e76817e041..1ea87c5a4e 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -302,47 +302,58 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: return result_domain @classmethod - def from_tensor(cls, t: torch.Tensor) -> "Input": + def from_tensor( + cls, t: torch.Tensor, disable_memory_format_check: bool = False + ) -> "Input": """ Produce a Input which contains the information of the given PyTorch tensor. Args: tensor (torch.Tensor): A PyTorch tensor. + disable_memory_format_check (bool): Whether to validate the memory formats of input tensors Returns: A Input object. """ - if not any( - [ - t.is_contiguous(memory_format=torch.contiguous_format), - t.is_contiguous(memory_format=torch.channels_last), - ] + if not ( + t.is_contiguous(memory_format=torch.contiguous_format) + or t.is_contiguous(memory_format=torch.channels_last) + or disable_memory_format_check ): raise ValueError( "Tensor does not have a supported memory format, supported formats are contiguous or channel_last" ) frmt = ( torch.contiguous_format - if t.is_contiguous(memory_format=torch.contiguous_format) + if ( + t.is_contiguous(memory_format=torch.contiguous_format) + or disable_memory_format_check + ) else torch.channels_last ) return cls(shape=t.shape, dtype=t.dtype, format=frmt) @classmethod - def from_tensors(cls, ts: torch.Tensor) -> List["Input"]: + def from_tensors( + cls, ts: torch.Tensor, disable_memory_format_check: bool = False + ) -> List["Input"]: """ Produce a list of Inputs which contain the information of all the given PyTorch tensors. Args: tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. + disable_memory_format_check (bool): Whether to validate the memory formats of input tensors Returns: A list of Inputs. """ assert isinstance(ts, (list, tuple)) - return [cls.from_tensor(t) for t in ts] + return [ + cls.from_tensor(t, disable_memory_format_check=disable_memory_format_check) + for t in ts + ] def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor: """ diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index de0aeb5308..368c870d70 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -15,8 +15,8 @@ class _IRType(Enum): ts = 0 fx = 1 - fx_ts_compat = 2 - dynamo_compile = 3 + dynamo = 2 + torch_compile = 3 class _ModuleType(Enum): @@ -47,17 +47,17 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) ir_targets_fx = ir == "fx" - ir_targets_dynamo_compile = ir == "dynamo_compile" - ir_targets_fx_ts_compat = ir == "fx_ts_compat" + ir_targets_dynamo = ir == "dynamo" + ir_targets_torch_compile = ir == "torch_compile" if module_is_tsable and ir_targets_torchscript: return _IRType.ts elif module_is_fxable and ir_targets_fx: return _IRType.fx - elif module_is_fxable and ir_targets_fx_ts_compat: - return _IRType.fx_ts_compat - elif module_is_fxable and ir_targets_dynamo_compile: - return _IRType.dynamo_compile + elif module_is_fxable and ir_targets_dynamo: + return _IRType.dynamo + elif module_is_fxable and ir_targets_torch_compile: + return _IRType.torch_compile else: if ir == "default": # Options are listed in order of preference @@ -67,13 +67,13 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: ) return _IRType.ts elif module_is_fxable: - raise ValueError( - "Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" + logging.log( + logging.Level.Warning, + "Input graph is a torch.fx.GraphModule but the ir provided is default (ts). Please set ir=dynamo to suppress the warning.", ) - # logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx") - # return _IRType.fx + return _IRType.dynamo else: - raise ValueError("Module was provided with in an unsupported format") + raise ValueError("Module was provided in an unsupported format") else: raise ValueError("Unknown ir was requested") @@ -156,18 +156,41 @@ def compile( dynamic_batch=False, **kwargs, ) - elif target_ir == _IRType.dynamo_compile: + elif target_ir == _IRType.dynamo: + from torch_tensorrt import Device + from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device + import collections.abc + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + device = kwargs.get("device", Device._current_device()) + torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device)) + module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs) return torch_tensorrt.dynamo.compile( - module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs - ) - elif target_ir == _IRType.fx_ts_compat: - return torch_tensorrt.dynamo.fx_ts_compat.compile( - module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs + module, + inputs=inputs, + enabled_precisions=enabled_precisions, + **kwargs, ) + elif target_ir == _IRType.torch_compile: + return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") +def torch_compile(module, **kwargs): + """ + Returns a boxed model which is the output of torch.compile. + This does not compile the model to TRT. Execute this model on + sample inputs to compile the model to TRT. + """ + from torch_tensorrt.dynamo.backend import torch_tensorrt_backend + + boxed_fn = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs}) + + return boxed_fn + + def convert_method_to_trt_engine( module: Any, method_name: str, @@ -224,6 +247,16 @@ def convert_method_to_trt_engine( **kwargs, ) elif target_ir == _IRType.fx: - raise RuntimeError("fx is currently not supported") + raise RuntimeError( + "convert_method_to_trt_engine call is not supported for ir=fx" + ) + elif target_ir == _IRType.dynamo: + raise RuntimeError( + "convert_method_to_trt_engine call is not supported for ir=dynamo." + ) + elif target_ir == _IRType.torch_compile: + raise RuntimeError( + "convert_method_to_trt_engine call is not supported for ir=torch_compile" + ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index ecd4384155..3cd3a24b59 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -2,5 +2,6 @@ from torch_tensorrt._util import sanitized_torch_version if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from torch_tensorrt.dynamo import fx_ts_compat - from .backend import compile + from ._settings import * + from .compile import compile + from .aten_tracer import trace diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py similarity index 69% rename from py/torch_tensorrt/dynamo/backend/_defaults.py rename to py/torch_tensorrt/dynamo/_defaults.py index 0afbc60f8c..e55c592f4b 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,7 +1,6 @@ -from torch_tensorrt.fx.utils import LowerPrecision +import torch - -PRECISION = LowerPrecision.FP32 +PRECISION = torch.float32 DEBUG = False WORKSPACE_SIZE = 0 MIN_BLOCK_SIZE = 5 diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/_settings.py similarity index 83% rename from py/torch_tensorrt/dynamo/backend/_settings.py rename to py/torch_tensorrt/dynamo/_settings.py index d074a6b079..85a2693606 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,8 +1,7 @@ from dataclasses import dataclass, field from typing import Optional, Sequence - -from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.backend._defaults import ( +import torch +from torch_tensorrt.dynamo._defaults import ( PRECISION, DEBUG, WORKSPACE_SIZE, @@ -17,7 +16,7 @@ @dataclass class CompilationSettings: - precision: LowerPrecision = PRECISION + precision: torch.dtype = PRECISION debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE min_block_size: int = MIN_BLOCK_SIZE diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py new file mode 100644 index 0000000000..74c7d151ef --- /dev/null +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -0,0 +1,157 @@ +import copy +import sys +from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union +from packaging import version + +import torch +import torch._dynamo as torchdynamo + +from torch_tensorrt.fx.utils import req_torch_version +from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( + compose_bmm, + compose_chunk, + compose_getitem_slice, + remove_ops, + replace_aten_op_with_indices, + replace_aten_reshape_alias_with_replace, + replace_builtin_ops, + replace_inplace_ops, + replace_native_layernorm_with_layernorm, + replace_transpose_mm_op_with_linear, + run_const_fold, +) +from typing_extensions import TypeAlias + +Value: TypeAlias = Union[ + Tuple["Value", ...], + List["Value"], + Dict[str, "Value"], +] + + +class DynamoConfig: + """ + Manage Exir-specific configurations of Dynamo. + """ + + def __init__( + self, + capture_scalar_outputs: bool = True, + guard_nn_modules: bool = True, + dynamic_shapes: bool = True, + specialize_int: bool = True, + verbose: bool = True, + ) -> None: + + self.capture_scalar_outputs = capture_scalar_outputs + self.guard_nn_modules = guard_nn_modules + self.dynamic_shapes = dynamic_shapes + self.specialize_int = specialize_int + self.verbose = verbose + + def activate(self) -> None: + torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs + torchdynamo.config.guard_nn_modules = self.guard_nn_modules + torchdynamo.config.dynamic_shapes = self.dynamic_shapes + torchdynamo.config.specialize_int = self.specialize_int + torchdynamo.config.verbose = self.verbose + + def deactivate(self) -> None: + torchdynamo.config.capture_scalar_outputs = True + torchdynamo.config.guard_nn_modules = True + torchdynamo.config.dynamic_shapes = True + torchdynamo.config.specialize_int = True + torchdynamo.config.verbose = True + + +@contextmanager +def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]: + config.activate() + try: + yield config + finally: + config.deactivate() + + +@contextmanager +def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]: + """ + Temporarily increase the python interpreter stack recursion limit. + This is mostly used for pickling large scale modules. + """ + default = sys.getrecursionlimit() + if limit > default: + sys.setrecursionlimit(limit) + try: + yield + finally: + sys.setrecursionlimit(default) + + +@req_torch_version("2.dev") +def dynamo_trace( + f: Callable[..., Value], + # pyre-ignore + args: Tuple[Any, ...], + aten_graph: bool, + tracing_mode: str = "real", + dynamo_config: Optional[DynamoConfig] = None, +) -> Tuple[torch.fx.GraphModule, Set]: + """ + TODO: Once we fully migrate to torchdynamo frontend, we will remove + this config option alltogether. For now, it helps with quick + experiments with playing around with TorchDynamo + """ + if dynamo_config is None: + dynamo_config = DynamoConfig() + with using_config(dynamo_config), setting_python_recursive_limit(2000): + torchdynamo.reset() + try: + return torchdynamo.export( + f, + *copy.deepcopy(args), + aten_graph=aten_graph, + tracing_mode=tracing_mode, + ) + except torchdynamo.exc.Unsupported as exc: + raise RuntimeError( + "The user code is using a feature we don't support. " + "Please try torchdynamo.explain() to get possible the reasons", + ) from exc + except Exception as exc: + raise RuntimeError( + "torchdynamo internal error occured. Please see above stacktrace" + ) from exc + + +@req_torch_version("2.dev") +def trace(model, inputs, **kwargs): + """ + Optimized trace with necessary passes which re-compose some ops or replace some ops + These passes should be general and functional purpose + """ + passes_list = [ + compose_bmm, + compose_chunk, + compose_getitem_slice, + replace_aten_reshape_alias_with_replace, + replace_aten_op_with_indices, + replace_transpose_mm_op_with_linear, # after compose_bmm + replace_native_layernorm_with_layernorm, + remove_ops, + replace_builtin_ops, # after replace_native_layernorm_with_layernorm + replace_inplace_ops, # remove it once functionalization is enabled + ] + + fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic") + print(fx_module.graph) + for passes in passes_list: + pr: PassResult = passes(fx_module) + fx_module = pr.graph_module + + fx_module(*inputs) + + fx_module = run_const_fold(fx_module) + print(fx_module.graph) + return fx_module diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 38e60fce41..596ff92589 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -1,167 +1 @@ -import torch -import logging -import collections.abc -import torch_tensorrt -from functools import partial - -from typing import Any, Optional, Sequence -from torch_tensorrt import EngineCapability, Device -from torch_tensorrt.fx.utils import LowerPrecision - -from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device -from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend -from torch_tensorrt.dynamo.backend._defaults import ( - PRECISION, - DEBUG, - WORKSPACE_SIZE, - MIN_BLOCK_SIZE, - PASS_THROUGH_BUILD_FAILURES, - MAX_AUX_STREAMS, - VERSION_COMPATIBLE, - OPTIMIZATION_LEVEL, - USE_PYTHON_RUNTIME, -) - - -logger = logging.getLogger(__name__) - - -def compile( - gm: torch.nn.Module, - inputs: Any, - *, - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=DEBUG, - capability=EngineCapability.default, - num_avg_timing_iters=1, - workspace_size=WORKSPACE_SIZE, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=False, - require_full_compilation=False, - min_block_size=MIN_BLOCK_SIZE, - torch_executed_ops=[], - torch_executed_modules=[], - pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, - max_aux_streams=MAX_AUX_STREAMS, - version_compatible=VERSION_COMPATIBLE, - optimization_level=OPTIMIZATION_LEVEL, - use_python_runtime=USE_PYTHON_RUNTIME, - **kwargs, -): - if debug: - logger.setLevel(logging.DEBUG) - - logger.warn( - "The Dynamo backend is an experimental feature, for which only the " - + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, min_block_size, " - + "torch_executed_ops, pass_through_build_failures}" - ) - - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] - - inputs = prepare_inputs(inputs, prepare_device(device)) - - if not isinstance(enabled_precisions, collections.abc.Collection): - enabled_precisions = [enabled_precisions] - - # Parse user-specified enabled precisions - if ( - torch.float16 in enabled_precisions - or torch_tensorrt.dtype.half in enabled_precisions - ): - lower_precision = LowerPrecision.FP16 - elif ( - torch.float32 in enabled_precisions - or torch_tensorrt.dtype.float in enabled_precisions - ): - lower_precision = LowerPrecision.FP32 - elif len(enabled_precisions) == 0: - logger.info(f"No precision specified, defaulting to {PRECISION}") - lower_precision = PRECISION - else: - raise ValueError( - f"Precision {enabled_precisions} not supported in the Dynamo Path" - ) - - custom_backend = create_backend( - precision=lower_precision, - debug=debug, - workspace_size=workspace_size, - min_block_size=min_block_size, - torch_executed_ops=torch_executed_ops, - pass_through_build_failures=pass_through_build_failures, - max_aux_streams=max_aux_streams, - version_compatible=version_compatible, - optimization_level=optimization_level, - use_python_runtime=use_python_runtime, - **kwargs, - ) - - model = torch.compile(gm, backend=custom_backend) - - # Ensure compilation occurs by calling the function with provided inputs - model(*inputs) - - return model - - -from torch_tensorrt.fx.utils import LowerPrecision - -logger = logging.getLogger(__name__) - - -def create_backend( - precision: LowerPrecision = PRECISION, - debug: bool = DEBUG, - workspace_size: int = WORKSPACE_SIZE, - min_block_size: int = MIN_BLOCK_SIZE, - torch_executed_ops: Sequence[str] = set(), - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, - version_compatible: bool = VERSION_COMPATIBLE, - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, - **kwargs, -): - """Create torch.compile backend given specified arguments - - Args: - precision: Model Layer precision - debug: Whether to print out verbose debugging information - workspace_size: Workspace TRT is allowed to use for the module (0 is default) - min_block_size: Minimum number of operators per TRT-Engine Block - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage - pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine - version_compatible: Provide version forward-compatibility for engine plan files - optimization_level: Builder optimization 0-5, higher levels imply longer build time, - searching for more optimization options. TRT defaults to 3 - use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the - argument as None - Returns: - Backend for torch.compile - """ - return partial( - torch_tensorrt_backend, - debug=debug, - precision=precision, - workspace_size=workspace_size, - min_block_size=min_block_size, - torch_executed_ops=torch_executed_ops, - pass_through_build_failures=pass_through_build_failures, - max_aux_streams=max_aux_streams, - version_compatible=version_compatible, - optimization_level=optimization_level, - use_python_runtime=use_python_runtime, - **kwargs, - ) +from .backends import torch_tensorrt_backend diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 49865380d4..01827af13a 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -4,19 +4,19 @@ from functools import partial import torch._dynamo as td -from torch_tensorrt.dynamo.backend._settings import CompilationSettings -from torch_tensorrt.dynamo.backend.lowering._decompositions import ( +from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo.lowering._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( +from torch_tensorrt.dynamo.lowering._pre_aot_lowering import ( pre_aot_substitutions, ) -from torch_tensorrt.dynamo.backend.lowering._partition import ( +from torch_tensorrt.dynamo.lowering._partition import ( partition, get_submod_inputs, ) -from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs -from torch_tensorrt.dynamo.backend.conversion import convert_module +from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs +from torch_tensorrt.dynamo.conversion import convert_module from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py deleted file mode 100644 index 23a1cd4795..0000000000 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import logging -from dataclasses import replace, fields - -from torch_tensorrt.dynamo.backend._settings import CompilationSettings -from typing import Any, Union, Sequence, Dict -from torch_tensorrt import _Input, Device -from ..common_utils import use_python_runtime_parser - - -logger = logging.getLogger(__name__) - - -def prepare_inputs( - inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict], - device: torch.device = torch.device("cuda"), -) -> Any: - if isinstance(inputs, _Input.Input): - if isinstance(inputs.shape, dict): - return inputs.example_tensor(optimization_profile_field="opt_shape").to( - device - ) - else: - return inputs.example_tensor().to(device) - - elif isinstance(inputs, torch.Tensor): - return inputs - - elif isinstance(inputs, list): - prepared_input = list() - - for input_obj in inputs: - prepared_input.append(prepare_inputs(input_obj)) - - return prepared_input - - elif isinstance(inputs, tuple): - prepared_input = list() - - for input_obj in inputs: - prepared_input.append(prepare_inputs(input_obj)) - - return tuple(prepared_input) - - elif isinstance(inputs, dict): - prepared_input = dict() - - for key, input_obj in inputs.items(): - prepared_input[key] = prepare_inputs(input_obj) - - return prepared_input - - else: - raise ValueError( - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " - + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" - ) - - -def prepare_device(device: Union[Device, torch.device]) -> torch.device: - if isinstance(device, Device): - if device.gpu_id != -1: - device = torch.device(device.gpu_id) - else: - raise ValueError("Invalid GPU ID provided for the CUDA device provided") - - elif isinstance(device, torch.device): - device = device - - else: - raise ValueError( - "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" - ) - - return device - - -def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: - """Parses the kwargs field of a Dynamo backend - - Args: - kwargs: Keyword arguments dictionary provided to the backend - Returns: - CompilationSettings object with relevant kwargs - """ - - # Initialize an empty CompilationSettings object - settings = CompilationSettings() - - # If the user specifies keyword args, overwrite those fields in settings - # Validate all specified kwargs to ensure they are true fields of the dataclass - # - # Note: kwargs provided by torch.compile are wrapped in the "options" key - if kwargs: - if "options" in kwargs and len(kwargs) == 1: - kwargs = kwargs["options"] - - valid_attrs = {attr.name for attr in fields(settings)} - valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs} - settings = replace(settings, **valid_kwargs) - - # Enable debug/verbose mode if requested - if settings.debug: - logger.setLevel(logging.DEBUG) - - # Parse input runtime specification - settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) - - logger.debug(f"Compiling with Settings:\n{settings}") - - return settings diff --git a/py/torch_tensorrt/dynamo/common_utils/__init__.py b/py/torch_tensorrt/dynamo/common_utils/__init__.py deleted file mode 100644 index de0ce0a48a..0000000000 --- a/py/torch_tensorrt/dynamo/common_utils/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging -from typing import Optional - - -logger = logging.getLogger(__name__) - - -def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: - """Parses a user-provided input argument regarding Python runtime - - Automatically handles cases where the user has not specified a runtime (None) - - Returns True if the Python runtime should be used, False if the C++ runtime should be used - """ - using_python_runtime = use_python_runtime - reason = "" - - # Runtime was manually specified by the user - if using_python_runtime is not None: - reason = "as requested by user" - # Runtime was not manually specified by the user, automatically detect runtime - else: - try: - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule - - using_python_runtime = False - reason = "since C++ dependency was detected as present" - except ImportError: - using_python_runtime = True - reason = "since import failed, C++ dependency not installed" - - logger.info( - f"Using {'Python' if using_python_runtime else 'C++'} {reason} TRT Runtime" - ) - - return using_python_runtime diff --git a/py/torch_tensorrt/dynamo/common_utils/test_utils.py b/py/torch_tensorrt/dynamo/common_utils/test_utils.py deleted file mode 100644 index 873aed4c6b..0000000000 --- a/py/torch_tensorrt/dynamo/common_utils/test_utils.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - -COSINE_THRESHOLD = 0.99 -DECIMALS_OF_AGREEMENT = 4 - - -def cosine_similarity(gt_tensor, pred_tensor): - gt_tensor = gt_tensor.flatten().to(torch.float32) - pred_tensor = pred_tensor.flatten().to(torch.float32) - if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: - if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): - return 1.0 - res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) - res = res.cpu().detach().item() - - return res diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py new file mode 100644 index 0000000000..5528fe88b1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/compile.py @@ -0,0 +1,171 @@ +import torch +import logging +import collections.abc +import torch_tensorrt +from functools import partial + +from typing import Any, Optional, Sequence +from torch_tensorrt import EngineCapability, Device +from torch.fx.passes.pass_manager import PassManager +from torch.fx.passes.shape_prop import ShapeProp +from torch.fx.passes.splitter_base import SplitResult +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting +from torch_tensorrt.dynamo.lowering import ( + fuse_permute_linear, + fuse_permute_matmul, +) +from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device +from torch_tensorrt.dynamo.backend import torch_tensorrt_backend +from torch_tensorrt.dynamo.backend.backends import _compile_module +from torch_tensorrt.dynamo.conversion import convert_module + +from torch_tensorrt.dynamo._defaults import ( + PRECISION, + DEBUG, + WORKSPACE_SIZE, + MIN_BLOCK_SIZE, + PASS_THROUGH_BUILD_FAILURES, + MAX_AUX_STREAMS, + VERSION_COMPATIBLE, + OPTIMIZATION_LEVEL, + USE_PYTHON_RUNTIME, +) + + +logger = logging.getLogger(__name__) + + +def compile( + gm: Any, + inputs: Any, + *, + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=DEBUG, + capability=EngineCapability.default, + num_avg_timing_iters=1, + workspace_size=WORKSPACE_SIZE, + dla_sram_size=1048576, + dla_local_dram_size=1073741824, + dla_global_dram_size=536870912, + calibrator=None, + truncate_long_and_double=False, + require_full_compilation=False, + min_block_size=MIN_BLOCK_SIZE, + torch_executed_ops=[], + torch_executed_modules=[], + pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, + max_aux_streams=MAX_AUX_STREAMS, + version_compatible=VERSION_COMPATIBLE, + optimization_level=OPTIMIZATION_LEVEL, + use_python_runtime=USE_PYTHON_RUNTIME, + **kwargs, +): + if debug: + logger.setLevel(logging.DEBUG) + + logger.warn( + "The Dynamo backend is an experimental feature, for which only the " + + "following arguments are supported: " + + "{enabled_precisions, debug, workspace_size, min_block_size, " + + "torch_executed_ops, pass_through_build_failures}" + ) + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device)) + + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + precision = torch.float16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + precision = torch.float32 + elif len(enabled_precisions) == 0: + logger.info(f"No precision specified, defaulting to {PRECISION}") + precision = PRECISION + else: + raise ValueError( + f"Precision {enabled_precisions} not supported in the Dynamo Path" + ) + + compilation_options = { + "precision": precision, + "debug": debug, + "workspace_size": workspace_size, + "min_block_size": min_block_size, + "torch_executed_ops": torch_executed_ops, + "pass_through_build_failures": pass_through_build_failures, + "max_aux_streams": max_aux_streams, + "version_compatible": version_compatible, + "optimization_level": optimization_level, + "use_python_runtime": use_python_runtime, + } + + settings = CompilationSettings(**compilation_options) + if kwargs.get("use_capability_partitioner", None): + model = lower_model(gm, torch_inputs) + return _compile_module(model, torch_inputs, settings) + else: + split_result = lower_model_using_trt_splitter(gm, torch_inputs) + trt_module = _compile_graph(split_result, torch_inputs, settings) + + return trt_module + + +def _compile_graph( + split_result: SplitResult, + inputs: Any, + settings: CompilationSettings = CompilationSettings(), + **kwargs, +): + + for submod_name, submod_inputs in split_result.submodule_inputs.items(): + submod = getattr(split_result.split_module, submod_name) + # Only acc submodules will be lowered. + if not submod_name.startswith(split_result.non_acc_submodule_prefix): + # Create TRT Module from submodule + trt_mod = convert_module( + submod, + submod_inputs, + settings=settings, + name=submod_name, + ) + setattr(split_result.split_module, submod_name, trt_mod) + + return split_result.split_module + + +def lower_model_using_trt_splitter(model: torch.nn.Module, inputs: Any, **kwargs): + # Perform basic lowering + model = lower_model(model, inputs) + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter_setting.min_acc_module_size = 1 + splitter_setting.use_experimental_rt = False + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + splitter.node_support_preview() + split_result = splitter.generate_split_results() + + return split_result + + +def lower_model(model: torch.nn.Module, inputs: Any, **kwargs): + + graph_optimization_pm = PassManager.build_from_passlist( + [fuse_permute_matmul, fuse_permute_linear] + ) + lowered_model = graph_optimization_pm(model) + # if isinstance(lowered_model, torch.fx.GraphModule): + # ShapeProp(lowered_model).propagate(*inputs) + + return lowered_model diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py new file mode 100644 index 0000000000..56c2361f13 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -0,0 +1,2 @@ +from .trt_interpreter import * +from .conversion import * diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py similarity index 81% rename from py/torch_tensorrt/dynamo/backend/conversion.py rename to py/torch_tensorrt/dynamo/conversion/conversion.py index 425fb0941e..3b194dd8bf 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -1,12 +1,11 @@ from typing import Sequence, Union import torch import io -from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.dynamo.backend._settings import CompilationSettings -from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import ( - InputTensorSpec, - TRTInterpreter, -) +from torch_tensorrt.dynamo.runtime import _PythonTorchTRTModule +from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt import Input +from torch_tensorrt.dynamo.conversion import TRTInterpreter + import tensorrt as trt @@ -24,7 +23,7 @@ def convert_module( settings: Compilation settings name: TRT engine name Returns: - TRTModule or TRTModuleNext + _PythonTorchTRTModule or TorchTensorRTModule """ # Specify module output data types to ensure TRT output types agree with # that of the equivalent Torch module @@ -34,17 +33,15 @@ def convert_module( module_outputs = [module_outputs] output_dtypes = list(output.dtype for output in module_outputs) - interpreter = TRTInterpreter( module, - InputTensorSpec.from_tensors(inputs), + Input.from_tensors(inputs, disable_memory_format_check=True), logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), output_dtypes=output_dtypes, ) - interpreter_result = interpreter.run( workspace_size=settings.workspace_size, - lower_precision=settings.precision, + precision=settings.precision, profiling_verbosity=( trt.ProfilingVerbosity.VERBOSE if settings.debug @@ -56,14 +53,14 @@ def convert_module( ) if settings.use_python_runtime: - return TRTModule( + return _PythonTorchTRTModule( engine=interpreter_result.engine, input_names=interpreter_result.input_names, output_names=interpreter_result.output_names, ) else: - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt.dynamo.runtime import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interpreter_result.engine.serialize()) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py similarity index 74% rename from py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py rename to py/torch_tensorrt/dynamo/conversion/trt_interpreter.py index a29cee509d..9f97fb1b0a 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py @@ -14,12 +14,11 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata -from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS -from .input_tensor_spec import InputTensorSpec +from torch_tensorrt.fx import CONVERTERS +from torch_tensorrt import Input from torch_tensorrt.fx.observer import Observer from torch_tensorrt.fx.utils import ( get_dynamic_dims, - LowerPrecision, unified_dtype_converter, Frameworks, ) @@ -42,7 +41,7 @@ class TRTInterpreter(torch.fx.Interpreter): def __init__( self, module: torch.fx.GraphModule, - input_specs: List[InputTensorSpec], + input_specs: List[Input], logger_level=None, output_dtypes=None, ): @@ -69,7 +68,6 @@ def __init__( self.optimization_profiles: Optional[List] = None self.input_specs = input_specs self.input_specs_iter = 0 - self.validate_input_specs() self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = [] @@ -80,63 +78,6 @@ def __init__( # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes - def validate_input_specs(self): - for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: - if not self.network.has_implicit_batch_dimension: - assert ( - has_batch_dim - ), "It's required to specify batch dimension when it's explicit in TensorRT network." - - dynamic_dims = get_dynamic_dims(shape) - if len(dynamic_dims): - assert not self.network.has_implicit_batch_dimension, ( - "Can't have dynamic dim when " - f"batch dim is implicit, got {shape}." - ) - assert len( - shape_ranges - ), "shape_ranges must be provided when shape has dynamic dim." - - if self.optimization_profiles: - assert len(shape_ranges) == len(self.optimization_profiles), ( - "Number of optimization " - f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range" - f" {len(shape_ranges)} provided." - ) - else: - self.optimization_profiles = [ - self.builder.create_optimization_profile() - for _ in range(len(shape_ranges)) - ] - - for shape_range in shape_ranges: - assert ( - len(shape_range) == 3 - ), f"Expect three elements in shape_range, got {len(shape_range)}" - assert all(len(s) == len(shape) for s in shape_range), ( - "Expect elements in shape_range" - f" {shape_range} have the same number of dimension as the provided shape {len(shape)}" - ) - - for i in range(len(shape)): - if i in dynamic_dims: - assert all( - shape_range[j][i] <= shape_range[j + 1][i] - for j in range(2) - ), ( - "Expect dynamic dim" - f" {i} to have incremental value for shapes in shape_range {shape_range}." - ) - else: - assert all(s[i] == shape[i] for s in shape_range), ( - f"Expect non dynamic dim {i} to be the same" - f" for all shapes in shape_range {shape_range}." - ) - else: - assert ( - len(shape_ranges) == 0 - ), "shape_ranges are provided for input that doesn't have dynamic dim." - def validate_conversion(self): missing_converter = set() @@ -156,7 +97,7 @@ def validate_conversion(self): def run( self, workspace_size=0, - lower_precision=LowerPrecision.FP16, + precision=torch.float32, sparse_weights=False, disable_tf32=False, force_fp32_output=False, @@ -173,7 +114,7 @@ def run( Build TensorRT engine with some configs. Args: workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation. - lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). + precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity force_fp32_output: force output to be fp32 strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons. @@ -189,22 +130,14 @@ def run( """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) - # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and + # For float outputs, we set their dtype to fp16 only if precision == torch.float16 and # force_fp32_output=False. Overriden by specifying output_dtypes - self.output_fp16 = ( - not force_fp32_output and lower_precision == LowerPrecision.FP16 - ) + self.output_fp16 = not force_fp32_output and precision == torch.float16 - if ( - lower_precision == LowerPrecision.INT8 - and not self.builder.platform_has_fast_int8 - ): + if precision == torch.int8 and not self.builder.platform_has_fast_int8: raise RuntimeError("Current platform doesn't support fast native int8!") - if ( - lower_precision == LowerPrecision.FP16 - and not self.builder.platform_has_fast_fp16 - ): + if precision == torch.float16 and not self.builder.platform_has_fast_fp16: warnings.warn("Current platform doesn't support fast native fp16!") self.input_specs_iter = 0 @@ -248,10 +181,10 @@ def run( _LOGGER.info(f"Using optimization level {optimization_level}") builder_config.builder_optimization_level = optimization_level - if lower_precision == LowerPrecision.FP16: + if precision == torch.float16: builder_config.set_flag(trt.BuilderFlag.FP16) - if lower_precision == LowerPrecision.INT8: + if precision == torch.int8: builder_config.set_flag(trt.BuilderFlag.INT8) if sparse_weights: @@ -313,23 +246,30 @@ def run_node(self, n): def placeholder(self, target, args, kwargs): self._input_names.append(target) - shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[ - self.input_specs_iter - ] + current_input = self.input_specs[self.input_specs_iter] self.input_specs_iter += 1 - - if self.network.has_implicit_batch_dimension: - if has_batch_dim: - shape = shape[1:] - else: - for i, shape_range in enumerate(shape_ranges): - assert self.optimization_profiles - self.optimization_profiles[i].set_shape(target, *shape_range) + # Set optimization profile for dynamic input shape + shape = current_input.shape + if current_input.shape_mode == Input._ShapeMode.DYNAMIC: + shape = [] + min_shape = current_input.shape["min_shape"] + opt_shape = current_input.shape["opt_shape"] + max_shape = current_input.shape["max_shape"] + self.optimization_profiles[0].set_shape( + target, [min_shape, opt_shape, max_shape] + ) + assert len(min_shape) == len(opt_shape) == len(max_shape) + for i in range(len(min_shape)): + if min_shape[i] == opt_shape[i] == max_shape[i]: + shape.append(min_shape[i]) + else: + # -1 to represent the dynamic dimension + shape.append(-1) return self.network.add_input( name=target, shape=tuple(shape), - dtype=unified_dtype_converter(dtype, Frameworks.TRT), + dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT), ) def call_module(self, target, args, kwargs): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/README.md b/py/torch_tensorrt/dynamo/fx_ts_compat/README.md deleted file mode 100644 index d2a9e295a3..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/README.md +++ /dev/null @@ -1,13 +0,0 @@ -The code in this directory is similar to `torch_tensorrrt.fx`. We intend to make changes under `dynamo` namespace to ensure we -have the same top level API as `torch_tensorrt.ts.compile`. Right now, the usage is as follows - -``` -import torch_tensorrt -trt_module = torch_tensorrt.compile( - module, - ir="dynamo" - torchtrt_inputs, - enabled_precisions={torch.float32}, - ) -``` -This will internally call `torch_tensorrt.dynamo.compile` which has the same signature as `torch_tensorrt.ts.compile`. We intend to add features (existing in Torchscript backend for eg: torch_executed_ops, torch_executed_modules and many more) to this dynamo backend in the coming months. diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py deleted file mode 100644 index 85ce01ef20..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -import logging - -from torch_tensorrt.fx.converter_registry import ( # noqa - CONVERTERS, - NO_EXPLICIT_BATCH_DIM_SUPPORT, - NO_IMPLICIT_BATCH_DIM_SUPPORT, - tensorrt_converter, -) -from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa -from .input_tensor_spec import InputTensorSpec # noqa -from .lower_setting import LowerSetting # noqa -from .lower import compile # usort: skip #noqa - -logging.basicConfig(level=logging.INFO) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py deleted file mode 100644 index 7f67e8abbf..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple - -import torch - -from torch_tensorrt.fx.types import Shape, ShapeRange -from torch_tensorrt.fx.utils import get_dynamic_dims -from torch_tensorrt._Input import Input - - -class InputTensorSpec(NamedTuple): - """ - This class contains the information of a input tensor. - - shape: shape of the tensor. - - dtype: dtyep of the tensor. - - device: device of the tensor. This is only used to generate inputs to the given model - in order to run shape prop. For TensorRT engine, inputs have to be on cuda device. - - shape_ranges: If dynamic shape is needed (shape has dimensions of -1), then this field - has to be provided (default is empty list). Every shape_range is a tuple of three - tuples ((min_input_shape), (optimized_input_shape), (max_input_shape)). Each shape_range - is used to populate a TensorRT optimization profile. - e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize - for (25, 224) because it's the most common input shape, then we set shape_ranges to - ((1, 224), (25, 225), (100, 224)). - - has_batch_dim: Whether the shape includes batch dimension. Batch dimension has to be provided - if the engine want to run with dynamic shape. - """ - - shape: Shape - dtype: torch.dtype - device: torch.device = torch.device("cpu") - shape_ranges: List[ShapeRange] = [] - has_batch_dim: bool = True - - @classmethod - def from_tensor(cls, tensor: torch.Tensor) -> "InputTensorSpec": - """ - Produce an InputTenosrSpec named tuple which contains the - information of the given PyTorch tensor. - - Args: - tensor (torch.Tensor): A PyTorch tensor. - - Returns: - An InputTensorSpec named tuple. - """ - return cls(tensor.shape, tensor.dtype, tensor.device) - - @classmethod - def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec"]: - """ - Produce a list of InputTenosrSpec named tuples which contain - the information of all the given PyTorch tensors. - - Args: - tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. - - Returns: - A list of InputTensorSpec named tuples. - """ - assert isinstance(tensors, (list, tuple)) - return [cls.from_tensor(t) for t in tensors] - - @classmethod - def from_input(cls, input_obj: Input) -> "InputTensorSpec": - """ - Produce a list of InputTenosrSpec named tuples which contain - the information of all the given PyTorch tensors. - - Args: - tensors (Iterable[torch.Tensor]): A list of PyTorch tensors. - - Returns: - A list of InputTensorSpec named tuples. - """ - assert isinstance(input_obj, Input) - input_spec = None - if isinstance(input_obj.shape, dict): - min_shape = input_obj.shape["min_shape"] - opt_shape = input_obj.shape["opt_shape"] - max_shape = input_obj.shape["max_shape"] - dyn_shape = [] - for min, opt, max in zip(min_shape, opt_shape, max_shape): - if min == opt == max: - dyn_shape.append(min) - else: - dyn_shape.append(-1) - dtype = input_obj.torch_dtype - input_spec = cls( - shape=dyn_shape, - dtype=dtype, - shape_ranges=[(min_shape, opt_shape, max_shape)], - ) - else: - shape = input_obj.shape - dtype = input_obj.torch_dtype - input_spec = cls(shape=shape, dtype=dtype) - - return input_spec - - @classmethod - def from_tensors_with_dynamic_batch_size( - cls, - tensors: Sequence[torch.Tensor], - batch_size_range: Tuple[int, int, int], - opt_profile_replica: int = 1, - batch_dims: Optional[List[int]] = None, - ) -> List["InputTensorSpec"]: - """ - Produce a list of InputTenosrSpec named tuples which would contain - the information of all the given PyTorch tensors. The produced input - tensor specs will treat all tensors' first dimension as batch dimension - and mark them as dynmaic. - - Args: - tensors (Sequence[torch.Tensor]): A list of PyTorch tensors. - batch_size_range (Tuple[int, int, int]): The first integer indicates - the smallest batch size allowed. The second integer indiceates - the batch size that we'll optimize for. The third integer indicates - the largest batch size allowed. - opt_profile_replica (int): If dynamic shape is enabled, each execution - context requires a different optimization profile. This arg determines - how many optimization profile replicas we want to produce. - batch_dims (Optional[List[int]]): The batch dim might not be the leading dim - and allow user to specify the batch dims using this arg. Default we treat - dim 0 as the batch dim. - - Returns: - A list of InputTensorSpec named tuples with dynamic ranges. - """ - if batch_dims is None: - batch_dims = [0] * len(tensors) - - input_specs = [] - batch_size = tensors[0].size(batch_dims[0]) - - for i, tensor in enumerate(tensors): - batch_dim = batch_dims[i] - assert batch_size == tensor.size( - batch_dim - ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." - shape = list(tensor.shape) - shape[batch_dim] = -1 - shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] - input_specs.append( - cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) - ) - - return input_specs - - def to_random_tensor(self, id=1): - shape = tuple(self.shape) - if len(get_dynamic_dims(shape)): - # id=0 -> min shape - # id=1 -> optimal shape - # id=2 -> max shape - shape = tuple(self.shape_ranges[0][id]) - elif not self.has_batch_dim: - shape = (1,) + tuple(shape) - - return torch.randn(shape).to(dtype=self.dtype, device=self.device) - - @staticmethod - def create_inputs_from_specs(input_specs: Iterable["InputTensorSpec"]): - inputs = [] - for spec in input_specs: - inputs.append(spec.to_random_tensor()) - - return inputs - - @staticmethod - def create_inputs_from_max_specs(input_specs: Iterable["InputTensorSpec"]): - inputs = [] - for spec in input_specs: - inputs.append(spec.to_random_tensor(2)) - - return inputs diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py deleted file mode 100644 index c0f1ae7870..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ /dev/null @@ -1,370 +0,0 @@ -import dataclasses as dc -import logging -from typing import Any, Callable, Optional, Sequence - -# @manual=//deeplearning/trt/python:py_tensorrt -import tensorrt as trt -import torch -import torch.fx as fx -import torch.nn as nn -import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer -from torch.fx.passes.splitter_base import SplitResult - -from .fx2trt import TRTInterpreter, TRTInterpreterResult -from .lower_setting import LowerSetting -from .passes.lower_pass_manager_builder import LowerPassManagerBuilder -from .passes.pass_utils import PassFunc, validate_inference -from ..common_utils import use_python_runtime_parser -from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting - -from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer -from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt._Device import Device - -logger = logging.getLogger(__name__) - -Input = Sequence[Any] - - -def compile( - module: nn.Module, - inputs, - device=torch.device(torch.cuda.current_device()), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - min_block_size: int = 3, - workspace_size=0, - dla_sram_size=1048576, - dla_local_dram_size=1073741824, - dla_global_dram_size=536870912, - calibrator=None, - truncate_long_and_double=False, - require_full_compilation=False, - debug=False, - refit=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - is_aten=False, - use_python_runtime=None, - max_aux_streams=None, - version_compatible=False, - optimization_level=None, - num_avg_timing_iters=1, - torch_executed_ops=[], - torch_executed_modules=[], - **kwargs, -) -> nn.Module: - """ - Takes in original module, input and lowering setting, run lowering workflow to turn module - into lowered module, or so called TRTModule. - - Args: - module: Original module for lowering. - input: Input for module. - min_block_size: Minimal number of nodes for an accelerated submodule - workspace_size: Maximum size of workspace given to TensorRT. - debug: Enable verbose log for TensorRT if set True. - timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. - save_timing_cache: Update timing cache with current timing cache data if set to True. - cuda_graph_batch_size: Cuda graph batch size, default to be -1. - use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the - argument as None - max_aux_streams: max number of aux stream to use - version_compatible: enable version compatible feature - optimization_level: builder optimization level - Returns: - A torch.nn.Module lowered by TensorRT. - """ - logger.warn( - "For ir=fx_ts_compat backend only the " - + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, device, disable_tf32, sparse_weights, min_block_size}" - ) - - # Parse precision into LowerPrecision - lower_precision = LowerPrecision.FP32 - if torch.float16 in enabled_precisions: - lower_precision = LowerPrecision.FP16 - elif torch.float32 in enabled_precisions: - lower_precision = LowerPrecision.FP32 - else: - raise ValueError(f"Precision {enabled_precisions} not supported on FX") - - # Parse device - if isinstance(device, Device): - if device.gpu_id != -1: - device = torch.device(device.gpu_id) - else: - raise ValueError("Invalid GPU ID provided for the CUDA device provided") - elif isinstance(device, torch.device): - device = device - elif isinstance(device, dict): - if "device_type" in device and device["device_type"] == trt.DeviceType.GPU: - if "gpu_id" in device: - device = torch.device(device["gpu_id"]) - else: - device = torch.device("cuda:0") - else: - raise ValueError( - "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" - ) - - # Parse user-specification of which runtime to use - use_python_runtime = use_python_runtime_parser(use_python_runtime) - - lower_setting = LowerSetting( - device=device, - min_block_size=min_block_size, - disable_tf32=disable_tf32, - sparse_weights=sparse_weights, - workspace_size=workspace_size, - lower_precision=lower_precision, - debug=debug, - timing_cache_prefix=timing_cache_prefix, - save_timing_cache=save_timing_cache, - cuda_graph_batch_size=cuda_graph_batch_size, - is_aten=is_aten, - use_python_runtime=use_python_runtime, - max_aux_streams=max_aux_streams, - version_compatible=version_compatible, - optimization_level=optimization_level, - ) - lowerer = Lowerer.create(lower_setting=lower_setting) - return lowerer(module, inputs) - - -@dc.dataclass -class LowerTrtInterpreter: - lower_setting: LowerSetting - timing_cache_manager: TimingCacheManager - - @classmethod - def create(cls, lower_setting): - timing_cache_manager = TimingCacheManager( - lower_setting.timing_cache_prefix, lower_setting.save_timing_cache - ) - return LowerTrtInterpreter(lower_setting, timing_cache_manager) - - def __call__(self, mod, input, split_name) -> TRTInterpreterResult: - assert self.lower_setting.input_specs, "Can't find input specs for lowering!" - logger.info( - f"split_name={split_name}, input_specs={self.lower_setting.input_specs}" - ) - - # Prepare algorithm selector and timing_cache for TRTInterpreter - algo_selector = None - if self.lower_setting.algo_selector: - algo_selector = self.lower_setting.algo_selector(f"{split_name}.json") - cache_data = None - if self.timing_cache_manager: - try: - cache_data = self.timing_cache_manager.get_timing_cache_trt(split_name) - logger.info("Timing cache is used!") - except Exception as e: - logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}") - cache_data = None - - interpreter = TRTInterpreter( - mod, - input_specs=self.lower_setting.input_specs, - logger_level=trt.Logger.VERBOSE - if self.lower_setting.debug - else trt.Logger.WARNING, - ) - - interp_result: TRTInterpreterResult = interpreter.run( - workspace_size=self.lower_setting.workspace_size, - lower_precision=self.lower_setting.lower_precision, - sparse_weights=self.lower_setting.sparse_weights, - disable_tf32=self.lower_setting.disable_tf32, - strict_type_constraints=self.lower_setting.strict_type_constraints, - algorithm_selector=algo_selector, - timing_cache=cache_data, - profiling_verbosity=trt.ProfilingVerbosity.DETAILED - if self.lower_setting.verbose_profile - else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, - tactic_sources=self.lower_setting.tactic_sources, - max_aux_streams=self.lower_setting.max_aux_streams, - version_compatible=self.lower_setting.version_compatible, - optimization_level=self.lower_setting.optimization_level, - ) - - # Update timing cache file if needed - timing_cache = interp_result.serialized_cache - if timing_cache and self.timing_cache_manager: - self.timing_cache_manager.update_timing_cache(split_name, timing_cache) - - return interp_result - - -def default_split_function( - model: fx.GraphModule, inputs: Input, lower_setting: LowerSetting -) -> SplitResult: - splitter_setting = TRTSplitterSetting() - splitter_setting.use_implicit_batch_dim = False - splitter_setting.min_block_size = lower_setting.min_block_size - splitter_setting.use_experimental_rt = not lower_setting.use_python_runtime - splitter = TRTSplitter(model, inputs, settings=splitter_setting) - splitter.node_support_preview() - return splitter.generate_split_results() - - -def create_lower_trt_interpreter(lower_setting: LowerSetting) -> LowerTrtInterpreter: - return LowerTrtInterpreter.create(lower_setting) - - -def default_lower_pass( - create_trt_interpreter: Callable[[LowerSetting], LowerTrtInterpreter], -) -> PassFunc: - def lower_pass( - mod: nn.Module, input: Input, lower_setting: LowerSetting, module_name: str - ) -> nn.Module: - """ - Create a module transformation pass which lowers an `fx.GraphModule` into a - `TRTModule` - """ - interpreter = create_trt_interpreter(lower_setting) - interp_res: TRTInterpreterResult = interpreter(mod, input, module_name) - if lower_setting.use_python_runtime: - trt_module = TRTModule( - engine=interp_res.engine, - input_names=interp_res.input_names, - output_names=interp_res.output_names, - cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, - ) - return trt_module - - else: - import io - from torch_tensorrt._Device import Device - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule - - with io.BytesIO() as engine_bytes: - engine_bytes.write(interp_res.engine.serialize()) - engine_str = engine_bytes.getvalue() - - trt_module = TorchTensorRTModule( - engine_str, - name=module_name, - input_binding_names=interp_res.input_names, - output_binding_names=interp_res.output_names, - target_device=Device(f"cuda:{torch.cuda.current_device()}"), - ) - return trt_module - - return lower_pass - - -@dc.dataclass(frozen=True) -class Lowerer: - """Lowers a module using fx2trt. - - This is a composable class to facilitate fx2trt. A normal fx2trt process - composes of the following passes to transform an `fx.GraphModule`: - - 1. trace - use torch.fx to trace the module so we can get the graph - representation of the model. - 2. split - the graph module is split into several submodules, - running either via TensorRT, or via regular CUDA. - - For each split that need to run via TRT, the following passes are - invoked: - - 3. `TRTInterpreter` - build the TRT engine for the submodule that - can be supported through `TRTInterpreter`. - 4. Wraps the executable TRT engine into `TRTModule`, which is an `nn.Module`. - 5. The converted submodule is then set back onto the top-level module - - """ - - lower_pass_manager_builder: LowerPassManagerBuilder - - @classmethod - def create( - cls, - lower_setting: LowerSetting, - interpreter_builder: Callable = create_lower_trt_interpreter, - split_func: Callable = default_split_function, - ) -> "Lowerer": - """Instantiate a `Lowerer` instance.""" - if not lower_setting.is_aten: - return cls( - lower_pass_manager_builder=LowerPassManagerBuilder( - lower_setting=lower_setting, - trace_func=lambda module, inputs: acc_tracer.trace( - module, - inputs, # type: ignore[arg-type] - ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, - leaf_module_list=lower_setting.leaf_module_list, - ), - split_func=split_func, - lower_func=default_lower_pass(interpreter_builder), - ) - ) - # proxytensor_trace - else: - return cls( - lower_pass_manager_builder=LowerPassManagerBuilder( - lower_setting=lower_setting, - trace_func=lambda module, inputs: aten_tracer.opt_trace( - module, inputs - ), - split_func=split_func, - lower_func=default_lower_pass(interpreter_builder), - ) - ) - - def __call__( - self, - module: nn.Module, - inputs: Input, - additional_inputs: Optional[Input] = None, - fp16_conversion_fn: Optional[Callable[[Input], Input]] = None, - ) -> nn.Module: - lower_setting = self.lower_pass_manager_builder.lower_setting - atol = lower_setting.correctness_atol - rtol = lower_setting.correctness_rtol - device = lower_setting.device - - @validate_inference( - atol=atol, - rtol=rtol, - device=device, - ) - def do_lower(module: nn.Module, inputs: Input) -> nn.Module: - module.eval() - if ( - self.lower_pass_manager_builder.lower_setting.lower_precision - == LowerPrecision.FP16 - ): - module.half() - # A custom conversion function can be passed to the lowerer to - # handle inputs with custom types. By default, just handle - # tensors and NoneType. - if fp16_conversion_fn is None: - conversion_fn = ( - lambda x: x.half() - if x is not None and x.dtype == torch.float32 - else x - ) - else: - conversion_fn = fp16_conversion_fn - - inputs = tuple(conversion_fn(x) for x in inputs) - if lower_setting.is_aten: - pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( - inputs, additional_inputs - ) - else: - pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( - inputs, additional_inputs - ) - lower_result = pm(module) - return lower_result - - return do_lower(module, inputs) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py deleted file mode 100644 index 9301a2cd90..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ /dev/null @@ -1,102 +0,0 @@ -import dataclasses as dc -from typing import List, Optional, Set, Type -import torch -from torch import nn -from torch.fx.passes.pass_manager import PassManager - -from .input_tensor_spec import InputTensorSpec -from torch_tensorrt.fx.passes.lower_basic_pass import ( - fuse_permute_linear, - fuse_permute_matmul, -) -from torch_tensorrt.fx.utils import LowerPrecision - - -@dc.dataclass -class LowerSettingBasic: - """ - Basic class for lowering. - lower_precision: lower precision dtype during lowering. - min_block_size(int): The minimum number of contiguous TensorRT convertable nodes in order to run them in TensorRT - ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of - modules that need AST rewriting. This is aiming to eliminate input variable involve in - exception checking control flow. - leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where - modules will not be traced into. - verbose_profile (bool): verbosity of profiler, default to False. - """ - - lower_precision: LowerPrecision = LowerPrecision.FP32 - device: torch.device = torch.device(torch.cuda.current_device()) - min_block_size: int = 3 - disable_tf32: bool = False - sparse_weights: bool = False - ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None - leaf_module_list: Optional[Set[Type[nn.Module]]] = None - verbose_profile: bool = False - is_aten: bool = False - - -@dc.dataclass -class LowerSetting(LowerSettingBasic): - """ - Basic configuration for lowering stack. - Args: - input_specs: Specs for inputs to engine, can either be a single size or a - range defined by Min, Optimal, Max sizes. - workspace_size: The maximum workspace size. The maximum GPU temporary - memory which the TensorRT engine can use at execution time. - strict_type_constraints: Require TensorRT engine to strictly follow data type - setting at execution time. - customized_fuse_pass: List of custmozied pass to apply during lowering process. - lower_basic_fuse_pass: Enable basic pass fuse duirng lowering, i.e. fuse multiple operations - as (a->b->c->d)=>(e). Current basic fuse patterns are: - permute->linear - permute->matmul - debug: Enable TensorRT engine verbose log mode. - algo_selector: Enable TensorRT algorithm selector at execution time. - timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing - cache file at execution time if valid timing cache file is provided. - save_timing_cache: Save updated timing cache data into timing cache file if the timing - cache file is provided. - cuda_graph_batch_size (int): Cuda graph batch size, default to be -1. - preset_lowerer (str): when specified, use a preset logic to build the - instance of Lowerer. - only used by explicit batch dim with dynamic shape mode. In general, we use 2 GPU setting with - 2 stream on each. Set total number to 8 as a safe default value. - tactic_sources: tactic sources for TensorRT kernel selection. Default to None, - meaning all possible tactic sources. - correctness_atol: absolute tolerance for correctness check - correctness_rtol: relative tolerance for correctness check - use_python_runtime: Whether to use Python runtime or C++ runtime. None implies the user has not - selected a runtime, and the frontend will automatically do so on their behalf - max_aux_streams: max number of aux stream to use - version_compatible: enable version compatible feature - optimization_level: builder optimization level - """ - - input_specs: List[InputTensorSpec] = dc.field(default_factory=list) - workspace_size: int = 0 - strict_type_constraints: bool = False - customized_fuse_pass: PassManager = dc.field( - default_factory=lambda: PassManager.build_from_passlist([]) - ) - lower_basic_fuse_pass: PassManager = dc.field( - default_factory=lambda: PassManager.build_from_passlist( - [fuse_permute_matmul, fuse_permute_linear] - ) - ) - debug: bool = False - algo_selector = None - timing_cache_prefix: str = "" - save_timing_cache: bool = False - cuda_graph_batch_size: int = -1 - preset_lowerer: str = "" - opt_profile_replica: int = 8 - tactic_sources: Optional[int] = None - correctness_atol: float = 0.1 - correctness_rtol: float = 0.1 - use_python_runtime: Optional[bool] = None - max_aux_streams: Optional[int] = None - version_compatible: bool = False - optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py deleted file mode 100644 index 0fd3777254..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py +++ /dev/null @@ -1,333 +0,0 @@ -import datetime -import logging -from functools import partial, wraps -from typing import Any, Callable, Optional, Sequence - -import torch -from torch import nn -from torch.fx.passes.pass_manager import inplace_wrapper, PassManager -from torch.fx.passes.shape_prop import ShapeProp -from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult -from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt import _Input -from ..input_tensor_spec import InputTensorSpec - -from ..lower_setting import LowerSetting -from torch_tensorrt.fx.observer import Observer -from torch_tensorrt.fx.passes.remove_duplicate_output_args import ( - remove_duplicate_output_args, -) -from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination -from .pass_utils import extract_example_tensors_from_input - -from torch_tensorrt.fx.passes.lower_basic_pass import ( # noqa - fix_clamp_numerical_limits_to_fp16, - fix_reshape_batch_dim, - replace_mutable_op, - replace_op_with_indices, - run_const_fold, -) - - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -Input = Sequence[Any] - - -# ---------------------------------------------------------------------- -# OBSERVERS -# ---------------------------------------------------------------------- -# List of observers. We can subscribe to them by calling its `add(callback)` -# function from anywhere in code: -# -# >>> from torch_tensorrt.fx.lower import FUSE_PASSES_POST_OBSERVER -# >>> with FUSE_PASSES_POST_OBSERVER.add(print_module_and_input): -# >>> # print_module_and_input will be called right after the fuse passes -# >>> lower(module, sample_input) - -# Observer for the model after the fuse passes. -FUSE_PASSES_POST_OBSERVER: Observer[Callable[[nn.Module, Input], None]] = Observer( - "FUSE_PASSES_POST_OBSERVER" -) - -# Observer for the TRT split submodules before lowering -LOWER_SPLIT_PRE_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer( - "LOWER_SPLIT_PRE_OBSERVER" -) - -# Observer for the TRT split submodules after lowering -LOWER_SPLIT_POST_OBSERVER: Observer[Callable[[str, nn.Module, Input], None]] = Observer( - "LOWER_SPLIT_POST_OBSERVER" -) -# ---------------------------------------------------------------------- - - -def wrapper(fn: Callable, input) -> Callable: - @wraps(fn) - def wrapped_fn(gm): - if isinstance(gm, torch.fx.GraphModule): - ShapeProp(gm).propagate(*input) - return fn(gm, input) - - return wrapped_fn - - -class LowerPassManagerBuilder: - """ - Build PassManager for lowering. - - Attributes: - lower_setting: Setting that will be used during process of lowering, see lower_setting.py for the details. - _trace_func: fx trace function for TRT conversion. - _split_func: the fx2trt split function. - _lower_func: function to create and run `TRTInterpreter` to convert `fx.GraphModule` - into a TensorRT engine. - - """ - - def __init__( - self, - lower_setting: LowerSetting, - trace_func: Callable, - split_func: Callable, - lower_func: Callable, - ): - self.lower_setting = lower_setting - self._trace_func = trace_func - self._split_func = split_func - self._lower_func = lower_func - - def _const_fold_pass(self) -> PassManager: - passes = [ - wrapper(self._trace_func, self._input), - run_const_fold, - ] - return PassManager.build_from_passlist(passes) - - def graph_optimization_pass(self) -> PassManager: - passes = [ - wrapper(self._trace_func, self._input), - ] - for p in self.lower_setting.customized_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - for p in self.lower_setting.lower_basic_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - if ( - hasattr(self.lower_setting, "lower_precision") - and self.lower_setting.lower_precision is LowerPrecision.FP16 - ) or ( - hasattr(self.lower_setting, "precision") - and self.lower_setting.precision is LowerPrecision.FP16 - ): - passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) - - passes.append(inplace_wrapper(common_subexpression_elimination)) - passes.append( - inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) - ) - passes.append(fix_reshape_batch_dim) - - return PassManager.build_from_passlist(passes) - - def graph_optimization_pass_aten(self) -> PassManager: - passes = [] - - for p in self.lower_setting.customized_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - for p in self.lower_setting.lower_basic_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - # TODO fix this pass for aten graph - # if ( - # hasattr(self.lower_setting, "lower_precision") - # and self.lower_setting.lower_precision is LowerPrecision.FP16 - # ) or ( - # hasattr(self.lower_setting, "precision") - # and self.lower_setting.precision is LowerPrecision.FP16 - # ): - # passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) - - passes.append( - inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) - ) - # TODO we most likely do not need it for aten - # passes.append(fix_reshape_batch_dim) - - return PassManager.build_from_passlist(passes) - - def _split_pass(self) -> PassManager: - passes = [ - partial( - self._split_func, inputs=self._input, lower_setting=self.lower_setting - ) - ] - passes.append( - inplace_wrapper( - lambda split_result: remove_duplicate_output_args( - split_result.split_module, split_result.submodule_inputs.keys() - ) - ) - ) - - return PassManager.build_from_passlist(passes) - - def _trt_lower_pass(self) -> PassManager: - def lower_func(split_result: SplitResult) -> nn.Module: - if ( - hasattr(self.lower_setting, "explicit_batch_dimension") - and self.lower_setting.explicit_batch_dimension - and self._additional_input - ): - additional_submodule_inputs = generate_inputs_for_submodules( - split_result.split_module, - self._additional_input, - list(split_result.submodule_inputs.keys()), - ) - else: - additional_submodule_inputs = None - - for submod_name, submod_inputs in split_result.submodule_inputs.items(): - submod = getattr(split_result.split_module, submod_name) - - LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) - - # Only acc submodules will be lowered. - if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info(f"Now lowering submodule {submod_name}") - lowering_start_time = datetime.datetime.now() - - self.lower_setting.input_specs = self._trt_input - - lowered_module = self._lower_func( - submod, submod_inputs, self.lower_setting, submod_name - ) - setattr(split_result.split_module, submod_name, lowered_module) - LOWER_SPLIT_POST_OBSERVER.observe( - submod_name, lowered_module, submod_inputs - ) - _LOGGER.info( - f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" - ) - - return split_result.split_module - - return PassManager.build_from_passlist([lower_func]) - - def _default_lower_pass(self) -> PassManager: - def lower_func(split_result: SplitResult) -> nn.Module: - if self._additional_input: - additional_submodule_inputs = generate_inputs_for_submodules( - split_result.split_module, - self._additional_input, - list(split_result.submodule_inputs.keys()), - ) - else: - additional_submodule_inputs = None - - for submod_name, submod_inputs in split_result.submodule_inputs.items(): - submod = getattr(split_result.split_module, submod_name) - - LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs) - - # Only acc submodules will be lowered. - if not submod_name.startswith(split_result.non_acc_submodule_prefix): - _LOGGER.info(f"Now lowering submodule {submod_name}") - lowering_start_time = datetime.datetime.now() - - self.lower_setting.additional_inputs = ( - additional_submodule_inputs[submod_name] - if additional_submodule_inputs - else None, - ) - - lowered_module = self._lower_func( - submod, submod_inputs, self.lower_setting, submod_name - ) - setattr(split_result.split_module, submod_name, lowered_module) - LOWER_SPLIT_POST_OBSERVER.observe( - submod_name, lowered_module, submod_inputs - ) - _LOGGER.info( - f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}" - ) - - return split_result.split_module - - return PassManager.build_from_passlist([lower_func]) - - def _default_replace_mutable_op_pass(self) -> PassManager: - return PassManager.build_from_passlist([replace_mutable_op]) - - def build_trt_lower_pipeline( - self, input: Input, additional_input: Optional[Input] = None - ) -> PassManager: - - self._input = extract_example_tensors_from_input( - input, self.lower_setting.device - ) - self._trt_input = [] - for input_obj in input: - if isinstance(input_obj, _Input.Input): - self._trt_input.append(InputTensorSpec.from_input(input_obj)) - elif isinstance(input_obj, torch.Tensor): - self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) - else: - raise ValueError( - "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" - ) - - self._additional_input = additional_input - passes = [] - - passes.append(self._default_replace_mutable_op_pass()) - passes.append(self._const_fold_pass()) - passes.append(self.graph_optimization_pass()) - passes.append(self._split_pass()) - passes.append(self._trt_lower_pass()) - - pm = PassManager.build_from_passlist(passes) - return pm - - def build_aten2trt_lower_pipeline( - self, input: Input, additional_input: Optional[Input] = None - ) -> PassManager: - - self._input = extract_example_tensors_from_input(input) - self._trt_input = [] - for input_obj in input: - if isinstance(input_obj, _Input.Input): - self._trt_input.append(InputTensorSpec.from_input(input_obj)) - elif isinstance(input_obj, torch.Tensor): - self._trt_input.append(InputTensorSpec.from_tensor(input_obj)) - else: - raise ValueError( - "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" - ) - - self._additional_input = additional_input - passes = [] - passes.append( - wrapper(self._trace_func, self._input), - ) - passes.append(self.graph_optimization_pass_aten()) - passes.append(self._split_pass()) - passes.append(self._trt_lower_pass()) - - pm = PassManager.build_from_passlist(passes) - return pm - - def build_default_lower_pipeline( - self, input: Input, additional_input: Optional[Input] = None - ) -> PassManager: - self._input = input - self._additional_input = additional_input - passes = [] - - passes.append(self._default_replace_mutable_op_pass()) - passes.append(self._const_fold_pass()) - passes.append(self.graph_optimization_pass()) - passes.append(self._split_pass()) - passes.append(self._default_lower_pass()) - - pm = PassManager.build_from_passlist(passes) - return pm diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py deleted file mode 100644 index 7d3046d617..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py +++ /dev/null @@ -1,304 +0,0 @@ -import io -import logging -import tempfile -from datetime import datetime -from functools import wraps -from typing import Any, Callable, List, Optional - -import torch -from torch import fx -from torch.fx.passes.shape_prop import ShapeProp -from torch_tensorrt import _Input - -# Create an alias for module input type to avoid littering pyre-ignore for Any -# throughout the file. -Input = Any -_LOGGER: logging.Logger = logging.getLogger(__name__) - -PassFunc = Callable[[fx.GraphModule, Input], fx.GraphModule] - -RELAX_ACCURACY_FAILURE: bool = False -FINAL_CHECK_ATOL_MULTIPLIER: float = 10 -FINAL_CHECK_RTOL_MULTIPLIER: float = 10 - - -def extract_example_tensors_from_input( - inputs: Any, device: torch.device = torch.device("cuda") -): - input_tensors = [] - for input_obj in inputs: - if isinstance(input_obj, _Input.Input): - if isinstance(input_obj.shape, dict): - input_tensors.append( - input_obj.example_tensor(optimization_profile_field="opt_shape").to( - device - ) - ) - else: - input_tensors.append(input_obj.example_tensor().to(device)) - elif isinstance(input_obj, torch.Tensor): - input_tensors.append(input_obj) - else: - raise ValueError( - "Invalid input type provided in the FX lowering. Expected type: torch_tensorrt.Input or torch.Tensor" - ) - - return input_tensors - - -class RelaxAccuracyCheckMode: - """ - Basically a context manager that controls a global variable that controls - the accuracy check mode. Use it like - with RelaxAccuracyCheckMode(True): - fx2trt() - """ - - def __init__( - self, - mode: bool, - final_atol_multiplier: Optional[float] = None, - final_rtol_multiplier: Optional[float] = None, - ): - """ - Arguments: - mode: whether we relax the immediate accuracy check failure or not. If yes, we will do an extra - accruacy check by raising the tolerance by the multipler times and only raise error if that fails. - This is to avoid catastrophic errors. - final_atol_multiplier [optional]: set FINAL_CHECK_ATOL_MULTIPLIER if specifier. - final_rtol_multiplier [optional]: set FINAL_CHECK_RTOL_MULTIPLIER if specifier. - """ - global RELAX_ACCURACY_FAILURE - global FINAL_CHECK_ATOL_MULTIPLIER - global FINAL_CHECK_RTOL_MULTIPLIER - self._old_mode = ( - RELAX_ACCURACY_FAILURE, - FINAL_CHECK_ATOL_MULTIPLIER, - FINAL_CHECK_RTOL_MULTIPLIER, - ) - RELAX_ACCURACY_FAILURE = mode - FINAL_CHECK_ATOL_MULTIPLIER = ( - final_atol_multiplier - if final_atol_multiplier - else FINAL_CHECK_ATOL_MULTIPLIER - ) - FINAL_CHECK_RTOL_MULTIPLIER = ( - final_rtol_multiplier - if final_rtol_multiplier - else FINAL_CHECK_RTOL_MULTIPLIER - ) - _LOGGER.info( - f"Set new relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" - ) - - def __enter__(self): - pass - - def __exit__(self, type, value, traceback): - global RELAX_ACCURACY_FAILURE - global FINAL_CHECK_ATOL_MULTIPLIER - global FINAL_CHECK_RTOL_MULTIPLIER - ( - RELAX_ACCURACY_FAILURE, - FINAL_CHECK_ATOL_MULTIPLIER, - FINAL_CHECK_RTOL_MULTIPLIER, - ) = self._old_mode - _LOGGER.info( - f"Restored old relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" - ) - - -def chain_passes(*passes: PassFunc) -> PassFunc: - """ - Chains a sequence of pass functions to form a single pass function - """ - - def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: - for pass_ in passes: - if isinstance(module, torch.fx.GraphModule): - ShapeProp(module).propagate(*input) - module = pass_(module, input) - return module - - return parent_pass - - -# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall -# on pass that failed accuracy check. -def validate_inference( - rtol=None, - atol=None, - device=torch.device(torch.cuda.current_device()), - suppress_accuracy_check_failure=True, -): - def _validate_inference(pass_: PassFunc) -> PassFunc: - """ - Wraps a pass function to validate that its inference results before and - after the pass run should be `close`. - """ - - @wraps(pass_) - def pass_with_validation( - module: fx.GraphModule, - input: Input, - *args, - **kwargs, - ) -> fx.GraphModule: - if suppress_accuracy_check_failure: - return pass_(module, input, *args, **kwargs) - else: - input_tensors = extract_example_tensors_from_input(input, device) - res0 = module(*input_tensors) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input_tensors) - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: - torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) - torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module - - return pass_with_validation - - return _validate_inference - - -Decorator = Callable[[Callable], Callable] - - -def decorate_method(dec_for_function: Decorator) -> Decorator: - def dec_for_method(unbounded_method) -> Callable: - def decorated_unbounded_method(self, *args, **kwargs): - @dec_for_function - def bounded_method(*args, **kwargs): - return unbounded_method(self, *args, **kwargs) - - return bounded_method(*args, **kwargs) - - return decorated_unbounded_method - - return dec_for_method - - -def log_perf_before_after(pass_: PassFunc) -> PassFunc: - """ - Wraps a pass function to log perf of the module before and after the pass - """ - - @wraps(pass_) - def check_perf_with_before_after_log( - module: fx.GraphModule, input: Input - ) -> fx.GraphModule: - def benchmark_torch_function(iters: int, f, *args) -> float: - """Estimates the average time duration for a single inference call in second - - If the input is batched, then the estimation is for the batches inference call. - - Args: - iters: number of inference iterations to run - f: a function to perform a single inference call - - Returns: - estimated average time duration in second for a single inference call - """ - with torch.inference_mode(): - f(*args) - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - # print("== Start benchmark iterations") - with torch.inference_mode(): - start_event.record() - for _ in range(iters): - f(*args) - end_event.record() - torch.cuda.synchronize() - # print("== End benchmark iterations") - return (start_event.elapsed_time(end_event) * 1.0e-3) / iters - - time_before = benchmark_torch_function(100, lambda: module(*input)) - _LOGGER.info(f"[{pass_}] Perf Before(eager mode): {time_before}") - - module = pass_(module, input) - time_after = benchmark_torch_function(100, lambda: module(*input)) - _LOGGER.info(f"[{pass_}] Perf After(eager mode): {time_after}") - return module - - return check_perf_with_before_after_log - - -def log_before_after(pass_: PassFunc) -> PassFunc: - """ - Wraps a pass function to log the module graph before and after the pass - """ - - @wraps(pass_) - def pass_with_before_after_log( - module: fx.GraphModule, input: Input - ) -> fx.GraphModule: - before_io = io.StringIO() - after_io = io.StringIO() - with tempfile.NamedTemporaryFile( - mode="w", - encoding="utf-8", - delete=False, - ) as f: - print(f"[{pass_}] Before:\n{module.graph}", file=f) - print(module.graph, file=before_io) - start_time = datetime.now() - module = pass_(module, input) - t_elapsed = datetime.now() - start_time - print(f"[{pass_}] After:\n{module.graph}", file=f) - print(module.graph, file=after_io) - t = before_io.getvalue() == after_io.getvalue() - _LOGGER.info( - f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}, time elapsed = {t_elapsed}" - ) - return module - - return pass_with_before_after_log - - -def _collect_tensors(arg: fx.node.Argument) -> List[torch.Tensor]: - """Collects all the tensors found in a nested container object""" - res: List[torch.Tensor] = [] - - def collect(x: fx.node.Argument) -> fx.node.Argument: - if isinstance(x, torch.Tensor): - res.append(x) - return x - - fx.node.map_aggregate(arg, collect) - return res diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py deleted file mode 100644 index b7dd8153cb..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input.py +++ /dev/null @@ -1,89 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -import io -import os - -import torch -import torch_tensorrt -from torch.testing._internal.common_utils import run_tests, TestCase - - -class TestInput(TestCase): - def test_add_model(self): - class TestModule(torch.nn.Module): - def forward(self, x): - return x + x - - inputs = [torch_tensorrt.Input(shape=(1, 3, 3, 4), dtype=torch.float32)] - rand_inputs = [torch.randn((1, 3, 3, 4), dtype=torch.float32).cuda()] - mod = TestModule().cuda().eval() - ref_output = mod(*rand_inputs) - - trt_mod = torch_tensorrt.compile( - mod, - ir="fx_ts_compat", - inputs=inputs, - min_block_size=1, - ) - trt_output = trt_mod(*rand_inputs) - - torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) - - def test_conv_model(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 1, 1, 1, 1, 1, True) - - def forward(self, x): - return self.conv(x) - - inputs = [torch_tensorrt.Input(shape=(1, 3, 32, 32), dtype=torch.float32)] - rand_inputs = [torch.randn((1, 3, 32, 32), dtype=torch.float32).cuda()] - mod = TestModule().cuda().eval() - ref_output = mod(*rand_inputs) - - trt_mod = torch_tensorrt.compile( - mod, - ir="fx_ts_compat", - inputs=inputs, - min_block_size=1, - ) - trt_output = trt_mod(*rand_inputs) - - torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) - - def test_conv_model_with_dyn_shapes(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 6, 1, 1, 1, 1, 1, True) - - def forward(self, x): - return self.conv(x) - - inputs = [ - torch_tensorrt.Input( - min_shape=(1, 3, 32, 32), - opt_shape=(8, 3, 32, 32), - max_shape=(16, 3, 32, 32), - dtype=torch.float32, - ) - ] - rand_inputs = [torch.randn((4, 3, 32, 32), dtype=torch.float32).cuda()] - mod = TestModule().cuda().eval() - ref_output = mod(*rand_inputs) - - trt_mod = torch_tensorrt.compile( - mod, - ir="fx_ts_compat", - inputs=inputs, - min_block_size=1, - ) - trt_output = trt_mod(*rand_inputs) - - torch.testing.assert_close(trt_output, ref_output, rtol=1e-04, atol=1e-04) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py deleted file mode 100644 index 0761b964f8..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py +++ /dev/null @@ -1,84 +0,0 @@ -# Owner(s): ["oncall: gpu_enablement"] - -from typing import List, Optional - -import torch -import torch_tensorrt -from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting - - -class TestTRTModule(TestCase): - def _validate_spec( - self, - spec: InputTensorSpec, - tensor: torch.Tensor, - dynamic_dims: Optional[List[int]] = None, - ): - expected_shape = list(tensor.shape) - if dynamic_dims: - for dim in dynamic_dims: - expected_shape[dim] = -1 - self.assertSequenceEqual(spec.shape, expected_shape) - self.assertEqual(spec.dtype, tensor.dtype) - self.assertEqual(spec.device, tensor.device) - self.assertTrue(spec.has_batch_dim) - - def test_from_tensor(self): - tensor = torch.randn(1, 2, 3) - spec = InputTensorSpec.from_tensor(tensor) - self._validate_spec(spec, tensor) - - def test_from_tensors(self): - tensors = [torch.randn(1, 2, 3), torch.randn(2, 4)] - specs = InputTensorSpec.from_tensors(tensors) - for spec, tensor in zip(specs, tensors): - self._validate_spec(spec, tensor) - - def test_from_tensors_with_dynamic_batch_size(self): - tensors = [torch.randn(1, 2, 3), torch.randn(1, 4)] - batch_size_range = [2, 3, 4] - specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( - tensors, batch_size_range - ) - for spec, tensor in zip(specs, tensors): - self._validate_spec(spec, tensor, dynamic_dims=[0]) - - for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): - self.assertEqual(batch_size, shape[0]) - self.assertSequenceEqual(tensor.shape[1:], shape[1:]) - - def test_from_tensors_with_dynamic_batch_size_different_batch_dims(self): - tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] - batch_size_range = [2, 3, 4] - specs = InputTensorSpec.from_tensors_with_dynamic_batch_size( - tensors, batch_size_range, batch_dims=[0, 1] - ) - for i, spec_and_tensor in enumerate(zip(specs, tensors)): - spec, tensor = spec_and_tensor - self._validate_spec(spec, tensor, dynamic_dims=[i]) - - for batch_size, shape in zip(batch_size_range, spec.shape_ranges[0]): - self.assertEqual(batch_size, shape[i]) - tensor_shape = list(tensor.shape) - tensor_shape[i] = batch_size - self.assertSequenceEqual(tensor_shape, shape) - - def test_from_static_input(self): - tensors = [torch.randn(1, 2, 3), torch.randn(2, 1, 4)] - inputs = torch_tensorrt.Input.from_tensors(tensors) - specs = [InputTensorSpec.from_input(input) for input in inputs] - for spec, tensor in zip(specs, tensors): - self._validate_spec(spec, tensor) - - def test_from_dynamic_input(self): - inputs = torch_tensorrt.Input( - min_shape=(2, 2, 3), opt_shape=(4, 2, 3), max_shape=(8, 2, 3) - ) - example_tensor = inputs.example_tensor(optimization_profile_field="opt_shape") - spec = InputTensorSpec.from_input(inputs) - self._validate_spec(spec, example_tensor, dynamic_dims=[0]) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py deleted file mode 100644 index 6423aa65ea..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .trt_minimizer import * # noqa: F401 F403 diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py deleted file mode 100644 index 334243fef4..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py +++ /dev/null @@ -1,446 +0,0 @@ -import logging -import time -import unittest -from typing import Callable, List, Optional, Set, Tuple - -import torch -import torch.fx - -import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer -import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer -from torch_tensorrt.fx import TRTModule -from torch.fx.experimental.normalize import NormalizeArgs -from torch.fx.passes import shape_prop -from torch.fx.passes.infra.pass_base import PassResult -from torch.testing._internal.common_utils import TestCase -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter -from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( - compose_bmm, - compose_chunk, - compose_getitem_slice, - remove_ops, - replace_aten_op_with_indices, - replace_aten_reshape_alias_with_replace, - replace_builtin_ops, - replace_native_layernorm_with_layernorm, - replace_transpose_mm_op_with_linear, - run_const_fold, -) -from torch_tensorrt.dynamo.fx_ts_compat.passes.pass_utils import chain_passes -from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -def fetch_attr(mod, target): - """ - Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. - - Args: - target (str): The fully-qualfiied name of the attribute to fetch - - Return: - Any: The value of the attribute. - """ - target_atoms = target.split(".") - attr_itr = mod - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError( - f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" - ) - attr_itr = getattr(attr_itr, atom) - return attr_itr - - -@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available") -class TRTTestCase(TestCase): - def setUp(self): - super().setUp() - torch.manual_seed(3) - - def run_test( - self, - mod, - inputs, - expected_ops, - unexpected_ops, - interpreter, - rtol, - atol, - precision=LowerPrecision.FP32, - ): - with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - - mod.eval() - if len(expected_ops): - self.assert_has_op(mod, expected_ops) - if unexpected_ops: - self.assert_unexpected_op(mod, unexpected_ops) - start = time.perf_counter() - interpreter_result = interpreter.run(lower_precision=precision) - sec = time.perf_counter() - start - _LOGGER.info(f"Interpreter run time(s): {sec}") - trt_mod = TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) - - ref_outputs = mod(*inputs) - - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - outputs = trt_mod(*cuda_inputs) - end_event.record() - torch.cuda.synchronize() - _LOGGER.info( - f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" - ) - - if type(outputs) not in (list, tuple): - outputs = [outputs] - if type(ref_outputs) not in ( - list, - tuple, - torch.return_types.max, - torch.return_types.min, - ): - ref_outputs = [ref_outputs] - for out, ref in zip(outputs, ref_outputs): - if not isinstance(ref, torch.Tensor): - ref = torch.tensor([ref]) - ref = ref.cpu() # to_dtype test has cases with gpu output - if ref.dtype == torch.int64: - ref = ref.int() # convert torch.max's index output tensor to int32 - torch.testing.assert_close( - out.cpu(), ref, rtol=rtol, atol=atol, equal_nan=True - ) - - def run_test_custom_compare_results( - self, - mod, - inputs, - expected_ops, - interpreter, - comparators: List[Tuple[Callable, List]], - fp16_mode=False, - ): - """ - Runs the test and compares the result using the provided comparators. - The size of comparators must be equal to the number of outputs from 'mod'. - - mod - a model to run. - inputs - a list of the model inputs. - expected ops - a list of ops that should be verified. - interpreter - used for converting the model to TRT. - comparators - a list of (func, args) pairs corresponding to each of - the module outputs. usage: func(x, y, *args) - - """ - with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - - mod.eval() - if len(expected_ops): - self.assert_has_op(mod, expected_ops) - - interpreter_result = interpreter.run( - lower_precision=LowerPrecision.FP16 - if fp16_mode - else LowerPrecision.FP32 - ) - trt_mod = TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) - res_trt = trt_mod(*cuda_inputs).cpu() - res_cpu = mod(*inputs) - assert len(res_trt) == len(res_cpu) - assert len(res_cpu) == len(comparators) - for output_trt, output_cpu, comparator in zip( - res_trt, res_cpu, comparators - ): - comp_func = comparator[0] - args = comparator[1] - self.assertTrue(comp_func(output_trt, output_cpu, *args)) - - def run_test_with_error(self, mod, inputs, interpreter, expect_error): - with self.assertRaises(expect_error): - with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - - mod.eval() - interpreter.run(lower_precision=LowerPrecision.FP32) - - def assert_has_op(self, mod, ops): - ops_in_mod = set() - - for node in mod.graph.nodes: - if node.op == "call_module": - ops_in_mod.add(type(fetch_attr(mod, node.target))) - elif node.op in {"call_function", "call_method"}: - ops_in_mod.add(node.target) - - self.assertTrue( - ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}" - ) - - def assert_unexpected_op(self, mod, ops): - for node in mod.graph.nodes: - if node.op == "call_module": - if type(fetch_attr(mod, node.target)) in ops: - return False - elif node.op in {"call_function", "call_method"}: - if node.target in ops: - return False - return True - - -class VanillaTestCase(TRTTestCase): - def run_test(self, mod, inputs, expected_ops, rtol=1e-03, atol=1e-03): - mod = torch.fx.symbolic_trace(mod) - shape_prop.ShapeProp(mod).propagate(*inputs) - mod = NormalizeArgs(mod).transform() - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test(mod, inputs, expected_ops, None, interp, rtol, atol) - - def run_test_custom_compare_results( - self, - mod, - inputs, - expected_ops, - interpreter, - comparators: List[Tuple[Callable, List]], - fp16_mode=False, - ): - # interpreter is ignored, we do not need this for Vanilla tests - # Note this is different from internal version, we need to fix the test case - # after we refactor the internal callsites to use this file - mod = torch.fx.symbolic_trace(mod) - shape_prop.ShapeProp(mod).propagate(*inputs) - mod = NormalizeArgs(mod).transform() - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test_custom_compare_results( - mod, inputs, expected_ops, interp, comparators, fp16_mode=fp16_mode - ) - - -class AccTestCase(TRTTestCase): - def run_test( - self, - mod, - inputs, - expected_ops, - unexpected_ops=None, - apply_passes=None, - test_explicit_batch_dim=True, - test_implicit_batch_dim=False, - test_explicit_precision=False, - rtol=1e-03, - atol=1e-03, - precision=LowerPrecision.FP32, - ): - mod.eval() - mod = acc_tracer.trace(mod, inputs) - - if apply_passes is not None: - pass_tracer = chain_passes(*apply_passes) - mod = pass_tracer(mod, inputs) - - if test_implicit_batch_dim: - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision - ) - - if test_explicit_batch_dim: - interp = TRTInterpreter( - mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True - ) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision - ) - - if test_explicit_precision: - interp = TRTInterpreter( - mod, - InputTensorSpec.from_tensors(inputs), - explicit_precision=test_explicit_precision, - ) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol - ) - - interp = TRTInterpreter( - mod, - InputTensorSpec.from_tensors(inputs), - explicit_batch_dimension=True, - explicit_precision=test_explicit_precision, - ) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision - ) - - def run_test_with_assert_error( - self, - mod, - inputs, - expect_error, - test_explicit_batch_dim=True, - test_implicit_batch_dim=True, - ): - mod.eval() - mod = acc_tracer.trace(mod, inputs) - - if test_implicit_batch_dim: - interp = TRTInterpreter(mod, InputTensorSpec.from_tensors(inputs)) - super().run_test_with_error(mod, inputs, interp, expect_error) - - if test_explicit_batch_dim: - interp = TRTInterpreter( - mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True - ) - super().run_test_with_error(mod, inputs, interp, expect_error) - - def run_test_with_dynamic_shape( - self, - mod, - input_specs, - expected_ops, - unexpected_ops=None, - rtol=1e-03, - atol=1e-03, - ): - mod.eval() - inputs = InputTensorSpec.create_inputs_from_specs(input_specs) - mod = acc_tracer.trace(mod, inputs) - interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True) - super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol) - - -class DispatchTestCase(TRTTestCase): - def generate_graph( - self, - mod: torch.nn.Module, - original_inputs: List[torch.Tensor], - expected_ops: Set[Callable], - unexpected_ops: Optional[Set[Callable]] = None, - customized_passes: List[Callable] = None, - ): - # Torchdynamo+aot proxytensor tracer - # Below are common passes - passes_list = [ - compose_bmm, - compose_chunk, - compose_getitem_slice, - replace_aten_reshape_alias_with_replace, - replace_aten_op_with_indices, - replace_transpose_mm_op_with_linear, # after compose_bmm - replace_native_layernorm_with_layernorm, - remove_ops, - replace_builtin_ops, # after replace_native_layernorm_with_layernorm - ] - # Combine with customized passes specific to any model - if customized_passes: - passes_list.extend(customized_passes) - fx_module, _ = aten_tracer.trace(mod, original_inputs) - for passes in passes_list: - pr: PassResult = passes(fx_module) - fx_module = pr.graph_module - fx_module(*original_inputs) - - fx_module = run_const_fold(fx_module) - _LOGGER.info(f"FX graph= {fx_module.graph}") - - if len(expected_ops): - self.assert_has_op(fx_module, expected_ops) - if unexpected_ops: - self.assert_unexpected_op(fx_module, unexpected_ops) - - return fx_module - - def run_test( - self, - mod, - inputs, - expected_ops, - unexpected_ops=None, - apply_passes=None, - test_explicit_batch_dim=True, - test_explicit_precision=False, - rtol=1e-03, - atol=1e-03, - precision=LowerPrecision.FP32, - ): - mod.eval() - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) - - if apply_passes is not None: - pass_tracer = chain_passes(*apply_passes) - mod = pass_tracer(mod, inputs) - - if test_explicit_batch_dim: - interp = TRTInterpreter( - mod, - InputTensorSpec.from_tensors(inputs), - explicit_batch_dimension=True, - ) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision - ) - - if test_explicit_precision: - interp = TRTInterpreter( - mod, - InputTensorSpec.from_tensors(inputs), - explicit_precision=test_explicit_precision, - ) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol - ) - - interp = TRTInterpreter( - mod, - InputTensorSpec.from_tensors(inputs), - explicit_batch_dimension=True, - explicit_precision=test_explicit_precision, - ) - super().run_test( - mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision - ) - - def run_test_with_dynamic_shape( - self, - mod, - input_specs, - expected_ops, - unexpected_ops=None, - rtol=1e-03, - atol=1e-03, - ): - mod.eval() - inputs = InputTensorSpec.create_inputs_from_specs(input_specs) - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) - - interp = TRTInterpreter( - mod, - input_specs, - explicit_batch_dimension=True, - ) - # Since the lowering is based on optimal shape. We need to test with - # different shape(for ex. max shape) for testing dynamic shape - inputs_max = InputTensorSpec.create_inputs_from_max_specs(input_specs) - super().run_test( - mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol - ) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py deleted file mode 100644 index bfb1964de9..0000000000 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py +++ /dev/null @@ -1,103 +0,0 @@ -import logging -from typing import Any, Callable, Tuple - -import torch -import torch.fx.passes.net_min_base as net_min_base -from torch.fx.passes.tools_common import Tensors - -from .. import InputTensorSpec, TRTInterpreter - -from torch_tensorrt.fx import TRTModule - -_LOGGER: logging.Logger = logging.getLogger(__name__) - - -def lower_mod_default( - mod: torch.fx.GraphModule, - inputs: Tensors, - use_python_runtime: bool = False, -) -> TRTModule: - interp = TRTInterpreter( - mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True - ) - interpreter_result = interp.run() - if use_python_runtime: - res_mod = TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) - - else: - import io - - from torch_tensorrt._Device import Device - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule - - with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) - engine_str = engine_bytes.getvalue() - - res_mod = TorchTensorRTModule( - engine_str, - name=str(type(mod)), - input_binding_names=interpreter_result.input_names, - output_binding_names=interpreter_result.output_names, - target_device=Device(f"cuda:{torch.cuda.current_device()}"), - # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do - ) - - return res_mod - - -class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase): - def __init__( - self, explicit_batch_dimension: Any = True, use_experimental_rt: bool = False - ): - if use_experimental_rt and not explicit_batch_dimension: - raise ValueError( - "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_rt=True" - ) - - self.explicit_batch_dimension = explicit_batch_dimension - self.use_experimental_rt = use_experimental_rt - super(TensorRTMinizerSetting, self).__init__() - - -class TensorRTMinimizer(net_min_base._MinimizerBase): - def __init__( - self, - module: torch.fx.GraphModule, - sample_input: Tensors, - compare_fn: Callable[[Any, Any, Any], Tuple[float, bool]], - settings: TensorRTMinizerSetting = TensorRTMinizerSetting(), - lower_fn: Callable[ - [torch.fx.GraphModule, Tensors, Any, bool], TRTModule - ] = lower_mod_default, - ): - self.lower_fn = lower_fn - self.use_experiemental_rt = settings.use_experimental_rt - super().__init__(module, sample_input, compare_fn, settings) - - def run_a(self, mod, inputs): - mod.eval() - with torch.no_grad(): - return mod(*inputs) - - def run_b(self, mod, inputs): - mod.eval() - try: - mod = self.lower_fn(mod, inputs, self.use_experiemental_rt) - output = mod(*inputs) - except RuntimeError as e: - raise net_min_base.FxNetMinimizerRunFuncError( - f"Encounter an error when processing \n{mod.graph}\n {e}" - ) - else: - return output - - def get_nodes(self, start=None, end=None, enable_print=False): - nodes = self._collect_nodes(start, end) - if enable_print: - _LOGGER.info(f"Nodes fetched from start {start} to end {end} as: {nodes}") - return nodes diff --git a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py similarity index 91% rename from py/torch_tensorrt/dynamo/backend/lowering/__init__.py rename to py/torch_tensorrt/dynamo/lowering/__init__.py index dd55d2c83c..44c37789a9 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -7,3 +7,4 @@ ) from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS from .substitutions import * +from ._fusers import * diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py similarity index 100% rename from py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py rename to py/torch_tensorrt/dynamo/lowering/_decompositions.py diff --git a/py/torch_tensorrt/dynamo/lowering/_fusers.py b/py/torch_tensorrt/dynamo/lowering/_fusers.py new file mode 100644 index 0000000000..c845204a65 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_fusers.py @@ -0,0 +1,72 @@ +import torch +from torch_tensorrt.fx.tracer.acc_tracer import acc_ops + + +def check_permute(node: torch.fx.Node): + ranks = len(node.meta["tensor_meta"].shape) + permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr] + allowed_permutation = list(i for i in range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def fuse_permute_matmul(gm: torch.fx.GraphModule): + """ + Fuse pattern like permute + matmul if permute is transposing the last two dimension. + """ + for node in gm.graph.nodes: + if node.target == acc_ops.matmul: + lhs, rhs = node.kwargs["input"], node.kwargs["other"] + lhs_transposed = rhs_tranposed = False + skip = False + + if lhs.target == acc_ops.permute and check_permute(lhs): + lhs_transposed = True + lhs = lhs.kwargs["input"] + + if rhs.target == acc_ops.permute and check_permute(rhs): + rhs_tranposed = True + rhs = rhs.kwargs["input"] + + if (not skip) and (lhs_transposed or rhs_tranposed): + with gm.graph.inserting_before(node): + fused_node = gm.graph.call_function( + trt_transposed_matmul, + args=(lhs, rhs, lhs_transposed, rhs_tranposed), + ) + node.replace_all_uses_with(fused_node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + + +def trt_transposed_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +): + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def fuse_permute_linear(gm: torch.fx.GraphModule): + """ + Fuse pattern like permute + linear if permute is transposing the last two dimension. + """ + for node in gm.graph.nodes: + if node.target == acc_ops.linear: + inp = node.kwargs["input"] + if inp.target == acc_ops.permute and check_permute(inp): + inp = inp.kwargs["input"] + weight = node.kwargs["weight"] + bias = node.kwargs["bias"] + with gm.graph.inserting_before(node): + fused_node = gm.graph.call_function( + trt_transposed_linear, args=(inp, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/lowering/_partition.py similarity index 98% rename from py/torch_tensorrt/dynamo/backend/lowering/_partition.py rename to py/torch_tensorrt/dynamo/lowering/_partition.py index 4d82bf4be5..c239dbc5b3 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/lowering/_partition.py @@ -3,8 +3,8 @@ import torch -from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE -from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo.lowering import SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py similarity index 100% rename from py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py rename to py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py diff --git a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py b/py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py similarity index 100% rename from py/torch_tensorrt/dynamo/backend/lowering/substitutions/__init__.py rename to py/torch_tensorrt/dynamo/lowering/substitutions/__init__.py diff --git a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py b/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py similarity index 96% rename from py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py rename to py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py index 57c4a93e62..c4a29b507e 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/einsum.py @@ -7,7 +7,7 @@ from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.dynamo.backend.lowering import register_substitution +from torch_tensorrt.dynamo.lowering import register_substitution @custom_op( diff --git a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py similarity index 98% rename from py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py rename to py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py index 020d3a0ca9..b3265e7875 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py +++ b/py/torch_tensorrt/dynamo/lowering/substitutions/maxpool1d.py @@ -7,7 +7,7 @@ from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from torch_tensorrt.dynamo.backend.lowering import register_substitution +from torch_tensorrt.dynamo.lowering import register_substitution # This file serves as an example and a tutorial for excluding custom modules from diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py new file mode 100644 index 0000000000..79070cea31 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTRTModule.py @@ -0,0 +1,256 @@ +from typing import Any, List, Sequence + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks + + +class PythonTorchTRTModule(torch.nn.Module): + """PythonTorchTRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. + + This module is backed by the Torch-TensorRT runtime and is only compatibile with + FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment. + """ + + def __init__( + self, engine=None, input_names=None, output_names=None, cuda_graph_batch_size=-1 + ): + super(PythonTorchTRTModule, self).__init__() + self._register_state_dict_hook(PythonTorchTRTModule._on_state_dict) + self.engine = engine + self.input_names = input_names + self.output_names = output_names + self.cuda_graph_batch_size = cuda_graph_batch_size + self.initialized = False + + if engine: + self._initialize() + + def _initialize(self): + self.initialized = True + self.context = self.engine.create_execution_context() + + # Indices of inputs/outputs in the trt engine bindings, in the order + # as they are in the original PyTorch model. + self.input_binding_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.input_names + ] + self.output_binding_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.output_names + ] + primary_input_outputs = set() + primary_input_outputs.update(self.input_binding_indices_in_order) + primary_input_outputs.update(self.output_binding_indices_in_order) + self.hidden_output_binding_indices_in_order: Sequence[int] = [] + self.hidden_output_names: Sequence[str] = [] + for i in range( + self.engine.num_bindings // self.engine.num_optimization_profiles + ): + if i not in primary_input_outputs: + self.hidden_output_binding_indices_in_order.append(i) + self.hidden_output_names.append(self.engine.get_binding_name(i)) + + assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == ( + len(self.input_names) + + len(self.output_names) + + len(self.hidden_output_names) + ) + + self.input_dtypes: Sequence[torch.dtype] = [ + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) + for idx in self.input_binding_indices_in_order + ] + self.input_shapes: Sequence[Sequence[int]] = [ + tuple(self.engine.get_binding_shape(idx)) + for idx in self.input_binding_indices_in_order + ] + self.output_dtypes: Sequence[torch.dtype] = [ + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) + for idx in self.output_binding_indices_in_order + ] + self.output_shapes = [ + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + for idx in self.output_binding_indices_in_order + ] + self.hidden_output_dtypes: Sequence[torch.dtype] = [ + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) + for idx in self.hidden_output_binding_indices_in_order + ] + self.hidden_output_shapes = [ + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + for idx in self.hidden_output_binding_indices_in_order + ] + + def _check_initialized(self): + if not self.initialized: + raise RuntimeError("PythonTorchTRTModule is not initialized.") + + def _on_state_dict(self, state_dict, prefix, local_metadata): + self._check_initialized() + state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) + state_dict[prefix + "input_names"] = self.input_names + state_dict[prefix + "output_names"] = self.output_names + state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + engine_bytes = state_dict[prefix + "engine"] + + logger = trt.Logger() + runtime = trt.Runtime(logger) + self.engine = runtime.deserialize_cuda_engine(engine_bytes) + + self.input_names = state_dict[prefix + "input_names"] + self.output_names = state_dict[prefix + "output_names"] + self._initialize() + + def __getstate__(self): + state = self.__dict__.copy() + state["engine"] = bytearray(self.engine.serialize()) + state.pop("context", None) + return state + + def __setstate__(self, state): + logger = trt.Logger() + runtime = trt.Runtime(logger) + state["engine"] = runtime.deserialize_cuda_engine(state["engine"]) + self.__dict__.update(state) + if self.engine: + self.context = self.engine.create_execution_context() + + def forward(self, *inputs): + with torch.autograd.profiler.record_function("PythonTorchTRTModule:Forward"): + self._check_initialized() + + with torch.autograd.profiler.record_function( + "PythonTorchTRTModule:ProcessInputs" + ): + assert len(inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." + + # This is only used when the trt engine is using implicit batch dim. + batch_size = inputs[0].shape[0] + contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] + bindings: List[Any] = [None] * ( + len(self.input_names) + + len(self.output_names) + + len(self.hidden_output_names) + ) + + for i, input_name in enumerate(self.input_names): + assert inputs[ + i + ].is_cuda, f"{i}th input({input_name}) is not on cuda device." + assert ( + inputs[i].dtype == self.input_dtypes[i] + ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}." + + idx = self.input_binding_indices_in_order[i] + bindings[idx] = contiguous_inputs[i].data_ptr() + + if not self.engine.has_implicit_batch_dimension: + self.context.set_binding_shape( + idx, tuple(contiguous_inputs[i].shape) + ) + else: + assert inputs[i].size()[1:] == self.input_shapes[i], ( + f"Shape mismatch for {i}th input({input_name}). " + f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}." + ) + + with torch.autograd.profiler.record_function( + "PythonTorchTRTModule:ProcessOutputs" + ): + # create output tensors + outputs: List[torch.Tensor] = [] + + for i, idx in enumerate(self.output_binding_indices_in_order): + if self.engine.has_implicit_batch_dimension: + shape = (batch_size,) + self.output_shapes[i] + else: + shape = tuple(self.context.get_binding_shape(idx)) + + output = torch.empty( # type: ignore[call-overload] + size=shape, + dtype=self.output_dtypes[i], + device=torch.cuda.current_device(), + ) + outputs.append(output) + bindings[idx] = output.data_ptr() + + for i, idx in enumerate(self.hidden_output_binding_indices_in_order): + if self.engine.has_implicit_batch_dimension: + shape = (batch_size,) + self.hidden_output_shapes[i] + else: + shape = tuple(self.context.get_binding_shape(idx)) + + output = torch.empty( # type: ignore[call-overload] + size=shape, + dtype=self.hidden_output_dtypes[i], + device=torch.cuda.current_device(), + ) + bindings[idx] = output.data_ptr() + + with torch.autograd.profiler.record_function( + "PythonTorchTRTModule:TensorRTRuntime" + ): + if self.engine.has_implicit_batch_dimension: + self.context.execute_async( + batch_size, bindings, torch.cuda.current_stream().cuda_stream + ) + else: + self.context.execute_async_v2( + bindings, torch.cuda.current_stream().cuda_stream + ) + + if len(outputs) == 1: + return outputs[0] + + return tuple(outputs) + + def enable_profiling(self, profiler: "trt.IProfiler" = None): + """ + Enable TensorRT profiling. After calling this function, TensorRT will report + time spent on each layer in stdout for each forward run. + """ + self._check_initialized() + + if not self.context.profiler: + self.context.profiler = trt.Profiler() if profiler is None else profiler + + def disable_profiling(self): + """ + Disable TensorRT profiling. + """ + self._check_initialized() + + torch.cuda.synchronize() + del self.context + self.context = self.engine.create_execution_context() + + def get_layer_info(self) -> str: + """ + Get layer info of the engine. Only support for TRT > 8.2. + """ + inspector = self.engine.create_engine_inspector() + return inspector.get_engine_information(trt.LayerInformationFormat.JSON) diff --git a/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py similarity index 95% rename from py/torch_tensorrt/dynamo/_TorchTensorRTModule.py rename to py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 8359bc62fb..c1fd86af8a 100644 --- a/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -13,7 +13,7 @@ class TorchTensorRTModule(torch.nn.Module): This module is backed by the Torch-TensorRT runtime and is fully compatibile with both FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as - well as TorchScript / C++ deployments since TRTModule can be passed to ``torch.jit.trace`` + well as TorchScript / C++ deployments since TorchTensorRTModule can be passed to ``torch.jit.trace`` and then saved. The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where @@ -36,7 +36,7 @@ def __init__( output_binding_names: List[str] = [], target_device: Device = Device._current_device(), ): - """__init__ method for torch_tensorrt.TorchTensorRTModule + """__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. @@ -61,11 +61,11 @@ def __init__( engine_bytes.write(trt_engine.serialize()) engine_str = engine_bytes.getvalue() - trt_module = TRTModule( + trt_module = TorchTensorRTModule( engine_str, - engine_name="my_module", - input_names=["x"], - output_names=["output"], + name="my_module", + input_binding_names=["x"], + output_binding_names=["output"], ) """ diff --git a/py/torch_tensorrt/dynamo/runtime/__init__.py b/py/torch_tensorrt/dynamo/runtime/__init__.py new file mode 100644 index 0000000000..a4586eae0a --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/__init__.py @@ -0,0 +1,2 @@ +from ._PythonTorchTRTModule import PythonTorchTRTModule +from ._TorchTensorRTModule import TorchTensorRTModule diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py new file mode 100644 index 0000000000..0688d2e169 --- /dev/null +++ b/py/torch_tensorrt/dynamo/utils.py @@ -0,0 +1,162 @@ +import torch +import logging +from dataclasses import replace, fields +from torch_tensorrt.dynamo import CompilationSettings +from typing import Any, Union, Sequence, Dict +from torch_tensorrt import Input, Device +from typing import Optional + +logger = logging.getLogger(__name__) + +COSINE_THRESHOLD = 0.99 + + +def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: + """Parses a user-provided input argument regarding Python runtime + + Automatically handles cases where the user has not specified a runtime (None) + + Returns True if the Python runtime should be used, False if the C++ runtime should be used + """ + using_python_runtime = use_python_runtime + reason = "" + + # Runtime was manually specified by the user + if using_python_runtime is not None: + reason = "as requested by user" + # Runtime was not manually specified by the user, automatically detect runtime + else: + try: + from torch_tensorrt.dynamo.runtime import TorchTensorRTModule + + using_python_runtime = False + reason = "since C++ dependency was detected as present" + except ImportError: + using_python_runtime = True + reason = "since import failed, C++ dependency not installed" + + logger.info( + f"Using {'Python-only' if using_python_runtime else 'Default'} Torch-TRT Runtime ({reason})" + ) + + return using_python_runtime + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + + return res + + +def prepare_inputs( + inputs: Union[Input, torch.Tensor, Sequence, Dict], + device: torch.device = torch.device("cuda"), +) -> Any: + if isinstance(inputs, Input): + if isinstance(inputs.shape, dict): + return inputs, inputs.example_tensor( + optimization_profile_field="opt_shape" + ).to(device) + else: + return inputs, inputs.example_tensor().to(device) + + elif isinstance(inputs, torch.Tensor): + return Input.from_tensor(inputs), inputs + + elif isinstance(inputs, list): + prepared_input = list() + torchtrt_inputs = [] + torch_inputs = [] + for input_obj in inputs: + torchtrt_input, torch_input = prepare_inputs(input_obj) + torchtrt_inputs.append(torchtrt_input) + torch_inputs.append(torch_input) + + return torchtrt_inputs, torch_inputs + + elif isinstance(inputs, tuple): + torchtrt_inputs = [] + torch_inputs = [] + for input_obj in inputs: + torchtrt_input, torch_input = prepare_inputs(input_obj) + torchtrt_inputs.append(torchtrt_input) + torch_inputs.append(torch_input) + + return tuple(torchtrt_inputs), tuple(torch_inputs) + + elif isinstance(inputs, dict): + torchtrt_inputs = dict() + torch_inputs = dict() + + for key, input_obj in inputs.items(): + torchtrt_input, torch_input = prepare_inputs(input_obj) + torchtrt_inputs[key] = torchtrt_input + torch_inputs[key] = torch_input + + return torchtrt_inputs, torch_inputs + + else: + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) + + +def prepare_device(device: Union[Device, torch.device]) -> torch.device: + if isinstance(device, Device): + if device.gpu_id != -1: + device = torch.device(device.gpu_id) + else: + raise ValueError("Invalid GPU ID provided for the CUDA device provided") + + elif isinstance(device, torch.device): + device = device + + else: + raise ValueError( + "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" + ) + + return device + + +def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: + """Parses the kwargs field of a Dynamo backend + + Args: + kwargs: Keyword arguments dictionary provided to the backend + Returns: + CompilationSettings object with relevant kwargs + """ + + # Initialize an empty CompilationSettings object + settings = CompilationSettings() + + # If the user specifies keyword args, overwrite those fields in settings + # Validate all specified kwargs to ensure they are true fields of the dataclass + # + # Note: kwargs provided by torch.compile are wrapped in the "options" key + if kwargs: + if "options" in kwargs and len(kwargs) == 1: + kwargs = kwargs["options"] + + valid_attrs = {attr.name for attr in fields(settings)} + valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs} + settings = replace(settings, **valid_kwargs) + + # Enable debug/verbose mode if requested + if settings.debug: + logger.setLevel(logging.DEBUG) + + # Parse input runtime specification + settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) + + logger.debug(f"Compiling with Settings:\n{settings}") + + return settings diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 1973c112fd..477774248d 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -224,33 +224,33 @@ test_suite( ":test_div", ":test_einsum", ":test_expand", + ":test_index", ":test_instance_norm", ":test_interpolate", - ":test_index", ":test_layer_norm", ":test_linear", ":test_lstm_cell", - ":test_matrix_multiply", ":test_masked_fill", + ":test_matrix_multiply", ":test_max", ":test_normalize", ":test_pooling", ":test_reduce", - ":test_roll", ":test_replication_pad", + ":test_roll", ":test_scatter", ":test_select", ":test_shuffle", + ":test_slice", ":test_softmax", + ":test_split", ":test_squeeze", ":test_stack", - ":test_split", - ":test_slice", ":test_topk", ":test_unary", - ":test_unsqueeze", ":test_unbind", ":test_unpack", + ":test_unsqueeze", ":test_where", ], ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py similarity index 93% rename from py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py rename to tests/py/dynamo/backend/test_backend_compiler.py index 2af251adbc..4bc1c8c18c 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -1,10 +1,9 @@ -from torch_tensorrt.dynamo.backend.lowering import partition -from torch.testing._internal.common_utils import run_tests, TestCase import torch +import torch_tensorrt +from torch_tensorrt.dynamo.lowering import partition +from torch.testing._internal.common_utils import run_tests, TestCase from copy import deepcopy -from torch_tensorrt.dynamo import compile -from utils import lower_graph_testing -from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT +from utils import lower_graph_testing, DECIMALS_OF_AGREEMENT class TestTRTModuleNextCompilation(TestCase): @@ -34,8 +33,9 @@ def forward(self, x, y): torch._dynamo.reset() # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( + optimized_model = torch_tensorrt.compile( fx_graph, + "torch_compile", inputs, min_block_size=1, pass_through_build_failures=True, @@ -102,8 +102,9 @@ def forward(self, x, y): torch._dynamo.reset() # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( + optimized_model = torch_tensorrt.compile( fx_graph, + "torch_compile", inputs, min_block_size=1, pass_through_build_failures=True, @@ -144,8 +145,9 @@ def forward(self, x, y): ] # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( + optimized_model = torch_tensorrt.compile( fx_graph, + "torch_compile", inputs, min_block_size=1, pass_through_build_failures=True, diff --git a/py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py b/tests/py/dynamo/backend/test_compiler_utils.py similarity index 61% rename from py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py rename to tests/py/dynamo/backend/test_compiler_utils.py index 947a277ddd..3ef81b4e1a 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py +++ b/tests/py/dynamo/backend/test_compiler_utils.py @@ -1,4 +1,4 @@ -from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs +from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs from utils import same_output_format import torch_tensorrt import unittest @@ -24,16 +24,22 @@ def test_prepare_trt_device(self): class TestPrepareInputs(unittest.TestCase): def test_prepare_single_tensor_input(self): inputs = [torch.ones((4, 4))] - prepared_inputs = prepare_inputs(inputs) + prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs) self.assertTrue( - same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False) + ) + self.assertTrue( + same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False) ) def test_prepare_trt_input(self): inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)] - prepared_inputs = prepare_inputs(inputs) + prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False) + ) self.assertTrue( - same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False) ) def test_prepare_mixed_type_compound_tensor_input(self): @@ -47,9 +53,12 @@ def test_prepare_mixed_type_compound_tensor_input(self): (torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))), ), } - prepared_inputs = prepare_inputs(inputs) + prepared_inputs_trt, prepared_inputs_torch = prepare_inputs(inputs) + self.assertTrue( + same_output_format(inputs, prepared_inputs_trt, enforce_tensor_type=False) + ) self.assertTrue( - same_output_format(inputs, prepared_inputs, enforce_tensor_type=False) + same_output_format(inputs, prepared_inputs_torch, enforce_tensor_type=False) ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py b/tests/py/dynamo/backend/test_decompositions.py similarity index 95% rename from py/torch_tensorrt/dynamo/backend/test/test_decompositions.py rename to tests/py/dynamo/backend/test_decompositions.py index d947c955e0..a9578e4ed8 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py +++ b/tests/py/dynamo/backend/test_decompositions.py @@ -1,9 +1,8 @@ from functools import partial -from utils import lower_graph_testing +from utils import lower_graph_testing, DECIMALS_OF_AGREEMENT from torch.testing._internal.common_utils import run_tests, TestCase import torch -from torch_tensorrt.dynamo import compile -from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT +import torch_tensorrt class TestLowering(TestCase): @@ -163,8 +162,12 @@ def forward(self, x, y, z): torch._dynamo.reset() # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( - fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py b/tests/py/dynamo/backend/test_partitioning.py similarity index 98% rename from py/torch_tensorrt/dynamo/backend/test/test_partitioning.py rename to tests/py/dynamo/backend/test_partitioning.py index fb5430b384..76645075b9 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py +++ b/tests/py/dynamo/backend/test_partitioning.py @@ -1,4 +1,4 @@ -from torch_tensorrt.dynamo.backend.lowering import partition +from torch_tensorrt.dynamo.lowering import partition from torch.testing._internal.common_utils import run_tests, TestCase from utils import lower_graph_testing import torch diff --git a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py b/tests/py/dynamo/backend/test_pre_aot_lowering.py similarity index 87% rename from py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py rename to tests/py/dynamo/backend/test_pre_aot_lowering.py index da44d6e826..3e340b5168 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py +++ b/tests/py/dynamo/backend/test_pre_aot_lowering.py @@ -1,7 +1,7 @@ import torch +import torch_tensorrt from utils import lower_graph_testing from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo import compile class TestMaxPool1D(TestCase): @@ -39,8 +39,12 @@ def forward(self, x): torch._dynamo.reset() # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( - fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -85,8 +89,12 @@ def forward(self, x, y): torch._dynamo.reset() # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( - fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py similarity index 87% rename from py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py rename to tests/py/dynamo/backend/test_specialized_models.py index 17df523ab8..ed9fc35a59 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -1,7 +1,7 @@ from utils import lower_graph_testing from torch.testing._internal.common_utils import run_tests, TestCase import torch -from torch_tensorrt.dynamo import compile +import torch_tensorrt class TestFakeTensors(TestCase): @@ -40,8 +40,12 @@ def forward(self, x): torch._dynamo.reset() # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( - fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() @@ -92,8 +96,12 @@ def forward(self, x): torch._dynamo.reset() # Validate that the results between Torch and Torch-TRT are similar - optimized_model = compile( - fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() torch_model_results = fx_graph(*inputs).detach().cpu() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/tests/py/dynamo/backend/utils.py similarity index 96% rename from py/torch_tensorrt/dynamo/backend/test/utils.py rename to tests/py/dynamo/backend/utils.py index 7c679b7d4d..0eaba4aeea 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/tests/py/dynamo/backend/utils.py @@ -2,13 +2,13 @@ from functools import partial from typing import Any, List, Sequence, Set import torch -from torch_tensorrt.dynamo.backend.lowering._decompositions import ( +from torch_tensorrt.dynamo.lowering._decompositions import ( get_decompositions, ) -from torch_tensorrt.dynamo.backend.lowering._partition import ( +from torch_tensorrt.dynamo.lowering._partition import ( partition, ) -from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( +from torch_tensorrt.dynamo.lowering._pre_aot_lowering import ( pre_aot_substitutions, ) @@ -16,6 +16,8 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +DECIMALS_OF_AGREEMENT = 4 + @fake_tensor_unsupported def fx_dynamo_testing_backend( diff --git a/py/torch_tensorrt/dynamo/test/conftest.py b/tests/py/dynamo/models/conftest.py similarity index 85% rename from py/torch_tensorrt/dynamo/test/conftest.py rename to tests/py/dynamo/models/conftest.py index 7218d5335b..3fbabc3360 100644 --- a/py/torch_tensorrt/dynamo/test/conftest.py +++ b/tests/py/dynamo/models/conftest.py @@ -9,7 +9,7 @@ def pytest_addoption(parser): type=str, required=True, help="IR to compile with", - choices=["dynamo_compile", "fx_ts_compat"], + choices=["dynamo", "torch_compile"], ) diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/tests/py/dynamo/models/test_models.py similarity index 96% rename from py/torch_tensorrt/dynamo/test/test_dynamo_backend.py rename to tests/py/dynamo/models/test_models.py index f34aad6caf..1141d54a7b 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/tests/py/dynamo/models/test_models.py @@ -8,7 +8,7 @@ from transformers import BertModel -from torch_tensorrt.dynamo.common_utils.test_utils import ( +from torch_tensorrt.dynamo.utils import ( COSINE_THRESHOLD, cosine_similarity, ) @@ -33,6 +33,7 @@ def test_resnet18(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -66,6 +67,7 @@ def test_mobilenet_v2(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -99,6 +101,7 @@ def test_efficientnet_b0(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -141,6 +144,7 @@ def test_bert_base_uncased(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -178,6 +182,7 @@ def test_resnet18_half(ir): "pass_through_build_failures": True, "optimization_level": 1, "min_block_size": 8, + "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py new file mode 100644 index 0000000000..2f05c00c20 --- /dev/null +++ b/tests/py/dynamo/models/test_models_export.py @@ -0,0 +1,152 @@ +import torch +import timm +import pytest +import unittest + +import torch_tensorrt as torchtrt +import torchvision.models as models + +from transformers import BertModel + +from torch_tensorrt.dynamo.utils import ( + COSINE_THRESHOLD, + cosine_similarity, +) + +assertions = unittest.TestCase() + + +@pytest.mark.unit +def test_resnet18(ir): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "ir": "dynamo", + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_mobilenet_v2(ir): + model = models.mobilenet_v2(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "ir": "dynamo", + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_efficientnet_b0(ir): + model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "ir": "dynamo", + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_resnet18_half(ir): + model = models.resnet18(pretrained=True).eval().to("cuda").half() + input = torch.randn((1, 3, 224, 224)).to("cuda").half() + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.half, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.half}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, + "ir": "dynamo", + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + with torch.no_grad(): + torch.cuda.empty_cache() diff --git a/tests/py/BUILD b/tests/py/ts/BUILD similarity index 100% rename from tests/py/BUILD rename to tests/py/ts/BUILD diff --git a/tests/py/api/test_classes.py b/tests/py/ts/api/test_classes.py similarity index 99% rename from tests/py/api/test_classes.py rename to tests/py/ts/api/test_classes.py index 3d0cb5c5f9..835257fc58 100644 --- a/tests/py/api/test_classes.py +++ b/tests/py/ts/api/test_classes.py @@ -1,6 +1,6 @@ import unittest import torch_tensorrt as torchtrt -from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule import torch import torchvision.models as models import copy diff --git a/tests/py/api/test_collections.py b/tests/py/ts/api/test_collections.py similarity index 100% rename from tests/py/api/test_collections.py rename to tests/py/ts/api/test_collections.py diff --git a/tests/py/api/test_e2e_behavior.py b/tests/py/ts/api/test_e2e_behavior.py similarity index 100% rename from tests/py/api/test_e2e_behavior.py rename to tests/py/ts/api/test_e2e_behavior.py diff --git a/tests/py/api/test_embed_engines.py b/tests/py/ts/api/test_embed_engines.py similarity index 100% rename from tests/py/api/test_embed_engines.py rename to tests/py/ts/api/test_embed_engines.py diff --git a/tests/py/api/test_logging.py b/tests/py/ts/api/test_logging.py similarity index 100% rename from tests/py/api/test_logging.py rename to tests/py/ts/api/test_logging.py diff --git a/tests/py/api/test_module_fallback.py b/tests/py/ts/api/test_module_fallback.py similarity index 98% rename from tests/py/api/test_module_fallback.py rename to tests/py/ts/api/test_module_fallback.py index 5eda2cdbfc..8ef1059b2f 100644 --- a/tests/py/api/test_module_fallback.py +++ b/tests/py/ts/api/test_module_fallback.py @@ -23,6 +23,7 @@ def test_fallback_resnet18(self): }, "enabled_precisions": {torch.float}, "torch_executed_modules": ["torchvision.models.resnet.BasicBlock"], + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) @@ -49,6 +50,7 @@ def test_fallback_mobilenet_v2(self): "torchvision.models.mobilenetv2.ConvBNActivation" ], "min_block_size": 5, + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) diff --git a/tests/py/api/test_operator_fallback.py b/tests/py/ts/api/test_operator_fallback.py similarity index 97% rename from tests/py/api/test_operator_fallback.py rename to tests/py/ts/api/test_operator_fallback.py index 19ba891514..3e4777869f 100644 --- a/tests/py/api/test_operator_fallback.py +++ b/tests/py/ts/api/test_operator_fallback.py @@ -23,6 +23,7 @@ def test_fallback_resnet18(self): }, "enabled_precisions": {torch.float}, "torch_executed_ops": ["aten::add"], + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) @@ -49,6 +50,7 @@ def test_fallback_resnet18_with_tensor_domain(self): }, "enabled_precisions": {torch.float}, "torch_executed_ops": ["aten::add"], + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) @@ -72,6 +74,7 @@ def test_fallback_mobilenet_v2(self): }, "enabled_precisions": {torch.float}, "torch_executed_ops": ["aten::hardtanh"], + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) diff --git a/tests/py/api/test_ts_backend.py b/tests/py/ts/api/test_ts_backend.py similarity index 100% rename from tests/py/api/test_ts_backend.py rename to tests/py/ts/api/test_ts_backend.py diff --git a/tests/py/api/utils.py b/tests/py/ts/api/utils.py similarity index 100% rename from tests/py/api/utils.py rename to tests/py/ts/api/utils.py diff --git a/tests/py/hw/test_api_dla.py b/tests/py/ts/hw/test_api_dla.py similarity index 100% rename from tests/py/hw/test_api_dla.py rename to tests/py/ts/hw/test_api_dla.py diff --git a/tests/py/hw/test_multi_gpu.py b/tests/py/ts/hw/test_multi_gpu.py similarity index 100% rename from tests/py/hw/test_multi_gpu.py rename to tests/py/ts/hw/test_multi_gpu.py diff --git a/tests/py/hw/utils.py b/tests/py/ts/hw/utils.py similarity index 100% rename from tests/py/hw/utils.py rename to tests/py/ts/hw/utils.py diff --git a/tests/py/integrations/test_to_backend_api.py b/tests/py/ts/integrations/test_to_backend_api.py similarity index 100% rename from tests/py/integrations/test_to_backend_api.py rename to tests/py/ts/integrations/test_to_backend_api.py diff --git a/tests/py/integrations/test_trt_intercompatibility.py b/tests/py/ts/integrations/test_trt_intercompatibility.py similarity index 100% rename from tests/py/integrations/test_trt_intercompatibility.py rename to tests/py/ts/integrations/test_trt_intercompatibility.py diff --git a/tests/py/integrations/utils.py b/tests/py/ts/integrations/utils.py similarity index 100% rename from tests/py/integrations/utils.py rename to tests/py/ts/integrations/utils.py diff --git a/tests/py/model_test_case.py b/tests/py/ts/model_test_case.py similarity index 100% rename from tests/py/model_test_case.py rename to tests/py/ts/model_test_case.py diff --git a/tests/py/models/custom_models.py b/tests/py/ts/models/custom_models.py similarity index 100% rename from tests/py/models/custom_models.py rename to tests/py/ts/models/custom_models.py diff --git a/tests/py/models/test_models.py b/tests/py/ts/models/test_models.py similarity index 98% rename from tests/py/models/test_models.py rename to tests/py/ts/models/test_models.py index 037e9f93d1..3e042fc763 100644 --- a/tests/py/models/test_models.py +++ b/tests/py/ts/models/test_models.py @@ -25,6 +25,7 @@ def test_resnet18(self): "gpu_id": 0, }, "enabled_precisions": {torch.float}, + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) @@ -49,6 +50,7 @@ def test_mobilenet_v2(self): "gpu_id": 0, }, "enabled_precisions": {torch.float}, + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) @@ -75,6 +77,7 @@ def test_efficientnet_b0(self): "gpu_id": 0, }, "enabled_precisions": {torch.float}, + "ir": "ts", } trt_mod = torchtrt.compile(self.model, **compile_spec) @@ -137,6 +140,7 @@ def test_resnet18_half(self): "gpu_id": 0, }, "enabled_precisions": {torch.half}, + "ir": "ts", } trt_mod = torchtrt.compile(self.scripted_model, **compile_spec) diff --git a/tests/py/models/test_multiple_registered_engines.py b/tests/py/ts/models/test_multiple_registered_engines.py similarity index 98% rename from tests/py/models/test_multiple_registered_engines.py rename to tests/py/ts/models/test_multiple_registered_engines.py index 98f012597b..e8c1f95433 100644 --- a/tests/py/models/test_multiple_registered_engines.py +++ b/tests/py/ts/models/test_multiple_registered_engines.py @@ -27,6 +27,7 @@ def test_multiple_engines(self): "gpu_id": 0, }, "enabled_precisions": {torch.float}, + "ir": "ts", } rn18_trt_mod = torchtrt.compile(self.resnet18, **compile_spec) rn50_trt_mod = torchtrt.compile(self.resnet50, **compile_spec) diff --git a/tests/py/models/utils.py b/tests/py/ts/models/utils.py similarity index 100% rename from tests/py/models/utils.py rename to tests/py/ts/models/utils.py diff --git a/tests/py/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py similarity index 100% rename from tests/py/ptq/test_ptq_dataloader_calibrator.py rename to tests/py/ts/ptq/test_ptq_dataloader_calibrator.py diff --git a/tests/py/ptq/test_ptq_to_backend.py b/tests/py/ts/ptq/test_ptq_to_backend.py similarity index 100% rename from tests/py/ptq/test_ptq_to_backend.py rename to tests/py/ts/ptq/test_ptq_to_backend.py diff --git a/tests/py/ptq/test_ptq_trt_calibrator.py b/tests/py/ts/ptq/test_ptq_trt_calibrator.py similarity index 100% rename from tests/py/ptq/test_ptq_trt_calibrator.py rename to tests/py/ts/ptq/test_ptq_trt_calibrator.py diff --git a/tests/py/qat/test_qat_trt_accuracy.py b/tests/py/ts/qat/test_qat_trt_accuracy.py similarity index 100% rename from tests/py/qat/test_qat_trt_accuracy.py rename to tests/py/ts/qat/test_qat_trt_accuracy.py diff --git a/tests/py/requirements.txt b/tests/py/ts/requirements.txt similarity index 100% rename from tests/py/requirements.txt rename to tests/py/ts/requirements.txt diff --git a/tests/py/utils.py b/tests/py/ts/utils.py similarity index 100% rename from tests/py/utils.py rename to tests/py/ts/utils.py